-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
58 lines (45 loc) · 1.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
sys.path.append('.')
from config import TrainConfig
from gnn.models import RGCNSkipConnection
from gnn.solver import RGCNSolver
from gnn.datasets import GNNDataset
import torch
from torch_geometric.data import DataLoader
import pandas as pd
def main():
config = TrainConfig
torch.manual_seed(config.seed)
# Loading and spliting datasets
dataset = GNNDataset(f'{config.dir_data}/train')
train_dataset = dataset[:27000]
valid_dataset = dataset[27000:]
dev_dataset = GNNDataset(f'{config.dir_data}/dev')
test_dataset = GNNDataset(f'{config.dir_data}/test')
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
dev_loader = DataLoader(dev_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# Defining model and solver
model = RGCNSkipConnection(
config.hidden_channels,
config.hidden_dims,
config.num_node_features,
config.node_embedding_dim,
config.dropout
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
solver = RGCNSolver(model, config.lr, config.n_epochs, device)
# Training model
solver.fit(train_loader, valid_loader, dev_loader)
# Predicting
sub = pd.read_csv('./data/sample_submission.csv')
model.eval()
with torch.no_grad():
for data in test_loader:
data.to(device)
pred = model(data.x, data.edge_index, data.edge_type, data.batch).detach().cpu().item()
sub.loc[sub['uid'] == data.uid[0], 'ST1_GAP(eV)'] = pred
sub.to_csv('single_model_submission.csv', index=False)
if __name__ == '__main__':
main()