Skip to content

Commit

Permalink
update some tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Feb 8, 2024
1 parent aaabd3f commit 7257f63
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 187 deletions.
34 changes: 19 additions & 15 deletions docs/basic_tutorial/basic_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"outputs": [],
"source": [
"# MAKE CONFIGS\n",
"model = 'angle'\n",
"model = \"angle\"\n",
"# Initialize the generator config (for MLP LANs)\n",
"generator_config = deepcopy(ssms.config.data_generator_config[\"lan\"])\n",
"# Specify generative model (one from the list of included models mentioned above)\n",
Expand Down Expand Up @@ -158,7 +158,7 @@
}
],
"source": [
"#training_data"
"# training_data"
]
},
{
Expand Down Expand Up @@ -186,13 +186,15 @@
"# MAKE DATALOADERS\n",
"\n",
"# List of datafiles (here only one)\n",
"folder_ = \"data/lan_mlp/\" + model + \"/\" # + \"/training_data_0_nbins_0_n_1000/\"\n",
"folder_ = \"data/lan_mlp/\" + model + \"/\" # + \"/training_data_0_nbins_0_n_1000/\"\n",
"file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]\n",
"\n",
"# Training dataset\n",
"torch_training_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_, batch_size=128, features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
" file_ids=file_list_,\n",
" batch_size=128,\n",
" features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
")\n",
"\n",
"torch_training_dataloader = torch.utils.data.DataLoader(\n",
Expand All @@ -205,8 +207,10 @@
"\n",
"# Validation dataset\n",
"torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_, batch_size=128, features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
" file_ids=file_list_,\n",
" batch_size=128,\n",
" features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
")\n",
"\n",
"torch_validation_dataloader = torch.utils.data.DataLoader(\n",
Expand Down Expand Up @@ -580,14 +584,14 @@
"source": [
"# TRAIN MODEL\n",
"model_trainer = lanfactory.trainers.ModelTrainerTorchMLP(\n",
" model=net,\n",
" train_config=train_config,\n",
" train_dl=torch_training_dataloader,\n",
" valid_dl=torch_validation_dataloader,\n",
" allow_abs_path_folder_generation=False,\n",
" pin_memory=True,\n",
" seed = None,\n",
" )\n",
" model=net,\n",
" train_config=train_config,\n",
" train_dl=torch_training_dataloader,\n",
" valid_dl=torch_validation_dataloader,\n",
" allow_abs_path_folder_generation=False,\n",
" pin_memory=True,\n",
" seed=None,\n",
")\n",
"\n",
"# model_trainer.train_model(save_history=True, save_model=True, verbose=0)\n",
"model_trainer.train_and_evaluate()\n",
Expand Down
18 changes: 10 additions & 8 deletions lanfactory/trainers/torch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
try:
import wandb
except ImportError:
print('passing 1')
print("passing 1")
print("wandb not available")

"""This module contains the classes for training TorchMLP models."""
Expand Down Expand Up @@ -138,8 +138,10 @@ def __data_generation(self, batch_ids=None):
elif self.tmp_data[self.label_key].ndim == 2:
y = self.tmp_data[self.label_key][batch_ids]
else:
raise ValueError("Label data has unexpected shape: " + \
str(self.tmp_data[self.label_key].shape))
raise ValueError(
"Label data has unexpected shape: "
+ str(self.tmp_data[self.label_key].shape)
)

if self.label_lower_bound is not None:
y[y < self.label_lower_bound] = self.label_lower_bound
Expand Down Expand Up @@ -316,7 +318,7 @@ def __init__(
def __try_wandb(
self, wandb_project_id="projectid", file_id="fileid", run_id="runid"
):
print('passing 2')
print("passing 2")
try:
wandb.init(
project=wandb_project_id,
Expand All @@ -332,7 +334,7 @@ def __try_wandb(
)
print("Succefully initialized wandb!")
except Exception as e:
print('passing 3')
print("passing 3")
print(e)
print("wandb not available, not storing results there")

Expand Down Expand Up @@ -474,7 +476,7 @@ def train_and_evaluate(
try:
wandb.watch(self.model, criterion=None, log="all", log_freq=1000)
except Exception as e:
print('passing 4')
print("passing 4")
print(e)

step_cnt = 0
Expand Down Expand Up @@ -559,7 +561,7 @@ def train_and_evaluate(
wandb.log({"loss": loss, "val_loss": val_loss}, step=step_cnt)
# print('logged loss')
except Exception as e:
print('passing 5')
print("passing 5")
print(e)

# Saving
Expand Down Expand Up @@ -607,7 +609,7 @@ def train_and_evaluate(
wandb.finish()
print("wandb uploaded")
except Exception as e:
print('passing 6')
print("passing 6")
print(e)

print("Training finished successfully...")
Expand Down
55 changes: 31 additions & 24 deletions notebooks/basic_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"outputs": [],
"source": [
"# MAKE CONFIGS\n",
"model = 'angle'\n",
"model = \"angle\"\n",
"# Initialize the generator config (for MLP LANs)\n",
"generator_config = deepcopy(ssms.config.data_generator_config[\"lan\"])\n",
"# Specify generative model (one from the list of included models mentioned above)\n",
Expand Down Expand Up @@ -429,7 +429,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"for i in range(25):\n",
" print(i)\n",
" training_data = my_dataset_generator.generate_data_training_uniform(save=True)"
Expand All @@ -452,7 +451,7 @@
}
],
"source": [
"#training_data"
"# training_data"
]
},
{
Expand Down Expand Up @@ -480,13 +479,15 @@
"# MAKE DATALOADERS\n",
"\n",
"# List of datafiles (here only one)\n",
"folder_ = \"data/lan_mlp/\" + model + \"/\" # + \"/training_data_0_nbins_0_n_1000/\"\n",
"folder_ = \"data/lan_mlp/\" + model + \"/\" # + \"/training_data_0_nbins_0_n_1000/\"\n",
"file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]\n",
"\n",
"# Training dataset\n",
"torch_training_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_, batch_size=1024, features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
" file_ids=file_list_,\n",
" batch_size=1024,\n",
" features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
")\n",
"\n",
"torch_training_dataloader = torch.utils.data.DataLoader(\n",
Expand All @@ -499,8 +500,10 @@
"\n",
"# Validation dataset\n",
"torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
" file_ids=file_list_, batch_size=1024, features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
" file_ids=file_list_,\n",
" batch_size=1024,\n",
" features_key=\"lan_data\",\n",
" label_key=\"lan_labels\",\n",
")\n",
"\n",
"torch_validation_dataloader = torch.utils.data.DataLoader(\n",
Expand Down Expand Up @@ -544,8 +547,8 @@
"# SPECIFY NETWORK CONFIGS AND TRAINING CONFIGS\n",
"\n",
"network_config = deepcopy(lanfactory.config.network_configs.network_config_mlp)\n",
"network_config['layer_sizes'] = [100, 100, 100, 1]\n",
"network_config['activations'] = ['tanh', 'tanh', 'tanh', 'linear']\n",
"network_config[\"layer_sizes\"] = [100, 100, 100, 1]\n",
"network_config[\"activations\"] = [\"tanh\", \"tanh\", \"tanh\", \"linear\"]\n",
"\n",
"print(\"Network config: \")\n",
"print(network_config)\n",
Expand Down Expand Up @@ -12892,19 +12895,21 @@
"source": [
"# TRAIN MODEL\n",
"model_trainer = lanfactory.trainers.ModelTrainerTorchMLP(\n",
" model=net,\n",
" train_config=train_config,\n",
" train_dl=torch_training_dataloader,\n",
" valid_dl=torch_validation_dataloader,\n",
" allow_abs_path_folder_generation=False,\n",
" pin_memory=True,\n",
" seed = None,\n",
" )\n",
" model=net,\n",
" train_config=train_config,\n",
" train_dl=torch_training_dataloader,\n",
" valid_dl=torch_validation_dataloader,\n",
" allow_abs_path_folder_generation=False,\n",
" pin_memory=True,\n",
" seed=None,\n",
")\n",
"\n",
"# model_trainer.train_model(save_history=True, save_model=True, verbose=0)\n",
"model_trainer.train_and_evaluate(wandb_on = False,\n",
" output_folder = 'data/torch_models/' + model + \"_lan\" + '/',\n",
" output_file_id = model)\n",
"model_trainer.train_and_evaluate(\n",
" wandb_on=False,\n",
" output_folder=\"data/torch_models/\" + model + \"_lan\" + \"/\",\n",
" output_file_id=model,\n",
")\n",
"# LOAD MODEL"
]
},
Expand Down Expand Up @@ -12975,13 +12980,13 @@
"\n",
"# Direct call --> need tensor input\n",
"direct_out = network(\n",
" torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1.], dtype=np.float32))\n",
" torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1.0], dtype=np.float32))\n",
")\n",
"print(\"direct call out: \", direct_out)\n",
"\n",
"# predict_on_batch method\n",
"predict_on_batch_out = network.predict_on_batch(\n",
" np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1.], dtype=np.float32)\n",
" np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1.0], dtype=np.float32)\n",
")\n",
"print(\"predict_on_batch out: \", predict_on_batch_out)"
]
Expand Down Expand Up @@ -13275,7 +13280,9 @@
"# Plot simulations\n",
"for i in range(10):\n",
" my_seed = np.random.choice(1000000)\n",
" sim_out = simulator(model=model, theta=data.values[0, :-2], n_samples=2000, random_state = my_seed)\n",
" sim_out = simulator(\n",
" model=model, theta=data.values[0, :-2], n_samples=2000, random_state=my_seed\n",
" )\n",
" plt.hist(\n",
" sim_out[\"rts\"] * sim_out[\"choices\"],\n",
" bins=100,\n",
Expand Down
Loading

0 comments on commit 7257f63

Please sign in to comment.