Skip to content

Commit

Permalink
Merge pull request #83 from victoresque/fix_optim
Browse files Browse the repository at this point in the history
Fix optimizer initialization order
  • Loading branch information
SunQpark authored Dec 2, 2020
2 parents a39a102 + 3c2bbc6 commit 85c5535
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 28 deletions.
24 changes: 1 addition & 23 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ def __init__(self, model, criterion, metric_ftns, optimizer, config):
self.config = config
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

# setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(config['n_gpu'])
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)

self.model = model
self.criterion = criterion
self.metric_ftns = metric_ftns
self.optimizer = optimizer
Expand Down Expand Up @@ -101,23 +96,6 @@ def train(self):
if epoch % self.save_period == 0:
self._save_checkpoint(epoch, save_best=best)

def _prepare_device(self, n_gpu_use):
"""
setup GPU device if available, move model into configured device
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self.logger.warning("Warning: There\'s no GPU available on this machine,"
"training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
"on this machine.".format(n_gpu_use, n_gpu))
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids

def _save_checkpoint(self, epoch, save_best=False):
"""
Saving checkpoints
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch>=1.1
torchvision
numpy
tqdm
tensorboard>=1.14
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import model.model as module_arch
from parse_config import ConfigParser
from trainer import Trainer
from utils import prepare_device


# fix random seeds for reproducibility
Expand All @@ -28,18 +29,24 @@ def main(config):
model = config.init_obj('arch', module_arch)
logger.info(model)

# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)

# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)

lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
Expand Down
5 changes: 3 additions & 2 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ class Trainer(BaseTrainer):
"""
Trainer class
"""
def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
valid_data_loader=None, lr_scheduler=None, len_epoch=None):
def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
super().__init__(model, criterion, metric_ftns, optimizer, config)
self.config = config
self.device = device
self.data_loader = data_loader
if len_epoch is None:
# epoch-based training
Expand Down
22 changes: 20 additions & 2 deletions utils/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import torch
import pandas as pd
from pathlib import Path
from itertools import repeat
Expand All @@ -25,12 +26,29 @@ def inf_loop(data_loader):
for loader in repeat(data_loader):
yield from loader

def prepare_device(n_gpu_use):
"""
setup GPU device if available. get gpu device indices which are used for DataParallel
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
print("Warning: There\'s no GPU available on this machine,"
"training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
"available on this machine.")
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids

class MetricTracker:
def __init__(self, *keys, writer=None):
self.writer = writer
self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
self.reset()

def reset(self):
for col in self._data.columns:
self._data[col].values[:] = 0
Expand All @@ -44,6 +62,6 @@ def update(self, key, value, n=1):

def avg(self, key):
return self._data.average[key]

def result(self):
return dict(self._data.average)

0 comments on commit 85c5535

Please sign in to comment.