Skip to content
This repository has been archived by the owner on Jan 26, 2024. It is now read-only.

Commit

Permalink
makes the optimizer optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Coos Baakman committed Dec 20, 2023
1 parent 0c2af21 commit 10e0500
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 35 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,6 @@ model = NeuralNet(data_set, cnn_reg, model_type='3d', task='reg',
metrics_exporters=[OutputExporter(out)],
cuda=False)
# change the optimizer (optional)
model.optimizer = optim.SGD(model.net.parameters(),
lr=0.001,
momentum=0.9,
weight_decay=0.005)
# do the prediction
model.test()
```
Expand Down
32 changes: 22 additions & 10 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,9 @@ def __init__(self, data_set, model,
elif self.cuda:
self.net = self.net.cuda()

# set the optimizer
#self.optimizer = optim.SGD(self.net.parameters(),
# lr=0.005,
# momentum=0.9,
# weight_decay=0.001)
self.optimizer = optim.AdamW(self.net.parameters(),
lr=0.005,
weight_decay=0.001)
if self.pretrained_model:
self.load_optimizer_params()
# set the optimizer to None in the beginning.
# if the user is going to train a model, then he must set it.
self.optimizer = None

# ------------------------------------------
# print
Expand Down Expand Up @@ -298,6 +291,13 @@ def train(self,
save_model (str, optional): 'best' or 'all', save only the
best model or all models.
"""

if self.optimizer is None:
if self.pretrained_model is not None:
self.load_optimizer_params()
else:
raise RuntimeError("no optimizer set, cannot train")

logger.info(f'\n: Batch Size: {train_batch_size}')
if self.cuda:
logger.info(f': NGPU : {self.ngpu}')
Expand Down Expand Up @@ -407,6 +407,18 @@ def load_model_params(self):

def load_optimizer_params(self):
"""Get optimizer parameters from a saved model."""

# guess the optimizer
if 'momentum' in self.state['optimizer']['param_groups']:
self.optimizer = optim.SGD(self.net.parameters(),
lr=0.005,
momentum=0.9,
weight_decay=0.001)
else:
self.optimizer = optim.AdamW(self.net.parameters(),
lr=0.005,
weight_decay=0.001)

self.optimizer.load_state_dict(self.state['optimizer'])

def load_nn_params(self):
Expand Down
Binary file added test/data/models/best_valid_model.pth.tar
Binary file not shown.
26 changes: 7 additions & 19 deletions test/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@



def test_predict_without_target():
feature_modules = ["test.feature.feature1", "test.feature.feature2"]
def test_predict():
target_modules = []
feature_modules = [
'deeprank.features.atomic_contacts',
'deeprank.features.neighbour_profile',
'deeprank.features.accessibility']

atomic_densities = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8}
grid_info = {
Expand All @@ -37,14 +40,8 @@ def test_predict_without_target():

variants = [PdbVariantSelection("101m", "A", 10, valine, cysteine,
protein_accession="P02144", protein_residue_number=10,
),
PdbVariantSelection("101m", "A", 8, glutamine, cysteine,
protein_accession="P02144",
),
PdbVariantSelection("101m", "A", 9, glutamine, cysteine,
protein_accession="P02144", protein_residue_number=9,
)]
augmentation = 5
augmentation = 2

work_dir_path = mkdtemp()
try:
Expand Down Expand Up @@ -72,18 +69,9 @@ def test_predict_without_target():
metrics_directory = os.path.join(work_dir_path, "runs")

neural_net = NeuralNet(dataset, cnn_class, model_type='3d',task='class',
pretrained_model="test/data/models/best_valid_model.pth.tar",
cuda=False, metrics_exporters=[OutputExporter(metrics_directory),
TensorboardBinaryClassificationExporter(metrics_directory)])

neural_net.optimizer = optim.SGD(neural_net.net.parameters(),
lr=0.001,
momentum=0.9,
weight_decay=0.005)

neural_net.state = {}
neural_net.state["task"] = "class"
neural_net.state["criterion"] = neural_net.criterion

neural_net.test()
finally:
rmtree(work_dir_path)
Expand Down

0 comments on commit 10e0500

Please sign in to comment.