Skip to content

Commit

Permalink
any network type is now allowed
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Feb 27, 2024
1 parent 9ce1375 commit 8243794
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions lanfactory/trainers/torch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class TorchMLP(nn.Module):
Network configuration.
input_shape (int):
Input shape.
network_type (str):
Network type.
"""

# AF-TODO: Potentially split this via super-class
Expand All @@ -170,6 +172,7 @@ def __init__(
self,
network_config=None,
input_shape=10,
network_type=None,
**kwargs,
):
super(TorchMLP, self).__init__()
Expand All @@ -182,7 +185,12 @@ def __init__(
else:
self.train_output_type = "logprob"

self.network_type = "lan" if self.train_output_type == "logprob" else "cpn"
if network_type is not None:
self.network_type = network_type
else:
self.network_type = "lan" if self.train_output_type == "logprob" else "cpn"
print('Setting network type to "lan" or "cpn" based on train_output_type. \n' + \
'Note: This is only a default setting, and can be overwritten by the network_type argument.')

self.activations = {
"relu": torch.nn.ReLU(),
Expand Down Expand Up @@ -459,17 +467,17 @@ def train_and_evaluate(
)

# Identify network type:
if self.model.train_output_type == "logprob":
network_type = "lan"
elif self.model.train_output_type == "logits":
network_type = "cpn"
else:
network_type = "unknown"
print(
'Model type identified as "unknown" because the '
"training_output_type attribute"
+ ' of the supplied jax model is neither "logprob", nor "logits"'
)
# if self.model.train_output_type == "logprob":
# network_type = "lan"
# elif self.model.train_output_type == "logits":
# network_type = "cpn"
# else:
# network_type = "unknown"
# print(
# 'Model type identified as "unknown" because the '
# "training_output_type attribute"
# + ' of the supplied jax model is neither "logprob", nor "logits"'
# )

training_history = pd.DataFrame(
np.zeros((self.train_config["n_epochs"], 2)), columns=["epoch", "val_loss"]
Expand Down Expand Up @@ -569,7 +577,7 @@ def train_and_evaluate(

# Saving
full_path = (
output_folder + "/" + output_file_id + "_" + network_type + "_" + run_id
output_folder + "/" + output_file_id + "_" + self.model.network_type + "_" + run_id
)

if save_history or save_all:
Expand Down

0 comments on commit 8243794

Please sign in to comment.