Skip to content

Commit

Permalink
new script to find optimal learning rate (order of magnitude)
Browse files Browse the repository at this point in the history
Former-commit-id: 362ffca24b4a13e9fcb2dccaf6b64372b960b83b
  • Loading branch information
Javi Ribera committed Oct 18, 2018
1 parent 226a5d6 commit 9a1056b
Showing 1 changed file with 175 additions and 0 deletions.
175 changes: 175 additions & 0 deletions object-locator/find_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright &copyright 2018 The Board of Trustees of Purdue University.
# All rights reserved.
#
# This source code is not to be distributed or modified
# without the written permission of Edward J. Delp at Purdue University
# Contact information: ace@ecn.purdue.edu
#
# ___Original___
# https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
# =====================================================================

from __future__ import print_function

import math
import os
from itertools import chain
from tqdm import tqdm

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
import matplotlib
matplotlib.use('Agg')
import skimage.transform
from peterpy import peter
from ballpark import ballpark
from matplotlib import pyplot as plt

from . import losses
from .models import unet_model
from .data import CSVDataset
from .data import csv_collator
from .data import RandomHorizontalFlipImageAndLabel
from .data import RandomVerticalFlipImageAndLabel
from .data import ScaleImageAndLabel
from . import argparser


# Parse command line arguments
args = argparser.parse_command_args('training')

# Tensor type to use, select CUDA or not
torch.set_default_dtype(torch.float32)
device_cpu = torch.device('cpu')
device = torch.device('cuda') if args.cuda else device_cpu

# Set seeds
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed_all(args.seed)

# Data loading code
training_transforms = []
if not args.no_data_augm:
training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5)]
training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5)]
training_transforms += [ScaleImageAndLabel(size=(args.height, args.width))]
training_transforms += [transforms.ToTensor()]
training_transforms += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
trainset = CSVDataset(args.train_dir,
transforms=transforms.Compose(training_transforms),
max_dataset_size=args.max_trainset_size)
trainset_loader = DataLoader(trainset,
batch_size=args.batch_size,
drop_last=args.drop_last_batch,
shuffle=True,
num_workers=args.nThreads,
collate_fn=csv_collator)

# Model
with peter('Building network'):
model = unet_model.UNet(3, 1,
height=args.height,
width=args.width,
known_n_points=args.n_points)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" with {ballpark(num_params)} trainable parameters. ", end='')
model = nn.DataParallel(model)
model.to(device)


# Loss function
loss_regress = nn.SmoothL1Loss()
loss_loc = losses.WeightedHausdorffDistance(resized_height=args.height,
resized_width=args.width,
p=args.p,
return_2_terms=True,
device=device)
l1_loss = nn.L1Loss(size_average=False)
mse_loss = nn.MSELoss(reduce=False)

optimizer = optim.SGD(model.parameters(),
lr=999) # will be set later


def find_lr(init_value = 1e-6, final_value=1e-3, beta = 0.7):
num = len(trainset_loader)-1
mult = (final_value / init_value) ** (1/num)
lr = init_value
optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
batch_num = 0
losses = []
log_lrs = []
for imgs, dicts in tqdm(trainset_loader):
batch_num += 1

# Pull info from this batch and move to device
imgs = imgs.to(device)
imgs = Variable(imgs)
target_locations = [dictt['locations'].to(device)
for dictt in dicts]
target_counts = [dictt['count'].to(device)
for dictt in dicts]
target_orig_heights = [dictt['orig_height'].to(device)
for dictt in dicts]
target_orig_widths = [dictt['orig_width'].to(device)
for dictt in dicts]

# Lists -> Tensor batches
target_counts = torch.stack(target_counts)
target_orig_heights = torch.stack(target_orig_heights)
target_orig_widths = torch.stack(target_orig_widths)
target_orig_sizes = torch.stack((target_orig_heights,
target_orig_widths)).transpose(0, 1)
# As before, get the loss for this mini-batch of inputs/outputs
optimizer.zero_grad()
est_maps, est_counts = model.forward(imgs)
term1, term2 = loss_loc.forward(est_maps,
target_locations,
target_orig_sizes)
target_counts = target_counts.view(-1)
est_counts = est_counts.view(-1)
target_counts = target_counts.view(-1)
term3 = loss_regress.forward(est_counts, target_counts)
term3 *= args.lambdaa
loss = term1 + term2 + term3

# Compute the smoothed loss
avg_loss = beta * avg_loss + (1-beta) *loss.item()
smoothed_loss = avg_loss / (1 - beta**batch_num)

# Stop if the loss is exploding
if (batch_num > 1 and smoothed_loss > 4 * best_loss):
return log_lrs, losses

# Record the best loss
if smoothed_loss < best_loss or batch_num==1:
best_loss = smoothed_loss

# Store the values
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))

# Do the SGD step
loss.backward()
optimizer.step()

# Update the lr for the next step
lr *= mult
optimizer.param_groups[0]['lr'] = lr
return log_lrs, losses

logs, losses = find_lr()
plt.plot(logs, losses)
plt.savefig('/data/jprat/plot_beta0.7.png')

0 comments on commit 9a1056b

Please sign in to comment.