diff --git a/pytorch_caney/pipelines/finetuning/finetune_vanilla.py b/pytorch_caney/pipelines/finetuning/finetune_vanilla.py deleted file mode 100644 index 72ade94..0000000 --- a/pytorch_caney/pipelines/finetuning/finetune_vanilla.py +++ /dev/null @@ -1,438 +0,0 @@ -from pytorch_caney.models.build import build_model - -from pytorch_caney.data.datamodules.finetune_datamodule \ - import build_finetune_dataloaders - -from pytorch_caney.training.mim_utils \ - import build_optimizer, save_checkpoint, reduce_tensor - -from pytorch_caney.config import get_config -from pytorch_caney.loss.build import build_loss -from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr -from pytorch_caney.ptc_logging import create_logger -from pytorch_caney.training.mim_utils import get_grad_norm - -import argparse -import datetime -import joblib -import numpy as np -import os -import time - -import torch -import torch.cuda.amp as amp -import torch.backends.cudnn as cudnn -import torch.distributed as dist - -from timm.utils import AverageMeter - - -def parse_args(): - """ - Parse command-line arguments - """ - - parser = argparse.ArgumentParser( - 'pytorch-caney finetuning', - add_help=False) - - parser.add_argument( - '--cfg', - type=str, - required=True, - metavar="FILE", - help='path to config file') - - parser.add_argument( - "--data-paths", - nargs='+', - required=True, - help="paths where dataset is stored") - - parser.add_argument( - '--dataset', - type=str, - required=True, - help='Dataset to use') - - parser.add_argument( - '--pretrained', - type=str, - help='path to pre-trained model') - - parser.add_argument( - '--batch-size', - type=int, - help="batch size for single GPU") - - parser.add_argument( - '--resume', - help='resume from checkpoint') - - parser.add_argument( - '--accumulation-steps', - type=int, - help="gradient accumulation steps") - - parser.add_argument( - '--use-checkpoint', - action='store_true', - help="whether to use gradient checkpointing to save memory") - - parser.add_argument( - '--enable-amp', - action='store_true') - - parser.add_argument( - '--disable-amp', - action='store_false', - dest='enable_amp') - - parser.set_defaults(enable_amp=True) - - parser.add_argument( - '--output', - default='output', - type=str, - metavar='PATH', - help='root of output folder, the full path is ' + - '// (default: output)') - - parser.add_argument( - '--tag', - help='tag of experiment') - - args = parser.parse_args() - - config = get_config(args) - - return args, config - - -def train(config, - dataloader_train, - dataloader_val, - model, - model_wo_ddp, - optimizer, - lr_scheduler, - scaler, - criterion): - """ - Start fine-tuning a specific model and dataset. - - Args: - config: config object - dataloader_train: training pytorch dataloader - dataloader_val: validation pytorch dataloader - model: model to pre-train - model_wo_ddp: model to pre-train that is not the DDP version - optimizer: pytorch optimizer - lr_scheduler: learning-rate scheduler - scaler: loss scaler - criterion: loss function to use for fine-tuning - """ - - loss = validate(config, model, dataloader_val, criterion) - - logger.info(f'Model validation loss: {loss:.3f}%') - - logger.info("Start fine-tuning") - - start_time = time.time() - - for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): - - dataloader_train.sampler.set_epoch(epoch) - - execute_one_epoch(config, model, dataloader_train, - optimizer, criterion, epoch, lr_scheduler, scaler) - - loss = validate(config, model, dataloader_val, criterion) - - logger.info(f'Model validation loss: {loss:.3f}%') - - if dist.get_rank() == 0 and \ - (epoch % config.SAVE_FREQ == 0 or - epoch == (config.TRAIN.EPOCHS - 1)): - - save_checkpoint(config, epoch, model_wo_ddp, 0., - optimizer, lr_scheduler, scaler, logger) - - total_time = time.time() - start_time - - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - - logger.info('Training time {}'.format(total_time_str)) - - -def execute_one_epoch(config, - model, - dataloader, - optimizer, - criterion, - epoch, - lr_scheduler, - scaler): - """ - Execute training iterations on a single epoch. - - Args: - config: config object - model: model to pre-train - dataloader: dataloader to use - optimizer: pytorch optimizer - epoch: int epoch number - lr_scheduler: learning-rate scheduler - scaler: loss scaler - """ - model.train() - - optimizer.zero_grad() - - num_steps = len(dataloader) - - # Set up logging meters - batch_time = AverageMeter() - data_time = AverageMeter() - loss_meter = AverageMeter() - norm_meter = AverageMeter() - loss_scale_meter = AverageMeter() - - start = time.time() - end = time.time() - for idx, (samples, targets) in enumerate(dataloader): - - data_time.update(time.time() - start) - - samples = samples.cuda(non_blocking=True) - targets = targets.cuda(non_blocking=True) - - samples = samples.to(torch.bfloat16) - - with amp.autocast(enabled=config.ENABLE_AMP): - logits = model(samples) - - loss = criterion(logits, targets) - loss = loss / config.TRAIN.ACCUMULATION_STEPS - - scaler.scale(loss).backward() - - grad_norm = get_grad_norm(model.parameters()) - - if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: - optimizer.zero_grad() - lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) - - loss_scale_value = scaler.state_dict()["scale"] - - torch.cuda.synchronize() - - loss_meter.update(loss.item(), targets.size(0)) - norm_meter.update(grad_norm) - loss_scale_meter.update(loss_scale_value) - batch_time.update(time.time() - end) - end = time.time() - - if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] - memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) - etas = batch_time.avg * (num_steps - idx) - logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {loss_scale_meter.val:.4f}' + - f' ({loss_scale_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') - - epoch_time = time.time() - start - logger.info( - f"EPOCH {epoch} training takes " + - f"{datetime.timedelta(seconds=int(epoch_time))}") - - -@torch.no_grad() -def validate(config, model, dataloader, criterion): - """Validation function which given a model and validation loader - performs a validation run and returns the average loss according - to the criterion. - - Args: - config: config object - model: pytorch model to validate - dataloader: pytorch validation loader - criterion: pytorch-friendly loss function - - Returns: - loss_meter.avg: average of the loss throught the validation - iterations - """ - - model.eval() - - batch_time = AverageMeter() - - loss_meter = AverageMeter() - - end = time.time() - - for idx, (images, target) in enumerate(dataloader): - - images = images.cuda(non_blocking=True) - - target = target.cuda(non_blocking=True) - - images = images.to(torch.bfloat16) - - # compute output - with amp.autocast(enabled=config.ENABLE_AMP): - output = model(images) - - # measure accuracy and record loss - loss = criterion(output, target) - - loss = reduce_tensor(loss) - - loss_meter.update(loss.item(), target.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - - end = time.time() - - if idx % config.PRINT_FREQ == 0: - - memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) - - logger.info( - f'Test: [{idx}/{len(dataloader)}]\t' - f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'Mem {memory_used:.0f}MB') - - return loss_meter.avg - - -def main(config): - """ - Performs the main function of building model, loader, etc. and starts - training. - """ - - dataloader_train, dataloader_val = build_finetune_dataloaders( - config, logger) - - model = build_finetune_model(config, logger) - - optimizer = build_optimizer(config, - model, - is_pretrain=False, - logger=logger) - - model, model_wo_ddp = make_ddp(model) - - n_iter_per_epoch = len(dataloader_train) - - lr_scheduler = build_scheduler(config, optimizer, n_iter_per_epoch) - - scaler = amp.GradScaler() - - criterion = build_loss(config) - - train(config, - dataloader_train, - dataloader_val, - model, - model_wo_ddp, - optimizer, - lr_scheduler, - scaler, - criterion) - - -def build_finetune_model(config, logger): - - logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") - - # You can replace this section by simply calling your model class - # For example: model = UNet(parameters) - model = build_model(config, - pretrain=False, - pretrain_method='mim', - logger=logger) - - model.cuda() - - logger.info(str(model)) - - return model - - -def make_ddp(model): - - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[int(os.environ["RANK"])], - broadcast_buffers=False, - find_unused_parameters=True) - - model_without_ddp = model.module - - return model, model_without_ddp - - -def setup_rank_worldsize(): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: - rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) - print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") - else: - rank = -1 - world_size = -1 - return rank, world_size - - -def setup_distributed_processing(rank, world_size): - torch.cuda.set_device(int(os.environ["RANK"])) - torch.distributed.init_process_group( - backend='nccl', init_method='env://', world_size=world_size, rank=rank) - torch.distributed.barrier() - - -def setup_seeding(config): - seed = config.SEED + dist.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - - -if __name__ == '__main__': - _, config = parse_args() - - rank, world_size = setup_rank_worldsize() - - setup_distributed_processing(rank, world_size) - - setup_seeding(config) - - cudnn.benchmark = True - - os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, - dist_rank=dist.get_rank(), - name=f"{config.MODEL.NAME}") - - if dist.get_rank() == 0: - path = os.path.join(config.OUTPUT, "config.json") - with open(path, "w") as f: - f.write(config.dump()) - logger.info(f"Full config saved to {path}") - logger.info(config.dump()) - config_file_name = f'{config.TAG}.config.sav' - config_file_path = os.path.join(config.OUTPUT, config_file_name) - joblib.dump(config, config_file_path) - - main(config)