-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
deephyper example with qm9 fixed (#317)
- Loading branch information
Showing
3 changed files
with
246 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,125 +1,160 @@ | ||
import os, json | ||
|
||
import os | ||
import pdb | ||
import json | ||
import torch | ||
import torch_geometric | ||
from torch_geometric.transforms import AddLaplacianEigenvectorPE | ||
import argparse | ||
|
||
# deprecated in torch_geometric 2.0 | ||
try: | ||
from torch_geometric.loader import DataLoader | ||
except: | ||
except ImportError: | ||
from torch_geometric.data import DataLoader | ||
|
||
import hydragnn | ||
import argparse | ||
|
||
num_samples = 1000 | ||
|
||
|
||
# Update each sample prior to loading. | ||
def qm9_pre_transform(data): | ||
def qm9_pre_transform(data, transform): | ||
# LPE | ||
data = transform(data) | ||
# Set descriptor as element type. | ||
data.x = data.z.float().view(-1, 1) | ||
# Only predict free energy (index 10 of 19 properties) for this run. | ||
data.y = data.y[:, 10] / len(data.x) | ||
graph_features_dim = [1] | ||
node_feature_dim = [1] | ||
# gps requires relative edge features, introduced rel_lapPe as edge encodings | ||
source_pe = data.pe[data.edge_index[0]] | ||
target_pe = data.pe[data.edge_index[1]] | ||
data.rel_pe = torch.abs(source_pe - target_pe) # Compute feature-wise difference | ||
return data | ||
|
||
|
||
def qm9_pre_filter(data): | ||
return data.idx < num_samples | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_type", help="model_type", default="EGNN") | ||
parser.add_argument("--hidden_dim", type=int, help="hidden_dim", default=5) | ||
parser.add_argument("--num_conv_layers", type=int, help="num_conv_layers", default=6) | ||
parser.add_argument("--num_headlayers", type=int, help="num_headlayers", default=2) | ||
parser.add_argument("--dim_headlayers", type=int, help="dim_headlayers", default=10) | ||
parser.add_argument("--log", help="log name", default="qm9_test") | ||
args = parser.parse_args() | ||
args.parameters = vars(args) | ||
|
||
num_samples = 1000 | ||
|
||
# Configurable run choices (JSON file that accompanies this example script). | ||
filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.json") | ||
with open(filename, "r") as f: | ||
config = json.load(f) | ||
verbosity = config["Verbosity"]["level"] | ||
var_config = config["NeuralNetwork"]["Variables_of_interest"] | ||
|
||
# Update the config dictionary with the suggested hyperparameters | ||
config["NeuralNetwork"]["Architecture"]["model_type"] = args.parameters["model_type"] | ||
config["NeuralNetwork"]["Architecture"]["hidden_dim"] = args.parameters["hidden_dim"] | ||
config["NeuralNetwork"]["Architecture"]["num_conv_layers"] = args.parameters[ | ||
"num_conv_layers" | ||
] | ||
|
||
dim_headlayers = [ | ||
args.parameters["dim_headlayers"] for i in range(args.parameters["num_headlayers"]) | ||
] | ||
|
||
for head_type in config["NeuralNetwork"]["Architecture"]["output_heads"]: | ||
config["NeuralNetwork"]["Architecture"]["output_heads"][head_type][ | ||
"num_headlayers" | ||
] = args.parameters["num_headlayers"] | ||
config["NeuralNetwork"]["Architecture"]["output_heads"][head_type][ | ||
"dim_headlayers" | ||
] = dim_headlayers | ||
|
||
if args.parameters["model_type"] not in ["EGNN", "SchNet", "DimeNet"]: | ||
config["NeuralNetwork"]["Architecture"]["equivariance"] = False | ||
|
||
# Always initialize for multi-rank training. | ||
world_size, world_rank = hydragnn.utils.setup_ddp() | ||
|
||
log_name = args.log | ||
# Enable print to log file. | ||
hydragnn.utils.setup_log(log_name) | ||
|
||
# Use built-in torch_geometric datasets. | ||
# Filter function above used to run quick example. | ||
# NOTE: data is moved to the device in the pre-transform. | ||
# NOTE: transforms/filters will NOT be re-run unless the qm9/processed/ directory is removed. | ||
dataset = torch_geometric.datasets.QM9( | ||
root="dataset/qm9", pre_transform=qm9_pre_transform, pre_filter=qm9_pre_filter | ||
) | ||
train, val, test = hydragnn.preprocess.split_dataset( | ||
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False | ||
) | ||
(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( | ||
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"] | ||
) | ||
|
||
config = hydragnn.utils.update_config(config, train_loader, val_loader, test_loader) | ||
|
||
model = hydragnn.models.create_model_config( | ||
config=config["NeuralNetwork"], | ||
verbosity=verbosity, | ||
) | ||
model = hydragnn.utils.get_distributed_model(model, verbosity) | ||
|
||
learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"] | ||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | ||
optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001 | ||
) | ||
|
||
# Run training with the given model and qm9 datasets. | ||
writer = hydragnn.utils.get_summary_writer(log_name) | ||
hydragnn.utils.save_config(config, log_name) | ||
|
||
hydragnn.train.train_validate_test( | ||
model, | ||
optimizer, | ||
train_loader, | ||
val_loader, | ||
test_loader, | ||
writer, | ||
scheduler, | ||
config["NeuralNetwork"], | ||
log_name, | ||
verbosity, | ||
) | ||
|
||
hydragnn.utils.save_model(model, optimizer, log_name) | ||
hydragnn.utils.print_timers(verbosity) | ||
def main(mpnn_type=None, global_attn_engine=None, global_attn_type=None): | ||
# FIX random seed | ||
random_state = 0 | ||
torch.manual_seed(random_state) | ||
|
||
# Set this path for output. | ||
try: | ||
os.environ["SERIALIZED_DATA_PATH"] | ||
except KeyError: | ||
os.environ["SERIALIZED_DATA_PATH"] = os.getcwd() | ||
|
||
# Configurable run choices (JSON file that accompanies this example script). | ||
filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.json") | ||
with open(filename, "r") as f: | ||
config = json.load(f) | ||
|
||
# If a model type is provided, update the configuration accordingly. | ||
if global_attn_engine: | ||
config["NeuralNetwork"]["Architecture"][ | ||
"global_attn_engine" | ||
] = global_attn_engine | ||
|
||
if global_attn_type: | ||
config["NeuralNetwork"]["Architecture"]["global_attn_type"] = global_attn_type | ||
|
||
if mpnn_type: | ||
config["NeuralNetwork"]["Architecture"]["mpnn_type"] = mpnn_type | ||
|
||
verbosity = config["Verbosity"]["level"] | ||
var_config = config["NeuralNetwork"]["Variables_of_interest"] | ||
|
||
# Always initialize for multi-rank training. | ||
world_size, world_rank = hydragnn.utils.distributed.setup_ddp() | ||
|
||
log_name = f"qm9_test_{mpnn_type}" if mpnn_type else "qm9_test" | ||
# Enable print to log file. | ||
hydragnn.utils.print.print_utils.setup_log(log_name) | ||
|
||
# LPE | ||
transform = AddLaplacianEigenvectorPE( | ||
k=config["NeuralNetwork"]["Architecture"]["pe_dim"], | ||
attr_name="pe", | ||
is_undirected=True, | ||
) | ||
|
||
# Use built-in torch_geometric datasets. | ||
# Filter function above used to run quick example. | ||
# NOTE: data is moved to the device in the pre-transform. | ||
# NOTE: transforms/filters will NOT be re-run unless the qm9/processed/ directory is removed. | ||
dataset = torch_geometric.datasets.QM9( | ||
root="dataset/qm9", | ||
pre_transform=lambda data: qm9_pre_transform(data, transform), | ||
pre_filter=qm9_pre_filter, | ||
) | ||
train, val, test = hydragnn.preprocess.split_dataset( | ||
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False | ||
) | ||
|
||
(train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( | ||
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"] | ||
) | ||
|
||
config = hydragnn.utils.input_config_parsing.update_config( | ||
config, train_loader, val_loader, test_loader | ||
) | ||
|
||
model = hydragnn.models.create_model_config( | ||
config=config["NeuralNetwork"], | ||
verbosity=verbosity, | ||
) | ||
model = hydragnn.utils.distributed.get_distributed_model(model, verbosity) | ||
|
||
learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"] | ||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | ||
optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001 | ||
) | ||
|
||
# Run training with the given model and qm9 datasets. | ||
writer = hydragnn.utils.model.model.get_summary_writer(log_name) | ||
hydragnn.utils.input_config_parsing.save_config(config, log_name) | ||
|
||
hydragnn.train.train_validate_test( | ||
model, | ||
optimizer, | ||
train_loader, | ||
val_loader, | ||
test_loader, | ||
writer, | ||
scheduler, | ||
config["NeuralNetwork"], | ||
log_name, | ||
verbosity, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Run the QM9 example with optional model type." | ||
) | ||
parser.add_argument( | ||
"--mpnn_type", | ||
type=str, | ||
default=None, | ||
help="Specify the model type for training (default: None).", | ||
) | ||
parser.add_argument( | ||
"--global_attn_engine", | ||
type=str, | ||
default=None, | ||
help="Specify if global attention is being used (default: None).", | ||
) | ||
parser.add_argument( | ||
"--global_attn_type", | ||
type=str, | ||
default=None, | ||
help="Specify the global attention type (default: None).", | ||
) | ||
args = parser.parse_args() | ||
main(mpnn_type=args.mpnn_type) |
Oops, something went wrong.