diff --git a/lanfactory/trainers/jax_mlp.py b/lanfactory/trainers/jax_mlp.py index b16069d..bb72dfc 100755 --- a/lanfactory/trainers/jax_mlp.py +++ b/lanfactory/trainers/jax_mlp.py @@ -185,11 +185,7 @@ def load_state_from_file(self, seed=42, input_dim=6, file_path=None): return loaded_state def make_forward_partial( - self, - seed=42, - input_dim=6, - state=None, - add_jitted=False, + self, seed=42, input_dim=6, state=None, add_jitted=False, ): """Creates a partial function for the forward pass of the network. diff --git a/lanfactory/trainers/torch_mlp.py b/lanfactory/trainers/torch_mlp.py index c5a9e17..0023130 100755 --- a/lanfactory/trainers/torch_mlp.py +++ b/lanfactory/trainers/torch_mlp.py @@ -169,11 +169,7 @@ class TorchMLP(nn.Module): # In the end I want 'eval', but differentiable # w.r.t to input ...., might be a problem def __init__( - self, - network_config=None, - input_shape=10, - network_type=None, - **kwargs, + self, network_config=None, input_shape=10, network_type=None, **kwargs, ): super(TorchMLP, self).__init__() @@ -614,8 +610,7 @@ def train_and_evaluate( print("Saving model state dict") train_state_path = full_path + "_train_state_dict_torch.pt" torch.save( - self.model.state_dict(), - train_state_path, + self.model.state_dict(), train_state_path, ) print("Saving model parameters to: " + train_state_path) self.file_path_model = train_state_path @@ -706,7 +701,9 @@ def __init__(self, model_file_path=None, network_config=None, input_dim=None): ) # self.net.load_state_dict(torch.load(self.model_file_path)) if torch.cuda.is_available() == False: - self.net.load_state_dict(torch.load(self.model_file_path, map_location=torch.device('cpu'))) + self.net.load_state_dict( + torch.load(self.model_file_path, map_location=torch.device("cpu")) + ) else: self.net.load_state_dict(torch.load(self.model_file_path)) self.net.to(self.dev) @@ -774,7 +771,9 @@ def __init__(self, model_file_path=None, network_config=None, input_dim=None): ) # self.net.load_state_dict(torch.load(self.model_file_path)) if torch.cuda.is_available() == False: - self.net.load_state_dict(torch.load(self.model_file_path, map_location=torch.device('cpu'))) + self.net.load_state_dict( + torch.load(self.model_file_path, map_location=torch.device("cpu")) + ) else: self.net.load_state_dict(torch.load(self.model_file_path)) self.net.to(self.dev)