-
Notifications
You must be signed in to change notification settings - Fork 0
/
extrapolation.py
140 lines (112 loc) · 3.8 KB
/
extrapolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
# //////////////////////////////////////////////
# ///////////// EXTRAPOLATION TASK /////////////
# //////////////////////////////////////////////
# ==============================================
# SETUP AND IMPORTS
# ==============================================
# We use `PyTorch Lightning` (wrapping `PyTorch`) as our main framework and `wandb`
# to track and log the experiments. We set all seeds through `PyTorch Lightning`'s dedicated function.
# import Libraries
import os
# torch imports
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
import torchmetrics
# pl imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pondernet import *
from cifar10data import *
# remaining imports
import wandb
from math import floor
# set seeds
seed_everything(1234)
# log in to wandb
wandb.login()
# ==============================================
# CONSTANTS AND HYPERPARAMETERS
# ==============================================
# Trainer settings
BATCH_SIZE = 128
EPOCHS = 10
# Optimizer settings
LR = 0.001
GRAD_NORM_CLIP = 0.5
# Model hparams
N_ELEMS = 512
N_HIDDEN = 100
MAX_STEPS = 20
LAMBDA_P = 0.5
BETA = 0.01
# ==============================================
# CIFAR10 SETUP
# ==============================================
def get_transforms():
# define transformations
transform_22 = transforms.Compose([
transforms.RandomRotation(degrees=22.5),
transforms.ToTensor(),
])
transform_45 = transforms.Compose([
transforms.RandomRotation(degrees=45),
transforms.ToTensor(),
])
transform_67 = transforms.Compose([
transforms.RandomRotation(degrees=67.5),
transforms.ToTensor(),
])
transform_90 = transforms.Compose([
transforms.RandomRotation(degrees=90),
transforms.ToTensor(),
])
train_transform = transform_22
test_transform = [transform_22, transform_45, transform_67, transform_90]
return train_transform, test_transform
train_transform, test_transform = get_transforms()
# ==============================================
# RUN EXTRAPOLATION
# ==============================================
# Load the CIFAR10 dataset with no rotations and train PonderNet on it.
# Make sure to edit the `WandbLogger` call so that you log the experiment
# on your account's desired project.
# initialize datamodule and model
cifar10_dm = CIFAR10_DataModule(
data_dir='./',
train_transform=train_transform,
test_transform=test_transform,
batch_size=BATCH_SIZE)
model = PonderCIFAR(
n_elems=N_ELEMS,
n_hidden=N_HIDDEN,
max_steps=MAX_STEPS,
lambda_p=LAMBDA_P,
beta=BETA,
lr=LR)
# setup logger
logger = WandbLogger(project='Test-Histogram', name='extrapolation', offline=False)
logger.watch(model)
trainer = Trainer(
logger=logger, # W&B integration
gpus=-1, # use all available GPU's
max_epochs=EPOCHS, # maximum number of epochs
gradient_clip_val=GRAD_NORM_CLIP, # gradient clipping
val_check_interval=0.25, # validate 4 times per epoch
precision=16, # train in half precision
deterministic=True) # for reproducibility
# fit the model
trainer.fit(model, datamodule=cifar10_dm)
# evaluate on the test set
trainer.test(model, datamodule=cifar10_dm)
wandb.finish()