diff --git a/lanfactory/trainers/torch_mlp.py b/lanfactory/trainers/torch_mlp.py index a8d471d..c5a9e17 100755 --- a/lanfactory/trainers/torch_mlp.py +++ b/lanfactory/trainers/torch_mlp.py @@ -704,7 +704,11 @@ def __init__(self, model_file_path=None, network_config=None, input_dim=None): input_shape=self.input_dim, generative_model_id=None, ) - self.net.load_state_dict(torch.load(self.model_file_path)) + # self.net.load_state_dict(torch.load(self.model_file_path)) + if torch.cuda.is_available() == False: + self.net.load_state_dict(torch.load(self.model_file_path, map_location=torch.device('cpu'))) + else: + self.net.load_state_dict(torch.load(self.model_file_path)) self.net.to(self.dev) self.net.eval() @@ -768,7 +772,11 @@ def __init__(self, model_file_path=None, network_config=None, input_dim=None): input_shape=self.input_dim, generative_model_id=None, ) - self.net.load_state_dict(torch.load(self.model_file_path)) + # self.net.load_state_dict(torch.load(self.model_file_path)) + if torch.cuda.is_available() == False: + self.net.load_state_dict(torch.load(self.model_file_path, map_location=torch.device('cpu'))) + else: + self.net.load_state_dict(torch.load(self.model_file_path)) self.net.to(self.dev) # AF-TODO: Seemingly LoadTorchMLPInfer is still not callable !