From 3252275cce2c2b2aa48033f14ef7f9ec22818169 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 15:55:22 -0400 Subject: [PATCH 1/8] added docstrings throughout --- lanfactory/config/network_configs.py | 5 + lanfactory/onnx/transform_onnx.py | 18 ++- lanfactory/trainers/jax_mlp.py | 174 ++++++++++++++++++++++++++- lanfactory/trainers/torch_mlp.py | 144 +++++++++++++++++----- lanfactory/utils/util_funs.py | 28 +++++ mkdocs.yml | 99 +++++++++++++++ 6 files changed, 435 insertions(+), 33 deletions(-) create mode 100644 mkdocs.yml diff --git a/lanfactory/config/network_configs.py b/lanfactory/config/network_configs.py index 230fb73..c6c304c 100755 --- a/lanfactory/config/network_configs.py +++ b/lanfactory/config/network_configs.py @@ -1,3 +1,8 @@ +"""This Module defines simple examples for network and training configurations that serve +as inputs to the training classes in the package. +""" + + network_config_mlp = { "layer_sizes": [100, 100, 1], "activations": ["tanh", "tanh", "linear"], diff --git a/lanfactory/onnx/transform_onnx.py b/lanfactory/onnx/transform_onnx.py index d86550b..e1cfd6c 100644 --- a/lanfactory/onnx/transform_onnx.py +++ b/lanfactory/onnx/transform_onnx.py @@ -5,6 +5,9 @@ import torch from lanfactory.trainers.torch_mlp import TorchMLP +"""This module contains the function to transform Torch/Jax models to ONNX format. +Can be run as a script. +""" def transform_to_onnx( network_config_file: str, @@ -15,11 +18,16 @@ def transform_to_onnx( """ Transforms a TorchMLP model to ONNX format. - Args: - network_config_file (str): Path to the pickle file containing the network configuration. - state_dict_file (str): Path to the file containing the state dictionary of the model. - input_shape (int): The size of the input tensor for the model. - output_onnx_file (str): Path to the output ONNX file. + Arguments + --------- + network_config_file (str): + Path to the pickle file containing the network configuration. + state_dict_file (str): + Path to the file containing the state dictionary of the model. + input_shape (int): + The size of the input tensor for the model. + output_onnx_file (str): + Path to the output ONNX file. """ with open(network_config_file, "rb") as f: network_config_mlp: Any = pickle.load(f) diff --git a/lanfactory/trainers/jax_mlp.py b/lanfactory/trainers/jax_mlp.py index 41ddd0c..7a864c9 100755 --- a/lanfactory/trainers/jax_mlp.py +++ b/lanfactory/trainers/jax_mlp.py @@ -22,8 +22,22 @@ except ImportError: print("wandb not available") +"""This module contains the JaxMLP class and the ModelTrainerJaxMLP class which + are used to train Jax based LANs and CPNs. +""" def MLPJaxFactory(network_config={}, train=True): + """Factory function to create a MLPJax object. + Arguments + --------- + network_config (dict): + Dictionary containing the network configuration. + train (bool): + Whether the model should be trained or not. + Returns + ------- + MLPJax class initialized with the correct network configuration. + """ return MLPJax( layer_sizes=network_config["layer_sizes"], activations=network_config["activations"], @@ -31,8 +45,20 @@ def MLPJaxFactory(network_config={}, train=True): train=train, ) - class MLPJax(nn.Module): + """JaxMLP class. + Arguments + --------- + layer_sizes (Sequence[int]): + Sequence of integers containing the sizes of the layers. + activations (Sequence[str]): + Sequence of strings containing the activation functions. + train (bool): + Whether the model should be set to training mode or not. + train_output_type (str): + The output type of the model during training. + """ + layer_sizes: Sequence[int] = (100, 90, 80, 1) activations: Sequence[str] = ("tanh", "tanh", "tanh", "linear") train: bool = True @@ -47,6 +73,9 @@ class MLPJax(nn.Module): network_type = "lan" if train_output_type == "logprob" else "cpn" def setup(self): + """Setup function for the JaxMLP class. + Initializes the layers and activation functions. + """ # TODO: Warn if unknown activation string used # TODO: Warn if linear activation is used before final layer self.layers = [nn.Dense(layer_size) for layer_size in self.layer_sizes] @@ -60,6 +89,18 @@ def setup(self): # self.network_type = "lan" if self.train_output_type == "logprob" else "cpn" def __call__(self, inputs): + """Call function for the JaxMLP class. + Performs forward pass through the network. + + Arguments + --------- + inputs (jax.numpy.ndarray): + Input tensor. + Returns + ------- + jax.numpy.ndarray: + Output tensor. + """ x = inputs for i, lyr in enumerate(self.layers): @@ -82,6 +123,23 @@ def __call__(self, inputs): return x def load_state_from_file(self, seed=42, input_dim=6, file_path=None): + """Loads the state dictionary from a file. + + Arguments + --------- + seed (int): + Seed for the random number generator. + input_dim (int): + Dimension of the input tensor. + file_path (str): + Path to the file containing the state dictionary. + + Returns + ------- + flax.core.frozen_dict.FrozenDict: + The state dictionary. + """ + if file_path is None: raise ValueError( "file_path argument needs to be speficied! " @@ -117,6 +175,29 @@ def make_forward_partial( file_path=None, add_jitted=False, ): + """Creates a partial function for the forward pass of the network. + + Arguments + --------- + seed (int): + Seed for the random number generator. + input_dim (int): + Dimension of the input tensor. + state_dict_from_file (bool): + Whether the state dictionary should be loaded from a file or not. + state (flax.core.frozen_dict.FrozenDict): + The state dictionary (if not loaded from file). + file_path (str): + Path to the file containing the state dictionary (if loaded from file). + add_jitted (bool): + Whether the partial function should be jitted or not. + + Returns + ------- + Callable: + The partial function for the forward pass of the network. + """ + if state_dict_from_file: if file_path is None: raise ValueError( @@ -156,6 +237,31 @@ def __init__( pin_memory=False, seed=None, ): + """Class for training JaxMLP models. + + Arguments + --------- + train_config (dict): + Dictionary containing the training configuration. + model (MLPJax): + The MLPJax model to be trained. + train_dl (torch.utils.data.DataLoader): + The training data loader. + valid_dl (torch.utils.data.DataLoader): + The validation data loader. + allow_abs_path_folder_generation (bool): + Whether the folder for the output files should be created or not. + pin_memory (bool): + Whether the data loader should pin memory or not. + seed (int): + Seed for the random number generator. + + Returns + ------- + ModelTrainerJaxMLP: + The ModelTrainerJaxMLP object. + + """ if "loss_dict" not in train_config.keys(): self.loss_dict = { "huber": {"fun": optax.huber_loss, "kwargs": {"delta": 1}}, @@ -198,12 +304,14 @@ def __init__( self.state = "Please run training for this attribute to be specified!" def __get_loss(self): + """Define loss function.""" self.loss = partial( self.loss_dict[self.train_config["loss"]]["fun"], **self.loss_dict[self.train_config["loss"]]["kwargs"], ) def __make_apply_model(self, train=True): + """Compile forward pass with loss aplication""" @jax.jit def apply_model_core(state, features, labels): def loss_fn(params): @@ -223,6 +331,7 @@ def loss_fn(params): return apply_model_core def __make_update_model(self): + """Compile gradient application""" @jax.jit def update_model(state, grads): return state.apply_gradients(grads=grads) @@ -232,6 +341,18 @@ def update_model(state, grads): def __try_wandb( self, wandb_project_id="projectid", file_id="fileid", run_id="runid" ): + """Helper function to initialize wandb + + Arguments + --------- + wandb_project_id (str): + The wandb project id. + file_id (str): + The file id. + run_id (str): + The run id. + + """ try: wandb.init( project=wandb_project_id, @@ -250,6 +371,7 @@ def __try_wandb( print("No wandb found, proceeding without logging") def create_train_state(self, rng): + """Create initial train state""" params = self.model.init( rng, jnp.ones((1, self.train_dl.dataset.input_dim)) ) # self.train_config['input_size']))) @@ -266,6 +388,25 @@ def create_train_state(self, rng): ) def run_epoch(self, state, train=True, verbose=1, epoch=0, max_epochs=0): + """Run one epoch of training or validation + Arguments + --------- + state (flax.core.frozen_dict.FrozenDict): + The state dictionary. + train (bool): + Whether the model should is in training mode or not. + verbose (int): + The verbosity level. + epoch (int): + The current epoch. + max_epochs (int): + The maximum number of epochs. + + Returns + ------- + tuple (flax.core.frozen_dict.FrozenDict, float): + The state dictionary and the mean epoch loss. + """ if train: tmp_dataloader = self.train_dl train_str = "Training" @@ -353,6 +494,37 @@ def train_and_evaluate( save_data_details=True, verbose=1, ): + """Train and evaluate JAXMLP model. + Arguments + --------- + + output_folder (str): + Path to the output folder. + output_file_id (str): + The file id. + run_id (str): + The run id. + wandb_on (bool): + Whether to use wandb or not. + wandb_project_id (str): + Project id for wandb. + save_history (bool): + Whether to save the training history or not. + save_model (bool): + Whether to save the model or not. + save_config (bool): + Whether to save the training configuration or not. + save_all (bool): + Whether to save all files or not. + save_data_details (bool): + Whether to save the data details or not. + verbose (int): + The verbosity level. + Returns + ------- + flax.core.frozen_dict.FrozenDict: + The final state dictionary (model state). + """ try_gen_folder( folder=output_folder, allow_abs_path_folder_generation=self.allow_abs_path_folder_generation, diff --git a/lanfactory/trainers/torch_mlp.py b/lanfactory/trainers/torch_mlp.py index 8b7a8b3..cff8e2c 100755 --- a/lanfactory/trainers/torch_mlp.py +++ b/lanfactory/trainers/torch_mlp.py @@ -15,8 +15,29 @@ except ImportError: print("wandb not available") +"""This module contains the classes for training TorchMLP models.""" class DatasetTorch(torch.utils.data.Dataset): + """Dataset class for TorchMLP training. + + Arguments + ---------- + file_ids (list): + List of paths to the data files. + batch_size (int): + Batch size. + label_lower_bound (float): + Lower bound for the labels. + label_upper_bound (float): + Upper bound for the labels. + features_key (str): + Key for the features in the data files. + label_key (str): + Key for the labels in the data files. + out_framework (str): + Output framework. + """ + def __init__( self, file_ids, @@ -122,6 +143,15 @@ def __data_generation(self, batch_ids=None): class TorchMLP(nn.Module): + """TorchMLP class. + + Arguments + ---------- + network_config (dict): + Network configuration. + input_shape (int): + Input shape. + """ # AF-TODO: Potentially split this via super-class # In the end I want 'eval', but differentiable # w.r.t to input ...., might be a problem @@ -192,6 +222,18 @@ def __init__( # Define forward pass def forward(self, x): + """Forward pass through network. + + Arguments + --------- + x (torch.Tensor): + Input tensor. + + Returns + ------- + torch.Tensor: + Output tensor. + """ for i in range(self.len_layers - 1): x = self.layers[i](x) if self.training or self.train_output_type == "logprob": @@ -215,7 +257,24 @@ def __init__( pin_memory=True, seed=None, ): - # Class to train MLP models (This is in fact not MLP specific --> rename?) + """Class to train Torch Models. + Arguments + --------- + train_config (dict): + Training configuration. + model (TorchMLP): + TorchMLP model. + train_dl (DatasetTorch): + Training dataloader. + valid_dl (DatasetTorch): + Validation dataloader. + allow_abs_path_folder_generation (bool): + Whether to allow absolute path folder generation. + pin_memory (bool): + Whether to pin memory (dataloader). Can affect speed. + seed (int): + Random seed. + """ torch.backends.cudnn.benchmark = True self.dev = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -247,28 +306,7 @@ def __try_wandb( + run_id ), config=self.train_config, - # { - # "learning_rate": self.train_config["learning_rate"], - # "weight_decay": self.train_config["weight_decay"], - # "epochs": self.train_config["n_epochs"], - # "batch_size": self.train_config["batch_size"] - # "file_id": file_id, - # "lr_scheduler": self.train_config["lr_scheduler"], - # "lr_scheduler_params": self.train_config["lr_scheduler_params"], - # "run_id": run_id, - # }, ) - - # wandb.config = { - # "learning_rate": self.train_config["learning_rate"], - # "weight_decay": self.train_config["weight_decay"], - # "epochs": self.train_config["n_epochs"], - # "batch_size": self.train_config["gpu_batch_size"] - # if torch.cuda.is_available() - # else self.train_config["cpu_batch_size"], - # "model_id": self.model.model_id, - # } - print("Succefully initialized wandb!") except Exception as e: print(e) @@ -353,6 +391,33 @@ def train_and_evaluate( save_data_details=True, verbose=1, ): + """Train and evaluate the model. + + Arguments + --------- + output_folder (str): + Output folder. + output_file_id (str): + Output file ID. + run_id (str): + Run ID. + wandb_on (bool): + Whether to use wandb. + wandb_project_id (str): + Wandb project ID. + save_history (bool): + Whether to save the training history. + save_model (bool): + Whether to save the model. + save_config (bool): + Whether to save the training configuration. + save_all (bool): + Whether to save all. + save_data_details (bool): + Whether to save the data details. + verbose (int): + Verbosity level. + """ try_gen_folder( folder=output_folder, allow_abs_path_folder_generation=self.allow_abs_path_folder_generation, @@ -516,8 +581,21 @@ def train_and_evaluate( print("Training finished successfully...") - class LoadTorchMLPInfer: + """Class to load TorchMLP models for inference. (This + was originally useful directly for application in the + HDDM toolbox). + + Arguments + --------- + model_file_path (str): + Path to the model file. + network_config (dict): + Network configuration. + input_dim (int): + Input dimension. + + """ def __init__(self, model_file_path=None, network_config=None, input_dim=None): torch.backends.cudnn.benchmark = True self.dev = ( @@ -548,8 +626,9 @@ def predict_on_batch(self, x=None): from a matrix input. To be used primarily through the HDDM toolbox. - :Arguments: - x: numpy.ndarray(dtype=numpy.float32) + Arguments + --------- + x (numpy.ndarray(dtype=numpy.float32)): Matrix which will be passed through the network. LANs expect the matrix columns to follow a specific order. When used in HDDM, x will be passed as follows. @@ -558,8 +637,9 @@ def predict_on_batch(self, x=None): The last two columns are filled with trial wise reaction times and choices. When not used via HDDM, no such restriction applies. - :Output: - numpy.ndarray(dtype = numpy.float32) + Output + ------ + numpy.ndarray(dtype = numpy.float32): Output of the network. When called through HDDM, this is expected as trial-wise log likelihoods of a given generative model. @@ -569,6 +649,16 @@ def predict_on_batch(self, x=None): class LoadTorchMLP: + """Class to load TorchMLP models. + + Arguments + --------- + model_file_path (str): + Path to the model file. + network_config (dict): + Network configuration. + input_dim (int): + Input dimension.""" def __init__(self, model_file_path=None, network_config=None, input_dim=None): ##torch.backends.cudnn.benchmark = True self.dev = ( diff --git a/lanfactory/utils/util_funs.py b/lanfactory/utils/util_funs.py index 93790b8..ae921b3 100755 --- a/lanfactory/utils/util_funs.py +++ b/lanfactory/utils/util_funs.py @@ -2,8 +2,20 @@ import pickle import warnings +"""Some utility functions for the lanfactory package.""" + def try_gen_folder(folder=None, allow_abs_path_folder_generation=True): + """Fucntion to generate a folder from a string. If the folder already exists, it will not be generated. + + Arguments + --------- + folder (str): + The folder string to generate. + allow_abs_path_folder_generation (bool): + If True, the folder string is treated as an absolute path. If False, the folder string is treated as a relative path. + + """ folder_list = folder.split("/") # Check if folder string supplied defines a relative or absolute path @@ -65,6 +77,22 @@ def save_configs( train_config=None, allow_abs_path_folder_generation=True, ): + """Function to save the network and training configurations to a folder. + + Arguments + --------- + model_id (str): + The id of the model. + save_folder (str): + The folder to save the configurations to. + network_config (dict): + The network configuration dictionary. + train_config (dict): + The training configuration dictionary. + allow_abs_path_folder_generation (bool): + If True, the folder string is treated as an absolute path. If False, the folder string is treated as a relative path. + """ + # Generate save_folder if it doesn't yet exist try_gen_folder( folder=save_folder, diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..45fb64d --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,99 @@ +site_name: LANFactory +repo_name: AlexanderFengler/LANFactory +repo_url: htps://github.com/AlexanderFengler/LANFactory +edit_uri: edit/main/docs + +nav: + - Home: + - Overview: index.md + - Basic Tutorial: + - Installation: basic_tutorial/basic_tutorial.ipynb + - API: + - ssms: api/ssms.md + - basic simulators: api/basic_simulators.md + - config: api/config.md + - data generators: api/dataset_generators.md + +plugins: + - search + - autorefs + - mkdocs-jupyter: + execute: True + execute_ignore: + - basic_tutorial/basic_tutorial.ipynb + - mkdocstrings: + default_handler: python + handlers: + python: + import: + - https://docs.python.org/3/objects.inv + - https://mkdocstrings.github.io/objects.inv + - https://mkdocstrings.github.io/griffe/objects.inv + options: + show_submodules: true + separate_signature: true + merge_init_into_class: true + docstring_options: + ignore_init_summary: true + docstring_style: "numpy" + docstring_section_style: "list" + show_root_members_full_path: true + show_object_full_path: false + show_category_heading: true + show_signature_annotations: false + show_source: false + group_by_category: false + signature_crossrefs: true + +theme: + name: material + custom_dir: docs/overrides + features: + - navigation.tracking + - navigation.tabs + - navigation.tabs.sticky + - navigation.sections + - navigation.path + - navigation.top + - content.code.copy + - content.action.view + - content.action.edit + - header.autohide + - announce.dismiss + palette: + # Palette toggle for automatic mode + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to dark mode + + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + toggle: + icon: material/brightness-7 + name: Switch to light mode + + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + toggle: + icon: material/brightness-4 + name: Switch to automatic mode + +extra: + homepage: "https://AlexanderFengler.github.io/LANFactory/" + +markdown_extensions: + - admonition + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences + - attr_list + - pymdownx.emoji: + emoji_index: !!python/name:materialx.emoji.twemoji + emoji_generator: !!python/name:materialx.emoji.to_svg \ No newline at end of file From 1abf6b3add27e98d06d2ed276d1fe8e90aac66ff Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 15:56:40 -0400 Subject: [PATCH 2/8] black and ruff --- lanfactory/onnx/transform_onnx.py | 9 ++-- lanfactory/trainers/jax_mlp.py | 80 ++++++++++++++++--------------- lanfactory/trainers/torch_mlp.py | 51 +++++++++++--------- lanfactory/utils/util_funs.py | 24 +++++----- 4 files changed, 88 insertions(+), 76 deletions(-) diff --git a/lanfactory/onnx/transform_onnx.py b/lanfactory/onnx/transform_onnx.py index e1cfd6c..4ff5b17 100644 --- a/lanfactory/onnx/transform_onnx.py +++ b/lanfactory/onnx/transform_onnx.py @@ -9,6 +9,7 @@ Can be run as a script. """ + def transform_to_onnx( network_config_file: str, state_dict_file: str, @@ -20,13 +21,13 @@ def transform_to_onnx( Arguments --------- - network_config_file (str): + network_config_file (str): Path to the pickle file containing the network configuration. - state_dict_file (str): + state_dict_file (str): Path to the file containing the state dictionary of the model. - input_shape (int): + input_shape (int): The size of the input tensor for the model. - output_onnx_file (str): + output_onnx_file (str): Path to the output ONNX file. """ with open(network_config_file, "rb") as f: diff --git a/lanfactory/trainers/jax_mlp.py b/lanfactory/trainers/jax_mlp.py index 7a864c9..ff0d114 100755 --- a/lanfactory/trainers/jax_mlp.py +++ b/lanfactory/trainers/jax_mlp.py @@ -26,13 +26,14 @@ are used to train Jax based LANs and CPNs. """ + def MLPJaxFactory(network_config={}, train=True): """Factory function to create a MLPJax object. Arguments --------- - network_config (dict): + network_config (dict): Dictionary containing the network configuration. - train (bool): + train (bool): Whether the model should be trained or not. Returns ------- @@ -45,11 +46,12 @@ def MLPJaxFactory(network_config={}, train=True): train=train, ) + class MLPJax(nn.Module): """JaxMLP class. Arguments --------- - layer_sizes (Sequence[int]): + layer_sizes (Sequence[int]): Sequence of integers containing the sizes of the layers. activations (Sequence[str]): Sequence of strings containing the activation functions. @@ -57,7 +59,7 @@ class MLPJax(nn.Module): Whether the model should be set to training mode or not. train_output_type (str): The output type of the model during training. - """ + """ layer_sizes: Sequence[int] = (100, 90, 80, 1) activations: Sequence[str] = ("tanh", "tanh", "tanh", "linear") @@ -73,7 +75,7 @@ class MLPJax(nn.Module): network_type = "lan" if train_output_type == "logprob" else "cpn" def setup(self): - """Setup function for the JaxMLP class. + """Setup function for the JaxMLP class. Initializes the layers and activation functions. """ # TODO: Warn if unknown activation string used @@ -191,7 +193,7 @@ def make_forward_partial( Path to the file containing the state dictionary (if loaded from file). add_jitted (bool): Whether the partial function should be jitted or not. - + Returns ------- Callable: @@ -312,6 +314,7 @@ def __get_loss(self): def __make_apply_model(self, train=True): """Compile forward pass with loss aplication""" + @jax.jit def apply_model_core(state, features, labels): def loss_fn(params): @@ -332,6 +335,7 @@ def loss_fn(params): def __make_update_model(self): """Compile gradient application""" + @jax.jit def update_model(state, grads): return state.apply_gradients(grads=grads) @@ -342,7 +346,7 @@ def __try_wandb( self, wandb_project_id="projectid", file_id="fileid", run_id="runid" ): """Helper function to initialize wandb - + Arguments --------- wandb_project_id (str): @@ -351,7 +355,7 @@ def __try_wandb( The file id. run_id (str): The run id. - + """ try: wandb.init( @@ -401,7 +405,7 @@ def run_epoch(self, state, train=True, verbose=1, epoch=0, max_epochs=0): The current epoch. max_epochs (int): The maximum number of epochs. - + Returns ------- tuple (flax.core.frozen_dict.FrozenDict, float): @@ -495,36 +499,36 @@ def train_and_evaluate( verbose=1, ): """Train and evaluate JAXMLP model. - Arguments - --------- + Arguments + --------- - output_folder (str): - Path to the output folder. - output_file_id (str): - The file id. - run_id (str): - The run id. - wandb_on (bool): - Whether to use wandb or not. - wandb_project_id (str): - Project id for wandb. - save_history (bool): - Whether to save the training history or not. - save_model (bool): - Whether to save the model or not. - save_config (bool): - Whether to save the training configuration or not. - save_all (bool): - Whether to save all files or not. - save_data_details (bool): - Whether to save the data details or not. - verbose (int): - The verbosity level. - Returns - ------- - flax.core.frozen_dict.FrozenDict: - The final state dictionary (model state). - """ + output_folder (str): + Path to the output folder. + output_file_id (str): + The file id. + run_id (str): + The run id. + wandb_on (bool): + Whether to use wandb or not. + wandb_project_id (str): + Project id for wandb. + save_history (bool): + Whether to save the training history or not. + save_model (bool): + Whether to save the model or not. + save_config (bool): + Whether to save the training configuration or not. + save_all (bool): + Whether to save all files or not. + save_data_details (bool): + Whether to save the data details or not. + verbose (int): + The verbosity level. + Returns + ------- + flax.core.frozen_dict.FrozenDict: + The final state dictionary (model state). + """ try_gen_folder( folder=output_folder, allow_abs_path_folder_generation=self.allow_abs_path_folder_generation, diff --git a/lanfactory/trainers/torch_mlp.py b/lanfactory/trainers/torch_mlp.py index cff8e2c..d9d61db 100755 --- a/lanfactory/trainers/torch_mlp.py +++ b/lanfactory/trainers/torch_mlp.py @@ -17,9 +17,10 @@ """This module contains the classes for training TorchMLP models.""" + class DatasetTorch(torch.utils.data.Dataset): """Dataset class for TorchMLP training. - + Arguments ---------- file_ids (list): @@ -152,6 +153,7 @@ class TorchMLP(nn.Module): input_shape (int): Input shape. """ + # AF-TODO: Potentially split this via super-class # In the end I want 'eval', but differentiable # w.r.t to input ...., might be a problem @@ -223,12 +225,12 @@ def __init__( # Define forward pass def forward(self, x): """Forward pass through network. - + Arguments --------- x (torch.Tensor): Input tensor. - + Returns ------- torch.Tensor: @@ -258,23 +260,23 @@ def __init__( seed=None, ): """Class to train Torch Models. - Arguments - --------- - train_config (dict): - Training configuration. - model (TorchMLP): - TorchMLP model. - train_dl (DatasetTorch): - Training dataloader. - valid_dl (DatasetTorch): - Validation dataloader. - allow_abs_path_folder_generation (bool): - Whether to allow absolute path folder generation. - pin_memory (bool): - Whether to pin memory (dataloader). Can affect speed. - seed (int): - Random seed. - """ + Arguments + --------- + train_config (dict): + Training configuration. + model (TorchMLP): + TorchMLP model. + train_dl (DatasetTorch): + Training dataloader. + valid_dl (DatasetTorch): + Validation dataloader. + allow_abs_path_folder_generation (bool): + Whether to allow absolute path folder generation. + pin_memory (bool): + Whether to pin memory (dataloader). Can affect speed. + seed (int): + Random seed. + """ torch.backends.cudnn.benchmark = True self.dev = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -581,9 +583,10 @@ def train_and_evaluate( print("Training finished successfully...") + class LoadTorchMLPInfer: - """Class to load TorchMLP models for inference. (This - was originally useful directly for application in the + """Class to load TorchMLP models for inference. (This + was originally useful directly for application in the HDDM toolbox). Arguments @@ -596,6 +599,7 @@ class LoadTorchMLPInfer: Input dimension. """ + def __init__(self, model_file_path=None, network_config=None, input_dim=None): torch.backends.cudnn.benchmark = True self.dev = ( @@ -650,7 +654,7 @@ def predict_on_batch(self, x=None): class LoadTorchMLP: """Class to load TorchMLP models. - + Arguments --------- model_file_path (str): @@ -659,6 +663,7 @@ class LoadTorchMLP: Network configuration. input_dim (int): Input dimension.""" + def __init__(self, model_file_path=None, network_config=None, input_dim=None): ##torch.backends.cudnn.benchmark = True self.dev = ( diff --git a/lanfactory/utils/util_funs.py b/lanfactory/utils/util_funs.py index ae921b3..33719df 100755 --- a/lanfactory/utils/util_funs.py +++ b/lanfactory/utils/util_funs.py @@ -6,14 +6,15 @@ def try_gen_folder(folder=None, allow_abs_path_folder_generation=True): - """Fucntion to generate a folder from a string. If the folder already exists, it will not be generated. + """Function to generate a folder from a string. If the folder already exists, it will not be generated. Arguments --------- - folder (str): + folder (str): The folder string to generate. - allow_abs_path_folder_generation (bool): - If True, the folder string is treated as an absolute path. If False, the folder string is treated as a relative path. + allow_abs_path_folder_generation (bool): + If True, the folder string is treated as an absolute path. + If False, the folder string is treated as a relative path. """ folder_list = folder.split("/") @@ -81,18 +82,19 @@ def save_configs( Arguments --------- - model_id (str): + model_id (str): The id of the model. - save_folder (str): + save_folder (str): The folder to save the configurations to. - network_config (dict): + network_config (dict): The network configuration dictionary. - train_config (dict): + train_config (dict): The training configuration dictionary. - allow_abs_path_folder_generation (bool): - If True, the folder string is treated as an absolute path. If False, the folder string is treated as a relative path. + allow_abs_path_folder_generation (bool): + If True, the folder string is treated as an absolute path. + If False, the folder string is treated as a relative path. """ - + # Generate save_folder if it doesn't yet exist try_gen_folder( folder=save_folder, From 8b2b3d453e4be8734775376e7e7d9df96fb0a1c4 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 22:22:13 -0400 Subject: [PATCH 3/8] add scaffolding for mkdocs, add in onnx submodule --- README.md | 6 + docs/api/config.md | 1 + docs/api/lanfactory.md | 1 + docs/api/onnx.md | 1 + docs/api/trainers.md | 1 + docs/api/utils.md | 1 + .../basic_tutorial}/basic_tutorial.ipynb | 0 docs/index.md | 332 ++++++++++++++++++ docs/overrides/main.html | 18 + lanfactory/__init__.py | 3 +- lanfactory/onnx/__init__.py | 3 + mkdocs.yml | 7 +- 12 files changed, 370 insertions(+), 4 deletions(-) create mode 100644 docs/api/config.md create mode 100644 docs/api/lanfactory.md create mode 100644 docs/api/onnx.md create mode 100644 docs/api/trainers.md create mode 100644 docs/api/utils.md rename {notebooks => docs/basic_tutorial}/basic_tutorial.ipynb (100%) create mode 100644 docs/index.md create mode 100644 docs/overrides/main.html create mode 100644 lanfactory/onnx/__init__.py diff --git a/README.md b/README.md index 404f5df..d0a6474 100755 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ # LANfactory +![PyPI](https://img.shields.io/pypi/v/lanfactory) +![PyPI_dl](https://img.shields.io/pypi/dm/lanfactory) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + + Lightweight python package to help with training [LANs](https://elifesciences.org/articles/65074) (Likelihood approximation networks). ### Quick Start diff --git a/docs/api/config.md b/docs/api/config.md new file mode 100644 index 0000000..49422e7 --- /dev/null +++ b/docs/api/config.md @@ -0,0 +1 @@ +:::lanfactory.config \ No newline at end of file diff --git a/docs/api/lanfactory.md b/docs/api/lanfactory.md new file mode 100644 index 0000000..c7f2ea1 --- /dev/null +++ b/docs/api/lanfactory.md @@ -0,0 +1 @@ +:::lanfactory diff --git a/docs/api/onnx.md b/docs/api/onnx.md new file mode 100644 index 0000000..9fabb9d --- /dev/null +++ b/docs/api/onnx.md @@ -0,0 +1 @@ +:::lanfactory.onnx \ No newline at end of file diff --git a/docs/api/trainers.md b/docs/api/trainers.md new file mode 100644 index 0000000..cdb65a0 --- /dev/null +++ b/docs/api/trainers.md @@ -0,0 +1 @@ +:::lanfactory.trainers \ No newline at end of file diff --git a/docs/api/utils.md b/docs/api/utils.md new file mode 100644 index 0000000..807f76f --- /dev/null +++ b/docs/api/utils.md @@ -0,0 +1 @@ +:::lanfactory.utils \ No newline at end of file diff --git a/notebooks/basic_tutorial.ipynb b/docs/basic_tutorial/basic_tutorial.ipynb similarity index 100% rename from notebooks/basic_tutorial.ipynb rename to docs/basic_tutorial/basic_tutorial.ipynb diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..d0a6474 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,332 @@ +# LANfactory + +![PyPI](https://img.shields.io/pypi/v/lanfactory) +![PyPI_dl](https://img.shields.io/pypi/dm/lanfactory) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + + +Lightweight python package to help with training [LANs](https://elifesciences.org/articles/65074) (Likelihood approximation networks). + +### Quick Start + +The `LANfactory` package is a light-weight convenience package for training `likelihood approximation networks` (LANs) in torch (or keras), +starting from supplied training data. + +[LANs](https://elifesciences.org/articles/65074), although more general in potential scope of applications, were conceived in the context of sequential sampling modeling +to account for cognitive processes giving rise to *choice* and *reaction time* data in *n-alternative forced choice experiments* commonly encountered in the cognitive sciences. + +In this quick tutorial we will use the [`ssms`](https://github.com/AlexanderFengler/ssm_simulators) package to generate our training data using such a sequential sampling model (SSM). The use is in no way bound to utilize the `ssms` package. + +#### Install + +To install the `ssms` package type, + +`pip install git+https://github.com/AlexanderFengler/ssm_simulators` + +To install the `LANfactory` package type, + +`pip install git+https://github.com/AlexanderFengler/LANfactory` + +Necessary dependency should be installed automatically in the process. + +#### Basic Tutorial + + +```python +# Load necessary packages +import ssms +import lanfactory +import os +import numpy as np +from copy import deepcopy +import torch +``` + +#### Generate Training Data +First we need to generate some training data. As mentioned above we will do so using the `ssms` python package, however without delving into a detailed explanation +of this package. Please refer to the [basic ssms tutorial] (https://github.com/AlexanderFengler/ssm_simulators) in case you want to learn more. + + +```python +# MAKE CONFIGS + +# Initialize the generator config (for MLP LANs) +generator_config = deepcopy(ssms.config.data_generator_config['lan']['mlp']) +# Specify generative model (one from the list of included models mentioned above) +generator_config['dgp_list'] = 'angle' +# Specify number of parameter sets to simulate +generator_config['n_parameter_sets'] = 100 +# Specify how many samples a simulation run should entail +generator_config['n_samples'] = 1000 +# Specify folder in which to save generated data +generator_config['output_folder'] = 'data/lan_mlp/' + +# Make model config dict +model_config = ssms.config.model_config['angle'] +``` + +```python +# MAKE DATA + +my_dataset_generator = ssms.dataset_generators.data_generator(generator_config = generator_config, + model_config = model_config) + +training_data = my_dataset_generator.generate_data_training_uniform(save = True) +``` + + n_cpus used: 6 + checking: data/lan_mlp/ + simulation round: 1 of 10 + simulation round: 2 of 10 + simulation round: 3 of 10 + simulation round: 4 of 10 + simulation round: 5 of 10 + simulation round: 6 of 10 + simulation round: 7 of 10 + simulation round: 8 of 10 + simulation round: 9 of 10 + simulation round: 10 of 10 + Writing to file: data/lan_mlp/training_data_0_nbins_0_n_1000/angle/training_data_angle_ef5b9e0eb76c11eca684acde48001122.pickle + + +#### Prepare for Training + +Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data. +Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples. +So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file. +The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0. + +The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally. + +You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one. + + +```python +# MAKE DATALOADERS + +# List of datafiles (here only one) +folder_ = 'data/lan_mlp/training_data_0_nbins_0_n_1000/angle/' +file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)] + +# Training dataset +torch_training_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_, + batch_size = 128) + +torch_training_dataloader = torch.utils.data.DataLoader(torch_training_dataset, + shuffle = True, + batch_size = None, + num_workers = 1, + pin_memory = True) + +# Validation dataset +torch_validation_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_, + batch_size = 128) + +torch_validation_dataloader = torch.utils.data.DataLoader(torch_validation_dataset, + shuffle = True, + batch_size = None, + num_workers = 1, + pin_memory = True) +``` + +Now we define two configuration dictionariers, + +1. The `network_config` dictionary defines the architecture and properties of the network +2. The `train_config` dictionary defines properties concerning training hyperparameters + +Two examples (which we take as provided by the package, but which you can adjust according to your needs) are provided below. + + +```python +# SPECIFY NETWORK CONFIGS AND TRAINING CONFIGS + +network_config = lanfactory.config.network_configs.network_config_mlp + +print('Network config: ') +print(network_config) + +train_config = lanfactory.config.network_configs.train_config_mlp + +print('Train config: ') +print(train_config) +``` + + Network config: + {'layer_types': ['dense', 'dense', 'dense'], 'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'loss': ['huber'], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']} + Train config: + {'batch_size': 128, 'n_epochs': 10, 'optimizer': 'adam', 'learning_rate': 0.002, 'loss': 'huber', 'save_history': True, 'metrics': [, ], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']} + + +We can now load a network, and save the configuration files for convenience. + + +```python +# LOAD NETWORK +net = lanfactory.trainers.TorchMLP(network_config = deepcopy(network_config), + input_shape = torch_training_dataset.input_dim, + save_folder = '/data/torch_models/', + generative_model_id = 'angle') + +# SAVE CONFIGS +lanfactory.utils.save_configs(model_id = net.model_id + '_torch_', + save_folder = 'data/torch_models/angle/', + network_config = network_config, + train_config = train_config, + allow_abs_path_folder_generation = True) +``` + +To finally train the network we supply our network, the dataloaders and training config to the `ModelTrainerTorchMLP` class, from `lanfactory.trainers`. + + +```python +# TRAIN MODEL +model_trainer.train_model(save_history = True, + save_model = True, + verbose = 0) +``` + + Epoch took 0 / 10, took 11.54538607597351 seconds + epoch 0 / 10, validation_loss: 0.3431 + Epoch took 1 / 10, took 13.032279014587402 seconds + epoch 1 / 10, validation_loss: 0.2732 + Epoch took 2 / 10, took 12.421074867248535 seconds + epoch 2 / 10, validation_loss: 0.1941 + Epoch took 3 / 10, took 12.097641229629517 seconds + epoch 3 / 10, validation_loss: 0.2028 + Epoch took 4 / 10, took 12.030233144760132 seconds + epoch 4 / 10, validation_loss: 0.184 + Epoch took 5 / 10, took 12.695374011993408 seconds + epoch 5 / 10, validation_loss: 0.1433 + Epoch took 6 / 10, took 12.177874326705933 seconds + epoch 6 / 10, validation_loss: 0.1115 + Epoch took 7 / 10, took 11.908828258514404 seconds + epoch 7 / 10, validation_loss: 0.1084 + Epoch took 8 / 10, took 12.066670179367065 seconds + epoch 8 / 10, validation_loss: 0.0864 + Epoch took 9 / 10, took 12.37562108039856 seconds + epoch 9 / 10, validation_loss: 0.07484 + Saving training history + Saving model state dict + Training finished successfully... + + +#### Load Model for Inference and Call + +The `LANfactory` provides some convenience functions to use networks for inference after training. +We can load a model using the `LoadTorchMLPInfer` class, which then allows us to run fast inference via either +a direct call, which expects a `torch.tensor` as input, or the `predict_on_batch` method, which expects a `numpy.array` +of `dtype`, `np.float32`. + + +```python +network_path_list = os.listdir('data/torch_models/angle') +network_file_path = ['data/torch_models/angle/' + file_ for file_ in network_path_list if 'state_dict' in file_][0] + +network = lanfactory.trainers.LoadTorchMLPInfer(model_file_path = network_file_path, + network_config = network_config, + input_dim = torch_training_dataset.input_dim) +``` + + +```python + +# Two ways to call the network + +# Direct call --> need tensor input +direct_out = network(torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype = np.float32))) +print('direct call out: ', direct_out) + +# predict_on_batch method +predict_on_batch_out = network.predict_on_batch(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype = np.float32)) +print('predict_on_batch out: ', predict_on_batch_out) + +``` + + direct call out: tensor([-16.4997]) + predict_on_batch out: [-16.499687] + + +#### A peek into the first passage distribution computed by the network + +We can compare the learned likelihood function in our `network` with simulation data from the underlying generative model. +For this purpose we recruit the [`ssms`](https://github.com/AlexanderFengler/ssm_simulators) package again. + + +```python +import pandas as pd +import matplotlib.pyplot as plt + +data = pd.DataFrame(np.zeros((2000, 7), dtype = np.float32), columns = ['v', 'a', 'z', 't', 'theta', 'rt', 'choice']) +data['v'] = 0.5 +data['a'] = 0.75 +data['z'] = 0.5 +data['t'] = 0.2 +data['theta'] = 0.1 +data['rt'].iloc[:1000] = np.linspace(5, 0, 1000) +data['rt'].iloc[1000:] = np.linspace(0, 5, 1000) +data['choice'].iloc[:1000] = -1 +data['choice'].iloc[1000:] = 1 + +# Network predictions +predict_on_batch_out = network.predict_on_batch(data.values.astype(np.float32)) + +# Simulations +from ssms.basic_simulators import simulator +sim_out = simulator(model = 'angle', + theta = data.values[0, :-2], + n_samples = 2000) +``` + + +```python +# Plot network predictions +plt.plot(data['rt'] * data['choice'], np.exp(predict_on_batch_out), color = 'black', label = 'network') + +# Plot simulations +plt.hist(sim_out['rts'] * sim_out['choices'], bins = 30, histtype = 'step', label = 'simulations', color = 'blue', density = True) +plt.legend() +plt.title('SSM likelihood') +plt.xlabel('rt') +plt.ylabel('likelihod') +``` + + + + + Text(0, 0.5, 'likelihod') + + + + + +![png](basic_tutorial_files/basic_tutorial_22_1.png) + + + +### TorchMLP to ONNX Converter + +The `transform_onnx.py` script converts a TorchMLP model to the ONNX format. It takes a network configuration file (in pickle format), a state dictionary file (Torch model weights), the size of the input tensor, and the desired output ONNX file path. + +### Usage + +```python onnx/transform_onnx.py ``` + +Replace the placeholders with the appropriate values: + +- : Path to the pickle file containing the network configuration. +- : Path to the file containing the state dictionary of the model. +- : The size of the input tensor for the model (integer). +- : Path to the output ONNX file. + +For example: + +``` +python onnx/transform_onnx.py '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch__network_config.pickle' '0d9f0e94175b11eca9e93cecef057438_lca_no_bias_4_torch_state_dict.pt' 11 'lca_no_bias_4_torch.onnx' +``` +This onnx file can be used directly with the [`HSSM`](https://github.com/lnccbrown/HSSM) package. + +We hope this package may be helpful in case you attempt to train [LANs](https://elifesciences.org/articles/65074) for your own research. + +#### END + diff --git a/docs/overrides/main.html b/docs/overrides/main.html new file mode 100644 index 0000000..d6e01b3 --- /dev/null +++ b/docs/overrides/main.html @@ -0,0 +1,18 @@ +{% extends "base.html" %} + +{% block announce %} + + + {% include ".icons/fontawesome/solid/angles-down.svg" %} + + Navigate the site here! + + + v0.4.0 is out! + + + + {% include ".icons/material/head-question.svg" %} + + +{% endblock %} \ No newline at end of file diff --git a/lanfactory/__init__.py b/lanfactory/__init__.py index 6eab8a3..1e9aa45 100755 --- a/lanfactory/__init__.py +++ b/lanfactory/__init__.py @@ -3,5 +3,6 @@ from . import config from . import trainers from . import utils +from . import onnx -__all__ = ["config", "trainers", "utils"] +__all__ = ["config", "trainers", "utils", "onnx"] diff --git a/lanfactory/onnx/__init__.py b/lanfactory/onnx/__init__.py new file mode 100644 index 0000000..e088a1e --- /dev/null +++ b/lanfactory/onnx/__init__.py @@ -0,0 +1,3 @@ +from .transform_onnx import transform_to_onnx + +__all__ = ["transform_to_onnx"] \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 45fb64d..fbc0878 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,10 +9,11 @@ nav: - Basic Tutorial: - Installation: basic_tutorial/basic_tutorial.ipynb - API: - - ssms: api/ssms.md - - basic simulators: api/basic_simulators.md + - lanfactory: api/lanfactory.md - config: api/config.md - - data generators: api/dataset_generators.md + - onnx: api/onnx.md + - trainers: api/trainers.md + - utils: api/utils.md plugins: - search From 6b94c48996364d5ab119d0414a0117383dc69eb4 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 22:23:07 -0400 Subject: [PATCH 4/8] adjust version --- docs/overrides/main.html | 2 +- lanfactory/__init__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/overrides/main.html b/docs/overrides/main.html index d6e01b3..abdd7e9 100644 --- a/docs/overrides/main.html +++ b/docs/overrides/main.html @@ -8,7 +8,7 @@ Navigate the site here! - v0.4.0 is out! + v0.4.1 is out! diff --git a/lanfactory/__init__.py b/lanfactory/__init__.py index 1e9aa45..64b684c 100755 --- a/lanfactory/__init__.py +++ b/lanfactory/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" from . import config from . import trainers diff --git a/pyproject.toml b/pyproject.toml index 152c004..33fdca4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = ["setuptools", "wheel"] [project] name= "lanfactory" -version= "0.4.0" +version= "0.4.1" authors= [{name = "Alexander Fenger", email = "alexander_fengler@brown.edu"}] description= "Package with convenience functions to train LANs" readme = "README.md" From 485ffca5b563b411371d2eec67733bd63b2b46dc Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 22:33:04 -0400 Subject: [PATCH 5/8] link to documentation in readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d0a6474..0eb7877 100755 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ Lightweight python package to help with training [LANs](https://elifesciences.org/articles/65074) (Likelihood approximation networks). +Please find the original [documentation here](https://alexanderfengler.github.io/LANfactory/). + ### Quick Start The `LANfactory` package is a light-weight convenience package for training `likelihood approximation networks` (LANs) in torch (or keras), From 746936472d06af23d18ea73d0808f5db758b4b31 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 22:43:57 -0400 Subject: [PATCH 6/8] slight change to manifest file --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index 6abdc87..5db5da4 100755 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ +recursive-include docs include README.md LICENSE include lanfactory *.py include notebooks/* From 0fdeca044c8263dbbbe560abb17b669e77366530 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 22:52:46 -0400 Subject: [PATCH 7/8] make some adjustments to git workflow for pushing to pypi --- .github/workflows/build_wheels.yml | 42 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 520fde9..553a94e 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -11,28 +11,28 @@ on: - published jobs: - build_wheels: - name: Build wheels on ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-22.04, windows-2022, macos-11] + # build_wheels: + # name: Build wheels on ${{ matrix.os }} + # runs-on: ${{ matrix.os }} + # strategy: + # matrix: + # os: [ubuntu-22.04, windows-2022, macos-11] - steps: - - uses: actions/checkout@v4 + # steps: + # - uses: actions/checkout@v4 - - name: Build wheels - uses: pypa/cibuildwheel@v2.15.0 - env: - CIBW_BUILD: cp39-* cp310-* - CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" - CIBW_ARCHS_LINUX: auto64 - CIBW_ARCHS_WINDOWS: auto64 - CIBW_BUILD_FRONTEND: build + # - name: Build wheels + # uses: pypa/cibuildwheel@v2.15.0 + # env: + # CIBW_BUILD: cp39-* cp310-* + # CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" + # CIBW_ARCHS_LINUX: auto64 + # CIBW_ARCHS_WINDOWS: auto64 + # CIBW_BUILD_FRONTEND: build - - uses: actions/upload-artifact@v3 - with: - path: ./wheelhouse/*.whl + # - uses: actions/upload-artifact@v3 + # with: + # path: ./wheelhouse/*.whl build_sdist: name: Build source distribution @@ -48,7 +48,7 @@ jobs: path: dist/*.tar.gz upload_test_pypi: - needs: [build_wheels, build_sdist] + needs: [build_sdist] #[build_wheels, build_sdist] runs-on: ubuntu-latest if: github.event_name == 'release' && github.event.action == 'published' # or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this) @@ -67,7 +67,7 @@ jobs: repository-url: https://test.pypi.org/legacy/ upload_pypi: - needs: [build_wheels, build_sdist, upload_test_pypi] + needs: [build_sdist, upload_test_pypi] #[build_wheels, build_sdist, upload_test_pypi] runs-on: ubuntu-latest # Add these back after setting up trusted publishing # environment: pypi From 4289e8006983e1866302b5ade757aa90e90712e7 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Sat, 16 Sep 2023 22:58:08 -0400 Subject: [PATCH 8/8] black --- lanfactory/onnx/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lanfactory/onnx/__init__.py b/lanfactory/onnx/__init__.py index e088a1e..8db9b26 100644 --- a/lanfactory/onnx/__init__.py +++ b/lanfactory/onnx/__init__.py @@ -1,3 +1,3 @@ from .transform_onnx import transform_to_onnx -__all__ = ["transform_to_onnx"] \ No newline at end of file +__all__ = ["transform_to_onnx"]