-
Notifications
You must be signed in to change notification settings - Fork 0
/
saturn.py
113 lines (95 loc) · 4.43 KB
/
saturn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Parent script that executes Sample Efficient Generative Molecular Design using Memory Manipulation (Saturn).
Takes as input a JSON configuration file that specifies all parameters for the generatve experiment.
Adapted from https://github.com/MolecularAI/Reinvent/input.py.
"""
import json
import argparse
import torch
from utils.utils import set_seed_everywhere
# Distribution Learning
from distribution_learning.distribution_learning import DistributionLearningTrainer
from distribution_learning.dataclass import DistributionLearningConfiguration
# Goal-Directed Generation
from goal_directed_generation.reinforcement_learning import ReinforcementLearningAgent
from goal_directed_generation.dataclass import ReinforcementLearningParameters, GoalDirectedGenerationConfiguration
from experience_replay.dataclass import ExperienceReplayParameters
from hallucinated_memory.dataclass import HallucinatedMemoryParameters
from beam_enumeration.dataclass import BeamEnumerationParameters
from diversity_filter.dataclass import DiversityFilterParameters
# Oracle (for Goal-Directed Generation)
from oracles.oracle import Oracle
from oracles.dataclass import OracleConfiguration
# Scoring
from scoring.scorer import Scorer
from scoring.dataclass import ScoringConfiguration
parser = argparse.ArgumentParser(description="Run Saturn.")
parser.add_argument(
"config",
type=str,
help="Path to the JSON configuration file."
)
def read_json_file(path: str):
with open(path) as f:
json_input = f.read().replace("\r", "").replace("\n", "")
try:
return json.loads(json_input)
except (ValueError, KeyError, TypeError) as e:
print(f"JSON format error in file ${path}: \n ${e}")
if __name__ == "__main__":
args = parser.parse_args()
config = read_json_file(args.config)
running_mode = config["running_mode"].lower()
# Set the seed
device = config["device"]
seed = config["seed"]
set_seed_everywhere(seed, device)
model_architecture = config["model_architecture"]
if running_mode == "distribution_learning":
# 1. Construct the Distribution Learning Trainer
distribution_learning_trainer = DistributionLearningTrainer(
config["logging"]["logging_path"],
config["logging"]["model_checkpoints_dir"],
DistributionLearningConfiguration(
seed,
model_architecture,
**config["distribution_learning"]["parameters"])
)
# 2. Run Distribution Learning
distribution_learning_trainer.run()
elif running_mode == "goal_directed_generation":
# 1. Construct the Oracle
oracle = Oracle(OracleConfiguration(**config["oracle"]))
# 2. Construct the Reinforcement Learning Agent
reinforcement_learning_agent = ReinforcementLearningAgent(
config["logging"]["logging_path"],
config["logging"]["model_checkpoints_dir"],
oracle=oracle,
configuration=GoalDirectedGenerationConfiguration(
seed,
model_architecture,
ReinforcementLearningParameters(**config["goal_directed_generation"]["reinforcement_learning"]),
ExperienceReplayParameters(**config["goal_directed_generation"]["experience_replay"]),
DiversityFilterParameters(**config["goal_directed_generation"]["diversity_filter"]),
HallucinatedMemoryParameters(**config["goal_directed_generation"]["hallucinated_memory"]),
BeamEnumerationParameters(**config["goal_directed_generation"]["beam_enumeration"]),
),
device=device
)
# 3. Run Goal-Directed Generation via Reinforcement Learning
reinforcement_learning_agent.run()
elif running_mode in ["scoring", "scorer"]:
# 1. Construct the Oracle
oracle = Oracle(OracleConfiguration(**config["oracle"]))
# 2. Construct the Scorer
scorer = Scorer(
config["logging"]["logging_path"],
oracle=oracle,
# FIXME: Currently required because the Oracle takes as input a Diversity Filter - remove this dependency
diversity_filter_configuration=DiversityFilterParameters(**config["goal_directed_generation"]["diversity_filter"]),
configuration=ScoringConfiguration(**config["scoring"])
)
# 3. Run Scoring
scorer.run()
else:
raise ValueError(f"Running mode: {running_mode} is not implemented.")