-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
173 lines (147 loc) · 6.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import datetime
import multiprocessing
import os
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
import logging
import graphgym # noqa, register custom modules
from graphgym.optimizer.extra_optimizers import ExtendedSchedulerConfig
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (cfg, dump_cfg,
set_cfg, load_cfg,
makedirs_rm_exist)
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.optim import OptimizerConfig
from modules.architecture.copt_module import create_model
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.graphgym.register import train_dict
from torch_geometric import seed_everything
from graphgym.patches import create_loader
torch.backends.cuda.matmul.allow_tf32 = True # Default False in PyTorch 1.12+
torch.backends.cudnn.allow_tf32 = True # Default True
def new_optimizer_config(cfg):
return OptimizerConfig(optimizer=cfg.optim.optimizer,
base_lr=cfg.optim.base_lr,
weight_decay=cfg.optim.weight_decay,
momentum=cfg.optim.momentum)
def new_scheduler_config(cfg):
return ExtendedSchedulerConfig(
scheduler=cfg.optim.scheduler,
steps=cfg.optim.steps, lr_decay=cfg.optim.lr_decay,
max_epoch=cfg.optim.max_epoch, reduce_factor=cfg.optim.reduce_factor,
schedule_patience=cfg.optim.schedule_patience, min_lr=cfg.optim.min_lr,
num_warmup_epochs=cfg.optim.num_warmup_epochs,
train_mode=cfg.train.mode, eval_period=cfg.optim.eval_period)
def custom_set_out_dir(cfg, cfg_fname, name_tag):
"""Set custom main output directory path to cfg.
Include the config filename and name_tag in the new :obj:`cfg.out_dir`.
Args:
cfg (CfgNode): Configuration node
cfg_fname (string): Filename for the yaml format configuration file
name_tag (string): Additional name tag to identify this execution of the
configuration file, specified in :obj:`cfg.name_tag`
"""
run_name = os.path.splitext(os.path.basename(cfg_fname))[0]
run_name += f"-{name_tag}" if name_tag else f"-{cfg.wandb.name}"
cfg.out_dir = os.path.join(cfg.out_dir, run_name)
def custom_set_run_dir(cfg, run_id):
"""Custom output directory naming for each experiment run.
Args:
cfg (CfgNode): Configuration node
run_id (int): Main for-loop iter id (the random seed or dataset split)
"""
cfg.run_dir = os.path.join(cfg.out_dir, str(run_id))
# Make output directory
if cfg.train.auto_resume:
os.makedirs(cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(cfg.run_dir)
def run_loop_settings(cfg, args):
"""Create main loop execution settings based on the current cfg.
Configures the main execution loop to run in one of two modes:
1. 'multi-seed' - Reproduces default behaviour of GraphGym when
args.repeats controls how many times the experiment run is repeated.
Each iteration is executed with a random seed set to an increment from
the previous one, starting at initial cfg.seed.
2. 'multi-split' - Executes the experiment run over multiple dataset splits,
these can be multiple CV splits or multiple standard splits. The random
seed is reset to the initial cfg.seed value for each run iteration.
Returns:
List of run IDs for each loop iteration
List of rng seeds to loop over
List of dataset split indices to loop over
"""
if len(cfg.run_multiple_splits) == 0:
# 'multi-seed' run mode
num_iterations = args.repeat
seeds = [cfg.seed + x for x in range(num_iterations)]
split_indices = [cfg.dataset.split_index] * num_iterations
run_ids = seeds
else:
# 'multi-split' run mode
if args.repeat != 1:
raise NotImplementedError("Running multiple repeats of multiple "
"splits in one run is not supported.")
num_iterations = len(cfg.run_multiple_splits)
split_indices = cfg.run_multiple_splits
run_ids = seeds = [cfg.seed + i for i in range(num_iterations)]
return run_ids, seeds, split_indices
def adapt_args(args):
opts = args.opts
adapted_opts = []
for opt in opts:
adapted_opts.extend(opt.split('='))
args.opts = adapted_opts
return args
if __name__ == '__main__':
# Load cmd line args
args = parse_args()
args = adapt_args(args)
# Load config file
set_cfg(cfg)
cfg.train.mode = None # XXX: temporary fix, need to register train.mode
load_cfg(cfg, args)
custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag)
if cfg.num_workers != 0 and cfg.num_workers > torch.multiprocessing.cpu_count():
logging.warning(f'cfg.num_workers is set to {cfg.num_workers} but only {torch.multiprocessing.cpu_count()} CPUs are '
f'available. Setting cfg.num_workers to {torch.multiprocessing.cpu_count()}.')
cfg.num_workers = torch.multiprocessing.cpu_count()
dump_cfg(cfg)
# Set Pytorch environment
torch.set_num_threads(cfg.num_threads)
# Repeat for multiple experiment runs
for run_id, seed, split_index in zip(*run_loop_settings(cfg, args)):
# Set configurations for each run
custom_set_run_dir(cfg, run_id)
set_printing()
cfg.dataset.split_index = split_index
cfg.seed = seed
cfg.run_id = run_id
seed_everything(cfg.seed)
auto_select_device()
# if cfg.pretrained.dir:
# cfg = load_pretrained_model_cfg(cfg)
logging.info(f"[*] Run ID {run_id}: seed={cfg.seed}, "
f"split_index={cfg.dataset.split_index}")
logging.info(f" Starting now: {datetime.datetime.now()}")
# Set machine learning pipeline
loaders = create_loader()
model = create_model(dim_in=cfg.dim_in, dim_out=cfg.dim_out)
# Print model info
logging.info(model)
logging.info(cfg)
cfg.params = params_count(model)
logging.info('Num parameters: %s', cfg.params)
# Start training
if cfg.train.mode == 'standard':
if cfg.wandb.use:
logging.warning("[W] WandB logging is not supported with the "
"default train.mode, set it to `custom`")
datamodule = GraphGymDataModule()
train(model, datamodule, logger=True)
else:
if cfg.train.mode == 'copt_test':
cfg.wandb.use = False
train_dict[cfg.train.mode](cfg, loaders, model)