-
Notifications
You must be signed in to change notification settings - Fork 0
/
interpolation.py
124 lines (98 loc) · 3.27 KB
/
interpolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
# //////////////////////////////////////////////
# ///////////// INTERPOLATION TASK /////////////
# //////////////////////////////////////////////
# ==============================================
# SETUP AND IMPORTS
# ==============================================
# We use `PyTorch Lightning` (wrapping `PyTorch`) as our main framework and `wandb`
# to track and log the experiments. We set all seeds through `PyTorch Lightning`'s dedicated function.
# import Libraries
import os
# torch imports
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
import torchmetrics
# pl imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pondernet import *
from cifar10data import *
# remaining imports
import wandb
from math import floor
# set seeds
seed_everything(1234)
# log in to wandb
wandb.login()
# ==============================================
# CONSTANTS AND HYPERPARAMETERS
# ==============================================
# Trainer settings
BATCH_SIZE = 128
EPOCHS = 10
# Optimizer settings
LR = 0.001
GRAD_NORM_CLIP = 0.5
# Model hparams
N_ELEMS = 512
N_HIDDEN = 100
MAX_STEPS = 20
LAMBDA_P = 0.5
BETA = 0.01
# ==============================================
# CIFAR10 SETUP
# ==============================================
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
# ==============================================
# RUN EXTRAPOLATION
# ==============================================
# Load the CIFAR10 dataset with no rotations and train PonderNet on it.
# Make sure to edit the `WandbLogger` call so that you log the experiment
# on your account's desired project.
# initialize datamodule and model
cifar10_dm = CIFAR10_DataModule(
data_dir='./',
train_transform=train_transform,
test_transform=test_transform,
batch_size=BATCH_SIZE)
model = PonderCIFAR(
n_elems=N_ELEMS,
n_hidden=N_HIDDEN,
max_steps=MAX_STEPS,
lambda_p=LAMBDA_P,
beta=BETA,
lr=LR)
# setup logger
logger = WandbLogger(project='Test-Histogram', name='interpolation', offline=False)
logger.watch(model)
trainer = Trainer(
logger=logger, # W&B integration
gpus=-1, # use all available GPU's
max_epochs=EPOCHS, # maximum number of epochs
gradient_clip_val=GRAD_NORM_CLIP, # gradient clipping
val_check_interval=0.25, # validate 4 times per epoch
precision=16, # train in half precision
deterministic=True) # for reproducibility
# fit the model
trainer.fit(model, datamodule=cifar10_dm)
# evaluate on the test set
trainer.test(model, datamodule=cifar10_dm)
wandb.finish()