Skip to content

Commit

Permalink
Merge pull request #8 from Oisin-M/feat/batched_update
Browse files Browse the repository at this point in the history
Feat/batched update
  • Loading branch information
fpichi authored Apr 11, 2024
2 parents 15b1bd7 + b95a435 commit 81813db
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
6 changes: 4 additions & 2 deletions gca_rom/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ def graphs_dataset(dataset, HyperParams, param_sample=None):
print("Length of train dataset: ", len(train_dataset))
print("Length of test dataset: ", len(test_dataset))

max_batch_size = 100

loader = DataLoader(graphs, batch_size=1)
train_loader = DataLoader(train_dataset, batch_size=train_sims, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=test_sims, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=train_sims if train_sims<max_batch_size else max_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=test_sims if test_sims<max_batch_size else max_batch_size, shuffle=False)
val_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

return loader, train_loader, test_loader, \
Expand Down
61 changes: 34 additions & 27 deletions gca_rom/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,52 +33,59 @@ def train(model, optimizer, device, scheduler, params, train_loader, test_loader
model.train()
loop = tqdm(range(HyperParams.max_epochs))
for epoch in loop:
train_rmse = total_examples = sum_loss = 0
train_rmse = sum_loss = 0
train_rmse_1 = train_rmse_2 = 0
sum_loss_1 = sum_loss_2 = 0

loss_train_mse = 0
loss_train_map = 0
start_ind = 0
optimizer.zero_grad()
for data in train_loader:
optimizer.zero_grad()
data = data.to(device)
out, z, z_estimation = model(data, params[train_trajectories, :])
loss_train_mse = F.mse_loss(out, data.x, reduction='mean')
loss_train_map = F.mse_loss(z_estimation, z, reduction='mean')
loss_train = loss_train_mse + HyperParams.lambda_map * loss_train_map
loss_train.backward()
optimizer.step()
sum_loss += loss_train.item()
sum_loss_1 += loss_train_mse.item()
sum_loss_2 += loss_train_map.item()
total_examples += 1
out, z, z_estimation = model(data, params[train_trajectories[start_ind:start_ind+data.batch_size], :])
loss_train_mse += F.mse_loss(out, data.x, reduction='sum')/(len(train_trajectories)*HyperParams.num_nodes*HyperParams.comp)
loss_train_map += F.mse_loss(z_estimation, z, reduction='sum')/(len(train_trajectories)*HyperParams.bottleneck_dim)
start_ind += data.batch_size
loss_train = loss_train_mse + HyperParams.lambda_map * loss_train_map
loss_train.backward()
sum_loss += loss_train.item()
sum_loss_1 += loss_train_mse.item()
sum_loss_2 += loss_train_map.item()
optimizer.step()

scheduler.step()
train_rmse = sum_loss / total_examples
train_rmse_1 = sum_loss_1 / total_examples
train_rmse_2 = sum_loss_2 / total_examples
train_rmse = sum_loss
train_rmse_1 = sum_loss_1
train_rmse_2 = sum_loss_2
train_history['train'].append(train_rmse)
train_history['l1'].append(train_rmse_1)
train_history['l2'].append(train_rmse_2)

if HyperParams.cross_validation:
with torch.no_grad():
model.eval()
test_rmse = total_examples = sum_loss = 0
test_rmse = sum_loss = 0
test_rmse_1 = test_rmse_2 = 0
sum_loss_1 = sum_loss_2 = 0

loss_test_mse = 0
loss_test_map = 0
start_ind = 0
for data in test_loader:
data = data.to(device)
out, z, z_estimation = model(data, params[test_trajectories, :])
loss_test_mse = F.mse_loss(out, data.x, reduction='mean')
loss_test_map = F.mse_loss(z_estimation, z, reduction='mean')
loss_test = loss_test_mse + HyperParams.lambda_map * loss_test_map
sum_loss += loss_test.item()
sum_loss_1 += loss_test_mse.item()
sum_loss_2 += loss_test_map.item()
total_examples += 1
out, z, z_estimation = model(data, params[test_trajectories[start_ind:start_ind+data.batch_size], :])
loss_test_mse += F.mse_loss(out, data.x, reduction='sum')/(len(test_trajectories)*HyperParams.num_nodes)
loss_test_map += F.mse_loss(z_estimation, z, reduction='sum')/(len(test_trajectories)*HyperParams.bottleneck_dim)
start_ind += data.batch_size
loss_test = loss_test_mse + HyperParams.lambda_map * loss_test_map
sum_loss += loss_test.item()
sum_loss_1 += loss_test_mse.item()
sum_loss_2 += loss_test_map.item()

test_rmse = sum_loss / total_examples
test_rmse_1 = sum_loss_1 / total_examples
test_rmse_2 = sum_loss_2 / total_examples
test_rmse = sum_loss
test_rmse_1 = sum_loss_1
test_rmse_2 = sum_loss_2
test_history['test'].append(test_rmse)
test_history['l1'].append(test_rmse_1)
test_history['l2'].append(test_rmse_2)
Expand Down

0 comments on commit 81813db

Please sign in to comment.