Skip to content

Commit

Permalink
Merge pull request #15 from Oisin-M/feat/minibatch
Browse files Browse the repository at this point in the history
Allow updating via minibatches
  • Loading branch information
fpichi authored Apr 12, 2024
2 parents 9132d3e + c7326f8 commit 5e2e881
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 31 deletions.
3 changes: 3 additions & 0 deletions gca_rom/network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import numpy as np
from torch import nn
from gca_rom import gca, scaling

Expand Down Expand Up @@ -61,6 +62,8 @@ def __init__(self, argv, **kwargs):
self.gamma = 0.0001
self.num_nodes = 0
self.conv = 'GMMConv'
self.batch_size = np.inf
self.minibatch = False
self.net_dir = './' + self.net_name + '/' + self.net_run + '/' + self.variable + '_' + self.net_name + '_lmap' + str(self.lambda_map) + '_btt' + str(self.bottleneck_dim) \
+ '_seed' + str(self.seed) + '_lv' + str(len(self.layer_vec)-2) + '_hc' + str(len(self.hidden_channels)) + '_nd' + str(self.nodes) \
+ '_ffn' + str(self.ffn) + '_skip' + str(self.skip) + '_lr' + str(self.learning_rate) + '_sc' + str(self.scaling_type) + '_rate' + str(self.rate) + '_conv' + self.conv + '/'
Expand Down
6 changes: 2 additions & 4 deletions gca_rom/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,9 @@ def graphs_dataset(dataset, HyperParams, param_sample=None):
print("Length of train dataset: ", len(train_dataset))
print("Length of test dataset: ", len(test_dataset))

max_batch_size = 100

loader = DataLoader(graphs, batch_size=1)
train_loader = DataLoader(train_dataset, batch_size=train_sims if train_sims<max_batch_size else max_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=test_sims if test_sims<max_batch_size else max_batch_size, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=train_sims if train_sims<HyperParams.batch_size else HyperParams.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=test_sims if test_sims<HyperParams.batch_size else HyperParams.batch_size, shuffle=False)
val_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

return loader, train_loader, test_loader, \
Expand Down
69 changes: 42 additions & 27 deletions gca_rom/training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from tqdm import tqdm

def train(model, optimizer, device, scheduler, params, train_loader, test_loader, train_trajectories, test_trajectories, HyperParams):
"""Trains the autoencoder model.
Expand Down Expand Up @@ -37,27 +37,44 @@ def train(model, optimizer, device, scheduler, params, train_loader, test_loader
train_rmse_1 = train_rmse_2 = 0
sum_loss_1 = sum_loss_2 = 0

loss_train_mse = 0
loss_train_map = 0
start_ind = 0
optimizer.zero_grad()
for data in train_loader:
data = data.to(device)
out, z, z_estimation = model(data, params[train_trajectories[start_ind:start_ind+data.batch_size], :])
loss_train_mse += F.mse_loss(out, data.x, reduction='sum')/(len(train_trajectories)*HyperParams.num_nodes*HyperParams.comp)
loss_train_map += F.mse_loss(z_estimation, z, reduction='sum')/(len(train_trajectories)*HyperParams.bottleneck_dim)
start_ind += data.batch_size
loss_train = loss_train_mse + HyperParams.lambda_map * loss_train_map
loss_train.backward()
sum_loss += loss_train.item()
sum_loss_1 += loss_train_mse.item()
sum_loss_2 += loss_train_map.item()
optimizer.step()

if HyperParams.minibatch:
total_batches = 0
for data in train_loader:
optimizer.zero_grad()
data = data.to(device)
out, z, z_estimation = model(data, params[train_trajectories[start_ind:start_ind+data.batch_size], :])
loss_train_mse = F.mse_loss(out, data.x, reduction='mean')
loss_train_map = F.mse_loss(z_estimation, z, reduction='mean')
loss_train = loss_train_mse + HyperParams.lambda_map * loss_train_map
loss_train.backward()
optimizer.step()
train_rmse += loss_train.item()
train_rmse_1 += loss_train_mse.item()
train_rmse_2 += loss_train_map.item()
total_batches += 1
train_rmse /= total_batches
train_rmse_1 /= total_batches
train_rmse_2 /= total_batches
else:
optimizer.zero_grad()
loss_train_mse = 0
loss_train_map = 0
for data in train_loader:
data = data.to(device)
out, z, z_estimation = model(data, params[train_trajectories[start_ind:start_ind+data.batch_size], :])
loss_train_mse += F.mse_loss(out, data.x, reduction='sum')/(len(train_trajectories)*HyperParams.num_nodes*HyperParams.comp)
loss_train_map += F.mse_loss(z_estimation, z, reduction='sum')/(len(train_trajectories)*HyperParams.bottleneck_dim)
start_ind += data.batch_size
loss_train = loss_train_mse + HyperParams.lambda_map * loss_train_map
loss_train.backward()
optimizer.step()
train_rmse += loss_train.item()
train_rmse_1 += loss_train_mse.item()
train_rmse_2 += loss_train_map.item()

scheduler.step()
train_rmse = sum_loss
train_rmse_1 = sum_loss_1
train_rmse_2 = sum_loss_2

train_history['train'].append(train_rmse)
train_history['l1'].append(train_rmse_1)
train_history['l2'].append(train_rmse_2)
Expand All @@ -79,13 +96,11 @@ def train(model, optimizer, device, scheduler, params, train_loader, test_loader
loss_test_map += F.mse_loss(z_estimation, z, reduction='sum')/(len(test_trajectories)*HyperParams.bottleneck_dim)
start_ind += data.batch_size
loss_test = loss_test_mse + HyperParams.lambda_map * loss_test_map
sum_loss += loss_test.item()
sum_loss_1 += loss_test_mse.item()
sum_loss_2 += loss_test_map.item()

test_rmse = sum_loss
test_rmse_1 = sum_loss_1
test_rmse_2 = sum_loss_2

test_rmse += loss_test.item()
test_rmse_1 += loss_test_mse.item()
test_rmse_2 += loss_test_map.item()

test_history['test'].append(test_rmse)
test_history['l1'].append(test_rmse_1)
test_history['l2'].append(test_rmse_2)
Expand Down

0 comments on commit 5e2e881

Please sign in to comment.