diff --git a/notebooks/README.md b/notebooks/README.md index 351aef1..76b1200 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -7,4 +7,6 @@ This directory contains example notebooks that show how to use various aspects o | [generate_rosetta_dataset.ipynb](generate_rosetta_dataset.ipynb) | Generate a Rosetta pretraining dataset using molecular simulations data obtained from the [metl-sim](https://github.com/gitter-lab/metl-sim) repository. | | [train_test_split.ipynb](train_test_split.ipynb) | Create train, validation, and test splits for experimental datasets. | | [pretraining.ipynb](pretraining.ipynb) | Pretrain METL models with Rosetta data. | -| [finetuning.ipynb](finetuning.ipynb) | Finetune METL models with experimental data. | \ No newline at end of file +| [finetuning.ipynb](finetuning.ipynb) | Finetune METL models with experimental data. | +| [colab_finetuning.ipynb](colab_finetuning.ipynb) | Finetune METL models with experimental data on Colab. | +| [colab_predicting.ipynb](colab_predicting.ipynb) | Predict with METL models with on Colab. | \ No newline at end of file diff --git a/notebooks/colab_finetuning.ipynb b/notebooks/colab_finetuning.ipynb new file mode 100644 index 0000000..7401b7b --- /dev/null +++ b/notebooks/colab_finetuning.ipynb @@ -0,0 +1,2039 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e30ea18e-6b5a-47d4-b7a4-1330804b5602", + "metadata": { + "id": "e30ea18e-6b5a-47d4-b7a4-1330804b5602" + }, + "source": [ + "# Finetune on experimental data\n", + "This notebook demonstrates how to finetune METL models on experimental data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "vZx7K4mpi4w1", + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vZx7K4mpi4w1", + "outputId": "08606a43-e6f3-4967-e09c-a05bda6e2fed" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'metl'...\n", + "remote: Enumerating objects: 416, done.\u001b[K\n", + "remote: Counting objects: 100% (416/416), done.\u001b[K\n", + "remote: Compressing objects: 100% (280/280), done.\u001b[K\n", + "remote: Total 416 (delta 166), reused 330 (delta 98), pack-reused 0 (from 0)\u001b[K\n", + "Receiving objects: 100% (416/416), 18.08 MiB | 14.06 MiB/s, done.\n", + "Resolving deltas: 100% (166/166), done.\n", + "/content/metl\n" + ] + } + ], + "source": [ + "# @title Cloning metl\n", + "!git clone https://github.com/gitter-lab/metl.git\n", + "%cd metl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "vl7ugAoEjNFQ", + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vl7ugAoEjNFQ", + "outputId": "37a27f98-3c5a-4351-93b0-2a950970e7fe" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2024-08-28 22:01:31-- https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n", + "Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.191.158, 104.16.32.241, 2606:4700::6810:20f1, ...\n", + "Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.191.158|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 148981743 (142M) [application/octet-stream]\n", + "Saving to: ‘./miniconda.sh’\n", + "\n", + "./miniconda.sh 100%[===================>] 142.08M 87.1MB/s in 1.6s \n", + "\n", + "2024-08-28 22:01:32 (87.1 MB/s) - ‘./miniconda.sh’ saved [148981743/148981743]\n", + "\n", + "PREFIX=/usr/local\n", + "Unpacking payload ...\n", + "\n", + "Installing base environment...\n", + "\n", + "Preparing transaction: ...working... done\n", + "Executing transaction: ...working... done\n", + "installation finished.\n", + "WARNING:\n", + " You currently have a PYTHONPATH environment variable set. This may cause\n", + " unexpected behavior when running the Python interpreter in Miniconda3.\n", + " For best results, please verify that your PYTHONPATH only points to\n", + " directories of packages that are compatible with the Python interpreter\n", + " in Miniconda3: /usr/local\n", + "Channels:\n", + " - conda-forge\n", + " - defaults\n", + " - pytorch\n", + "Platform: linux-64\n", + "Collecting package metadata (repodata.json): ...working... done\n", + "Solving environment: ...working... done\n", + "Preparing transaction: ...working... done\n", + "Verifying transaction: ...working... done\n", + "Executing transaction: ...working... By downloading and using the CUDA Toolkit conda packages, you accept the terms and conditions of the CUDA End User License Agreement (EULA): https://docs.nvidia.com/cuda/eula/index.html\n", + "\n", + "done\n", + "Installing pip dependencies: ...working... done\n" + ] + } + ], + "source": [ + "# @title Setting up conda to download notebook dependencies (this takes a while)\n", + "# @markdown This step may take 10-20 minutes.\n", + "\n", + "!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ./miniconda.sh\n", + "!chmod +x miniconda.sh\n", + "!bash ./miniconda.sh -b -u -p /usr/local\n", + "!conda env update -q -n base -f ./environment.yml" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "432eebaf-00b8-42bf-b927-fd651e6ab94d", + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-16T22:51:36.573559Z", + "start_time": "2024-02-16T22:51:36.569490Z" + }, + "id": "432eebaf-00b8-42bf-b927-fd651e6ab94d", + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c566507e-1012-4415-82ba-7498950e0b6c", + "metadata": { + "id": "c566507e-1012-4415-82ba-7498950e0b6c" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "sys.path.append('/usr/local/lib/python3.9/site-packages')\n", + "# define the name of the project root directory\n", + "project_root_dir_name = \"metl\"\n", + "\n", + "# find the project root by checking each parent directory\n", + "current_dir = os.getcwd()\n", + "while os.path.basename(current_dir) != project_root_dir_name and current_dir != os.path.dirname(current_dir):\n", + " current_dir = os.path.dirname(current_dir)\n", + "\n", + "# change the current working directory to the project root directory\n", + "if os.path.basename(current_dir) == project_root_dir_name:\n", + " os.chdir(current_dir)\n", + "else:\n", + " print(\"project root directory not found\")\n", + "\n", + "# add the project code folder to the system path so imports work\n", + "module_path = os.path.abspath(\"code\")\n", + "if module_path not in sys.path:\n", + " sys.path.append(module_path)" + ] + }, + { + "cell_type": "markdown", + "id": "19876208-66f9-46b5-8f50-8e798fa815a4", + "metadata": { + "id": "19876208-66f9-46b5-8f50-8e798fa815a4" + }, + "source": [ + "# Acquire an experimental dataset\n", + "\n", + "For demonstration purposes, this repository contains the [avGFP dataset](https://github.com/gitter-lab/metl/tree/main/data/dms_data/avgfp) from [Sarkisyan et al. (2016)](https://doi.org/10.1038/nature17995).\n", + "See the [metl-pub](https://github.com/gitter-lab/metl-pub) repository to access the other experimental datasets we used in our preprint.\n", + "See the README in the [dms_data](https://github.com/gitter-lab/metl/tree/main/data/dms_data/) directory for information about how to use your own experimental dataset." + ] + }, + { + "cell_type": "markdown", + "id": "d6abf8b1-aa2d-4055-9184-d962ba0d4582", + "metadata": { + "id": "d6abf8b1-aa2d-4055-9184-d962ba0d4582" + }, + "source": [ + "# Acquire a pretrained model\n", + "Pretrained METL models are available in the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) repository. You can use one of those, or you can pretrain your own METL model (see [pretraining.ipynb](https://github.com/gitter-lab/metl/blob/main/notebooks/pretraining.ipynb)).\n", + "\n", + "For demonstration purposes, we include a pretrained avGFP METL-Local model from the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) repository in the [pretrained_models](https://github.com/gitter-lab/metl/tree/main/pretrained_models) directory. This model is `METL-L-2M-3D-GFP` (UUID: `Hr4GNHws`).\n", + "It is the avGFP METL-Local source model we used for the analysis in our preprint.\n", + "\n", + "We will show how to finetune this model using the [experimental avGFP dataset](https://github.com/gitter-lab/metl/tree/main/data/dms_data/avgfp)." + ] + }, + { + "cell_type": "markdown", + "id": "23a30235-357a-4326-a4ff-77ab26eb5d7f", + "metadata": { + "id": "23a30235-357a-4326-a4ff-77ab26eb5d7f" + }, + "source": [ + "# Training arguments\n", + "\n", + "The script for finetuning on experimental data is [train_target_model.py](https://github.com/gitter-lab/metl/blob/main/code/train_target_model.py). This script has a number of arguments you can view by uncommenting and running the below cell. There are additional arguments related to architecture that won't show up if you run the command, but you can view them in [models.py](https://github.com/gitter-lab/metl/tree/main/code/models.py) in the `TransferModel` class." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bca8aeea-3dc3-47eb-915c-d80132be8fef", + "metadata": { + "id": "bca8aeea-3dc3-47eb-915c-d80132be8fef" + }, + "outputs": [], + "source": [ + "# !python code/train_target_model.py -h" + ] + }, + { + "cell_type": "markdown", + "id": "7ec8c31b-2da2-4ba7-9f4e-39e30dce8056", + "metadata": { + "id": "7ec8c31b-2da2-4ba7-9f4e-39e30dce8056" + }, + "source": [ + "We set up finetuning arguments for this example in [finetune_avgfp_local.txt](https://github.com/gitter-lab/metl/tree/main/args/pretrain_avgfp_local.txt) in the [args](https://github.com/gitter-lab/metl/tree/main/args) directory. This argument file can be used directly with [train_target_model.py](https://github.com/gitter-lab/metl/blob/main/code/train_target_model.py) by calling the command `!python code/train_target_model.py @args/finetune_avgfp_local.txt` (we do this in the next section).\n", + "\n", + "Uncomment and run the cell below to view the contents of the argument file. The sections below will walk through and explain the key arguments." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a06a897f-877d-4e41-9bee-4d3eabeead7d", + "metadata": { + "id": "a06a897f-877d-4e41-9bee-4d3eabeead7d" + }, + "outputs": [], + "source": [ + "# with open(\"args/finetune_avgfp_local.txt\", \"r\") as file:\n", + "# contents = file.read()\n", + "# print(contents)" + ] + }, + { + "cell_type": "markdown", + "id": "c2610124-fa2c-4709-98fc-bae51b258338", + "metadata": { + "id": "c2610124-fa2c-4709-98fc-bae51b258338" + }, + "source": [ + "## Dataset arguments" + ] + }, + { + "cell_type": "markdown", + "id": "9f56ee90-90be-41fa-bc99-c13f94e14976", + "metadata": { + "id": "9f56ee90-90be-41fa-bc99-c13f94e14976" + }, + "source": [ + "\n", + "Specify the dataset name and the train/val/test split. The dataset must be defined in [datasets.yml](https://github.com/gitter-lab/metl/tree/main/data/dms_data/datasets.yml). For demonstration purposes, we are using one of the reduced dataset size splits with a dataset size of 160 (train size of 128).\n", + "```\n", + "--ds_name\n", + "avgfp\n", + "--split_dir\n", + "data/dms_data/avgfp/splits/resampled/resampled_ds160_val0.2_te0.1_w1abc2f4e9a64_s1_r8099/resampled_ds160_val0.2_te0.1_w1abc2f4e9a64_s1_r8099_rep_0\n", + "```\n", + "\n", + "Specify the names of the train, validation, and test set files in the split directory. Using \"auto\" for the test_name will select the super test set (\"stest.txt\") if it exists in the split directory, otherwise it will use the standard test set (\"test.txt\").\n", + "\n", + "```\n", + "--train_name\n", + "train\n", + "--val_name\n", + "val\n", + "--test_name\n", + "test\n", + "```\n", + "\n", + "The name of the target column in the dataset dataframe. The model will be finetuned to predict the score in this column.\n", + "\n", + "```\n", + "--target_names\n", + "score\n", + "```\n", + "\n", + "The METL-Local model we are finetuning uses 3D structure-based relative position embeddings, so we need to specify the PDB filename. This PDB file is in the [data/pdb_files](https://github.com/gitter-lab/metl/tree/main/data/pdb_files) directory, which the script checks by default, so there is no need to specify the full path. You can also just specify \"auto\" to use the PDB file defined for this dataset in [datasets.yml](https://github.com/gitter-lab/metl/tree/main/data/dms_data/datasets.yml).\n", + "\n", + "```\n", + "--pdb_fn\n", + "1gfl_cm.pdb\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "890cea13-feae-4e54-bf0f-dcbe97f4409f", + "metadata": { + "id": "890cea13-feae-4e54-bf0f-dcbe97f4409f" + }, + "source": [ + "## Network architecture arguments" + ] + }, + { + "cell_type": "markdown", + "id": "72ee9762-cae7-4e21-8435-f6dd49781b8c", + "metadata": { + "id": "72ee9762-cae7-4e21-8435-f6dd49781b8c" + }, + "source": [ + "For finetuning, we implemented a special model `transfer_model` that handles pretrained checkpoints with top nets.\n", + "```\n", + "--model_name\n", + "transfer_model\n", + "```\n", + "\n", + "The pretrained checkpoint can be a PyTorch checkpoint (.pt file) downloaded from the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) repository or a PyTorch Lightning checkpoint (.ckpt file) obtained from pretraining a model with this repository.\n", + "```\n", + "--pretrained_ckpt_path\n", + "pretrained_models/Hr4GNHws.pt\n", + "```\n", + "\n", + "The backbone cutoff determines where to cutoff the pretrained model and place the new prediction head. For METL-Local models, we recommend backbone cutoff -1, and for METL-Global models we recommend backbone cutoff -2.\n", + "\n", + "```\n", + "--backbone_cutoff\n", + "-1\n", + "```\n", + "\n", + "The remaining arguments determine the encoding, which should be set to `int_seqs`, whether to use dropout after the backbone cutoff, and the architecture of the new top net. You can leave these values as-is to match what we did for the preprint.\n", + "\n", + "```\n", + "--encoding\n", + "int_seqs\n", + "--dropout_after_backbone\n", + "--dropout_after_backbone_rate\n", + "0.5\n", + "--top_net_type\n", + "linear\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "8d94c112-9770-4a5f-93e0-acf4d9acae16", + "metadata": { + "id": "8d94c112-9770-4a5f-93e0-acf4d9acae16" + }, + "source": [ + "## Finetuning strategy arguments" + ] + }, + { + "cell_type": "markdown", + "id": "7bb96cb6-7815-4efa-9b6f-305df9bb3050", + "metadata": { + "id": "7bb96cb6-7815-4efa-9b6f-305df9bb3050" + }, + "source": [ + "We implemented a dual-phase finetuning strategy. During the first phase, the backbone weights are frozen and only the top net is trained. During the second phase, all the network weights are unfrozen and trained at a reduced learning rate.\n", + "\n", + "The unfreeze_backbone_at_epoch argument determines the training epoch at which to unfreeze the backbone. We train the models for 500 epochs, so the backbone is unfrozen halfway through at epoch 250.\n", + "\n", + "```\n", + "--finetuning\n", + "--finetuning_strategy\n", + "backbone\n", + "--unfreeze_backbone_at_epoch\n", + "250\n", + "--backbone_always_align_lr\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "f51d4584-a0ce-45c3-8fb7-8c34d3a984c3", + "metadata": { + "id": "f51d4584-a0ce-45c3-8fb7-8c34d3a984c3" + }, + "source": [ + "## Optimization arguments" + ] + }, + { + "cell_type": "markdown", + "id": "d90d10e8-21f3-4b9e-8134-99cb053bef13", + "metadata": { + "id": "d90d10e8-21f3-4b9e-8134-99cb053bef13" + }, + "source": [ + "Basic optimizer arguments include the batch size, learning rate, and maximum number of epochs to train for. Unless early stopping is enabled, the model will train for the given number of epochs.\n", + "\n", + "```\n", + "--optimizer\n", + "adamw\n", + "--weight_decay\n", + "0.1\n", + "--batch_size\n", + "128\n", + "--learning_rate\n", + "0.001\n", + "--max_epochs\n", + "500\n", + "--gradient_clip_val\n", + "0.5\n", + "```\n", + "\n", + "The learning rate scheduler we used for finetuning is a dual phase learning rate schedule that matches the dual phase finetuning strategy. Each phase has a linear learning rate warmup for 1% of the total steps in that phase. There is also a cosine decay for the learning rate for each phase. The phase 2 learning rate is 10% of the phase 1 learning rate.\n", + "\n", + "```\n", + "--lr_scheduler\n", + "dual_phase_warmup_constant_cosine_decay\n", + "--warmup_steps\n", + ".01\n", + "--phase2_lr_ratio\n", + "0.1\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "16327f53-7beb-412e-a925-12884e66d70b", + "metadata": { + "id": "16327f53-7beb-412e-a925-12884e66d70b" + }, + "source": [ + "## Logging arguments" + ] + }, + { + "cell_type": "markdown", + "id": "132db93c-85e6-4658-a31e-9b103df34cb7", + "metadata": { + "id": "132db93c-85e6-4658-a31e-9b103df34cb7" + }, + "source": [ + "We have built in functionality for tracking model training with Weights & Biases. If you have a Weights and Biases account, set the environment variable `WANDB_API_KEY` to your API key and set the flag `--use_wandb` instead of `--no_use_wandb` below.\n", + "\n", + "```\n", + "--no_use_wandb\n", + "--wandb_project\n", + "metl-target\n", + "--wandb_online\n", + "--experiment\n", + "default\n", + "```\n", + "\n", + "The below argument determines where to place the log directory locally.\n", + "```\n", + "--log_dir_base\n", + "output/training_logs\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "53a2fda3-6dfa-46d5-ad3d-3055eda0b29a", + "metadata": { + "id": "53a2fda3-6dfa-46d5-ad3d-3055eda0b29a" + }, + "source": [ + "# Running training" + ] + }, + { + "cell_type": "markdown", + "id": "8d3d8d23-9d54-4888-842d-4fc8fd843b40", + "metadata": { + "id": "8d3d8d23-9d54-4888-842d-4fc8fd843b40" + }, + "source": [ + "All the arguments described above are contained in [finetune_avgfp_local.txt](https://github.com/gitter-lab/metl/tree/main/args/pretrain_avgfp_local.txt), which can be fed directly into [train_target_model.py](https://github.com/gitter-lab/metl/blob/main/code/train_target_model.py).\n", + "\n", + "PyTorch Lightning has a built-in progress bar that is convenient for seeing training progress, but it does not display correctly in Jupyter when calling the script with `!python`. We are going to disable the progress bar for by setting the flag `--enable_progress_bar false`. Instead, we implemented a simple print statement to track training progress, which we will enable with the flag `--enable_simple_progress_messages`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "977b4d8d-4662-4e03-955c-dc4a8ae7c1dc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "977b4d8d-4662-4e03-955c-dc4a8ae7c1dc", + "outputId": "a487d3f2-97fd-47ea-d07b-9c0cd21c9fb7" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Random seed not specified, using: 522644021\n", + "Global seed set to 522644021\n", + "Created model UUID: fmngE6sB\n", + "Created log directory: output/training_logs/fmngE6sB\n", + "Final UUID: fmngE6sB\n", + "Final log directory: output/training_logs/fmngE6sB\n", + "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "/usr/local/lib/python3.9/site-packages/torch/utils/data/dataloader.py:563: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", + " warnings.warn(_create_warning_msg(\n", + "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + " rank_zero_warn(\n", + "Number of training steps is 50\n", + "Number of warmup steps is 0.5\n", + "Second warmup phase starts at step 25\n", + "total_steps 50\n", + "phase1_total_steps 25\n", + "phase2_total_steps 25\n", + "┏━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35m In sizes\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35m Out sizes\u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ model │ TransferModel │ 2.4 M │\u001b[37m \u001b[0m\u001b[37m[128, 237]\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m [128, 1]\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ model.model │ SequentialWithArgs │ 2.4 M │\u001b[37m \u001b[0m\u001b[37m[128, 237]\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m [128, 1]\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ model.model.backbone │ SequentialWithArgs │ 2.4 M │\u001b[37m \u001b[0m\u001b[37m[128, 237]\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m[128, 256]\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ model.model.dropout │ Dropout │ 0 │\u001b[37m \u001b[0m\u001b[37m[128, 256]\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m[128, 256]\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m4\u001b[0m\u001b[2m \u001b[0m│ model.model.flatten │ Flatten │ 0 │\u001b[37m \u001b[0m\u001b[37m[128, 256]\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m[128, 256]\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m5\u001b[0m\u001b[2m \u001b[0m│ model.model.prediction │ Linear │ 257 │\u001b[37m \u001b[0m\u001b[37m[128, 256]\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m [128, 1]\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m6\u001b[0m\u001b[2m \u001b[0m│ test_pearson │ PearsonCorrCoef │ 0 │\u001b[37m \u001b[0m\u001b[37m ?\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m ?\u001b[0m\u001b[37m \u001b[0m│\n", + "│\u001b[2m \u001b[0m\u001b[2m7\u001b[0m\u001b[2m \u001b[0m│ test_spearman │ SpearmanCorrCoef │ 0 │\u001b[37m \u001b[0m\u001b[37m ?\u001b[0m\u001b[37m \u001b[0m│\u001b[37m \u001b[0m\u001b[37m ?\u001b[0m\u001b[37m \u001b[0m│\n", + "└───┴────────────────────────┴────────────────────┴────────┴────────────┴────────────┘\n", + "\u001b[1mTrainable params\u001b[0m: 257 \n", + "\u001b[1mNon-trainable params\u001b[0m: 2.4 M \n", + "\u001b[1mTotal params\u001b[0m: 2.4 M \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 9 \n", + "Starting sanity check...\n", + "Sanity check complete.\n", + "Starting training...\n", + "Epoch 0: Train Loss = 2.322, Val Loss = 2.024\n", + "Epoch 1: Train Loss = 2.348, Val Loss = 1.977\n", + "Epoch 2: Train Loss = 2.255, Val Loss = 1.931\n", + "Epoch 3: Train Loss = 2.265, Val Loss = 1.888\n", + "Epoch 4: Train Loss = 2.179, Val Loss = 1.846\n", + "Epoch 5: Train Loss = 2.138, Val Loss = 1.807\n", + "Epoch 6: Train Loss = 2.080, Val Loss = 1.770\n", + "Epoch 7: Train Loss = 2.037, Val Loss = 1.735\n", + "Epoch 8: Train Loss = 1.982, Val Loss = 1.703\n", + "Epoch 9: Train Loss = 1.948, Val Loss = 1.674\n", + "Epoch 10: Train Loss = 1.962, Val Loss = 1.648\n", + "Epoch 11: Train Loss = 1.894, Val Loss = 1.624\n", + "Epoch 12: Train Loss = 1.903, Val Loss = 1.603\n", + "Epoch 13: Train Loss = 1.889, Val Loss = 1.585\n", + "Epoch 14: Train Loss = 1.822, Val Loss = 1.570\n", + "Epoch 15: Train Loss = 1.838, Val Loss = 1.556\n", + "Epoch 16: Train Loss = 1.811, Val Loss = 1.546\n", + "Epoch 17: Train Loss = 1.808, Val Loss = 1.537\n", + "Epoch 18: Train Loss = 1.803, Val Loss = 1.530\n", + "Epoch 19: Train Loss = 1.754, Val Loss = 1.525\n", + "Epoch 20: Train Loss = 1.788, Val Loss = 1.521\n", + "Epoch 21: Train Loss = 1.779, Val Loss = 1.519\n", + "Epoch 22: Train Loss = 1.771, Val Loss = 1.518\n", + "Epoch 23: Train Loss = 1.790, Val Loss = 1.517\n", + "Epoch 24: Train Loss = 1.803, Val Loss = 1.517\n", + "Epoch 25: Train Loss = 1.824, Val Loss = 1.517\n", + "Epoch 26: Train Loss = 1.773, Val Loss = 1.486\n", + "Epoch 27: Train Loss = 1.741, Val Loss = 1.455\n", + "Epoch 28: Train Loss = 1.674, Val Loss = 1.425\n", + "Epoch 29: Train Loss = 1.671, Val Loss = 1.394\n", + "Epoch 30: Train Loss = 1.592, Val Loss = 1.365\n", + "Epoch 31: Train Loss = 1.603, Val Loss = 1.335\n", + "Epoch 32: Train Loss = 1.581, Val Loss = 1.307\n", + "Epoch 33: Train Loss = 1.526, Val Loss = 1.279\n", + "Epoch 34: Train Loss = 1.489, Val Loss = 1.253\n", + "Epoch 35: Train Loss = 1.445, Val Loss = 1.228\n", + "Epoch 36: Train Loss = 1.375, Val Loss = 1.203\n", + "Epoch 37: Train Loss = 1.394, Val Loss = 1.181\n", + "Epoch 38: Train Loss = 1.337, Val Loss = 1.160\n", + "Epoch 39: Train Loss = 1.358, Val Loss = 1.142\n", + "Epoch 40: Train Loss = 1.326, Val Loss = 1.126\n", + "Epoch 41: Train Loss = 1.259, Val Loss = 1.111\n", + "Epoch 42: Train Loss = 1.200, Val Loss = 1.099\n", + "Epoch 43: Train Loss = 1.180, Val Loss = 1.090\n", + "Epoch 44: Train Loss = 1.148, Val Loss = 1.082\n", + "Epoch 45: Train Loss = 1.126, Val Loss = 1.076\n", + "Epoch 46: Train Loss = 1.182, Val Loss = 1.072\n", + "Epoch 47: Train Loss = 1.177, Val Loss = 1.070\n", + "Epoch 48: Train Loss = 1.168, Val Loss = 1.069\n", + "Epoch 49: Train Loss = 1.091, Val Loss = 1.069\n", + "`Trainer.fit` stopped: `max_epochs=50` reached.\n", + "Restoring states from the checkpoint path at output/training_logs/fmngE6sB/checkpoints/epoch=49-step=50.ckpt\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Loaded model weights from checkpoint at output/training_logs/fmngE6sB/checkpoints/epoch=49-step=50.ckpt\n", + "/usr/local/lib/python3.9/site-packages/torch/utils/data/dataloader.py:563: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", + " warnings.warn(_create_warning_msg(\n", + "Starting testing...\n", + "Testing complete.\n", + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.2109477519989014 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_pearson \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6558916568756104 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_spearman \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6332594752311707 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n", + "Restoring states from the checkpoint path at output/training_logs/fmngE6sB/checkpoints/epoch=49-step=50.ckpt\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Loaded model weights from checkpoint at output/training_logs/fmngE6sB/checkpoints/epoch=49-step=50.ckpt\n", + "Starting prediction...\n", + "Prediction complete.\n", + "saving a scatter plot for set: train (128 variants)\n", + "saving a scatter plot for set: val (32 variants)\n", + "saving a scatter plot for set: test (4655 variants)\n", + " mse pearsonr r2 spearmanr\n", + "set \n", + "train 1.286097 0.734186 -0.102364 0.693757\n", + "val 1.069018 0.737532 0.008699 0.725690\n", + "test 1.210948 0.655892 -0.087824 0.633260\n" + ] + } + ], + "source": [ + "!python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25" + ] + }, + { + "cell_type": "markdown", + "id": "f33fc407-6ab1-45e3-8e6a-9b717dca7f00", + "metadata": { + "id": "f33fc407-6ab1-45e3-8e6a-9b717dca7f00" + }, + "source": [ + "# Additional recommendations" + ] + }, + { + "cell_type": "markdown", + "id": "39c8e0e5-8bb5-4200-ab45-e559b0f20896", + "metadata": { + "id": "39c8e0e5-8bb5-4200-ab45-e559b0f20896", + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Model selection\n", + "\n", + "Selecting the model from the epoch with the lowest validation set loss can help prevent overfitting. It requires having a big enough validation set that provides an accurate estimate of performance.\n", + "\n", + "We enabled model selection if the validation set size was ≥ 32 for METL-Local and ≥ 128 for METL-Global. We found the optimization was more stable for METL-Local than METL-Global, thus smaller validation sets were still reliable.\n", + "\n", + "Enable model selection by setting argument `--ckpt_monitor val`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "b18f773b-8209-4993-b3f0-994b0ab2b133", + "metadata": { + "id": "b18f773b-8209-4993-b3f0-994b0ab2b133" + }, + "source": [ + "## Backbone cutoff for METL-Global\n", + "Finetuning METL-Global is largely the same as METL-Local, except we recommend using a different threshold for model selection (see above), as well as a different backbone cutoff.\n", + "\n", + "For METL-Local, we set `--backbone_cutoff -1`, which attaches the new prediction head immediately after the final fully connected layer.\n", + "\n", + "For METL-Global, we recommend setting `--backbone_cutoff -2`, which attaches the new prediction head immediately after the global pooling layer. We found this resulted in better finetuning performance for METL-Global." + ] + }, + { + "cell_type": "markdown", + "id": "2a591eb8-3d5e-437f-9189-3c0834f7f447", + "metadata": { + "id": "2a591eb8-3d5e-437f-9189-3c0834f7f447" + }, + "source": [ + "# Running inference using finetuned model" + ] + }, + { + "cell_type": "markdown", + "id": "af85ff8f-1a30-4ba2-bf3b-967a773e0e80", + "metadata": { + "id": "af85ff8f-1a30-4ba2-bf3b-967a773e0e80" + }, + "source": [ + "The PyTorch Lightning framework supports inference, but while we put together a working example, we recommend converting the PyTorch Lightning checkpoint to pure PyTorch and using the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) package to run inference in pure PyTorch." + ] + }, + { + "cell_type": "markdown", + "id": "1acca5d1-1bca-4c3f-b9d3-56525cf11186", + "metadata": { + "id": "1acca5d1-1bca-4c3f-b9d3-56525cf11186" + }, + "source": [ + "## Convert to PyTorch\n", + "Lightning checkpoints are compatible with pure pytorch, but they may contain additional items that are not needed for inference. This script loads the checkpoint and saves a smaller checkpoint with just the model weights and hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "63d8ce0a-5534-406f-90b6-6c155cb6ea9c", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "63d8ce0a-5534-406f-90b6-6c155cb6ea9c", + "outputId": "3406c304-902c-425d-ebe5-65d03c3b480c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Processing checkpoint: output/training_logs/fmngE6sB/checkpoints/epoch=49-step=50.ckpt\n", + "Saving converted checkpoint to: output/training_logs/fmngE6sB/checkpoints/fmngE6sB.pt\n" + ] + } + ], + "source": [ + "# the Lightning checkpoint from the finetuning we performed above\n", + "fine_tuning_dir_name = os.listdir('output/training_logs')[0]\n", + "\n", + "ckpt_fn = f\"output/training_logs/{fine_tuning_dir_name}/checkpoints/epoch=49-step=50.ckpt\"\n", + "\n", + "# run the conversion script\n", + "!python code/convert_ckpt.py --ckpt_path $ckpt_fn" + ] + }, + { + "cell_type": "markdown", + "id": "98b562aa-663a-4b0d-a719-e85555cf875d", + "metadata": { + "id": "98b562aa-663a-4b0d-a719-e85555cf875d" + }, + "source": [ + "## Load checkpoint with metl-pretrained package\n", + "Using the Hugging Face wrapper, we can load the metl library and use it to load our newly trained model checkpoint and run inference with it." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ff35ce69-97ed-4a5a-b082-f197aae1addc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 341, + "referenced_widgets": [ + "b7b47f570c8c43b6bf2230e6a44d3e09", + "219fdb01577845ea8fe0ad54a77cf65d", + "275bb85ac0104926a7988da85947b332", + "f7f2291a716346888c03cf0481b3b4b7", + "55be7aa5f252407bbe8e00f1c182fb22", + "ac4e32e574354ec095f92678a0313f4e", + "5c3b1f7e1e75469da8d453d63c762aad", + "a2f6a37975d54b418fde8afdeb86ffa0", + "6f0c2bfb83ab494fa67af09ad82f1a19", + "1dea0603d40b4d4f84193876183ebfbc", + "cdf78bd9b62d4d5c9809f9ff99cd296b", + "adf1bab3a2cc4ceb94031be73340031e", + "a9e9d6e3751f42b0857e3c6736af3fc7", + "98ffe8af93d5480e94c00ba5f09db454", + "cc76bf4678fd49148f5f25236daa43d2", + "908beed2fd40426db19081750bfac4fb", + "2b5138e0f130467d9af35736eb8bce41", + "865015002a294027951beda6cbc1a07c", + "5642d3dea048402abf2cad21c25a4c01", + "7e920825842744eeb4d00631b5ffc08a", + "4d0d620efc564efa8e0e069c59a14460", + "79b93dcc01304185a65877e0a51e861d", + "ef9b97e9f9434dc79b2201bde062a0c4", + "e6df8ef3f9ca4e17a76a093c753a5b47", + "f78d7b0a73974cbea0d992824c9a795f", + "16a9352d655043e682c1cb20f8214479", + "f12897838468411c94e8126a56890057", + "5d4f099cd3624ae9adc500aa93a80ae0", + "c4250b0106404bc78368f17819f51fde", + "dacad6e28f7e4b369dbf7f03c88fa00e", + "ae47fbb36d6a4f4ba3394a2c85fb599e", + "95c96e9e3f2e4ad48cf08e55a4341d1e", + "ace8a8608099458e84cfb991ad462631" + ] + }, + "id": "ff35ce69-97ed-4a5a-b082-f197aae1addc", + "outputId": "e549e68a-df63-42e2-f687-c6923882db07" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "config.json: 0%| | 0.00/269 [00:00] 76.92K --.-KB/s in 0.01s \n", + "\n", + "2024-08-28 21:50:07 (6.60 MB/s) - ‘2qmt_p.pdb’ saved [78764/78764]\n", + "\n" + ] + } + ], + "source": [ + "# @title Download the example pdb file\n", + "!wget -O 2qmt_p.pdb https://raw.githubusercontent.com/gitter-lab/metl-pretrained/main/pdbs/2qmt_p.pdb" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "cellView": "form", + "id": "mpjvGjzArLF1" + }, + "outputs": [], + "source": [ + "# @title Importing required libraries\n", + "from transformers import AutoModel, AutoConfig, logging\n", + "import ipywidgets as widgets\n", + "from IPython.display import clear_output, HTML, display\n", + "import pandas as pd\n", + "import torch\n", + "import io\n", + "import json\n", + "import biopandas\n", + "\n", + "logging.set_verbosity_error()\n", + "\n", + "# Declaring this here so that it's available regardless if later cells are run or not\n", + "variant_file = None\n", + "pdb_file_path = '2qmt_p.pdb'\n", + "\n", + "def to_zero_based(variants):\n", + " zero_based = []\n", + " for line in variants:\n", + " line_as_json = json.loads(line)\n", + " new_variants = []\n", + " for variant in line_as_json:\n", + " new_variant = []\n", + " mutations = variant.split(',')\n", + " for mutation in mutations:\n", + " residue_zero_based = int(mutation[1:-1]) - 1\n", + " new_variant.append(f\"{mutation[0]}{residue_zero_based}{mutation[-1]}\")\n", + " new_variants.append(\",\".join(new_variant))\n", + " zero_based.append(new_variants)\n", + "\n", + " return zero_based" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JWJCsBukrLF7" + }, + "source": [ + "We will then load a METL model through the 🤗 API. trust_remote_code=True is required to use METL models through 🤗." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 237, + "referenced_widgets": [ + "548aee36a23840da8d2d9112c1a126e4", + "9875bd75c19140e8982cfc83ab405980", + "08b5e6a111014f50be5b549dc61c04bd", + "47cb28f509c648a1a435b946bfc07518", + "541c0af07fc64fcfabaa49cd90ccacd2", + "7df715d2ab2d45adb6cf04e1b908fe34", + "9df76c111df140c3b7dfadb724aae88f", + "43f1996def8b4fca9ef92d15e935f1ba", + "1cea69ccb7ab49eab6b3ceccc60f846b", + "b8b4d2ec394547cc86dcfe1eddaf0402", + "579c2d34818e4ff0ba157532d81edfb7", + "08ce7a7c160f48488a1ccfff9e52fc19", + "289c3d37c529432d9f2d4ee502b427b5", + "634f94c13ef74971a26f4988ade714fa", + "55eb8ffb95564d25acb8b40d2ac86b0c", + "f935b17bb46649b5a81733fb0e5529ab", + "1ee82a5c2acc4c7fb0c63ddbb30b471c", + "67a7158ed55842d38aa55392ab1d3744", + "812642ecbb744904a74c1dbcf414cea9", + "c38ea535ff6e4d60a772700d1667c7c1", + "5b68d66d5ceb473eaf8e45a979e58e8f", + "1f6ef9a762f841f4acf9fe8edd589b2b", + "6eabe0cb30a142f2b79185efb52f6bb0", + "a05042c291aa4e1e9bb4f09be3d9f806", + "3f9981ff162245f49f533d90c0100a9c", + "531cde6e3d8f44009ab3480338bb316e", + "8ae64abd8d234c218a8fb4e92a059377", + "640f95344552481aa56d9d17398a9b45", + "3fcaf29e93bc4f0ea5293b761fbe9194", + "69bd8895c8f948c7b879b1c869539334", + "61c6494ff4fb4f7a888e3129ac641df4", + "7243419b29bc471fa4a8ef56d6041f47", + "a4b4fdbd7d5648a1ab7850e2f81ead15" + ] + }, + "id": "vIp69-KLrLF8", + "outputId": "27117ccf-b5ca-4906-c532-f23bd26f65dc" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "548aee36a23840da8d2d9112c1a126e4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/269 [00:00\n", + " .variant_text_area > textarea::placeholder {\n", + " color: var(--colab-primary-text-color);\n", + " }\n", + "\n", + " .variant_text_area > textarea {\n", + " background-color: var(--colab-secondary-surface-color);\n", + " color: var(--colab-primary-text-color);\n", + " }\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9d004e2783e54df7a8dca1952ec2434d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Textarea(value='', description='Variant String:', layout=Layout(height='100px', width='500px'), placeholder='[…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# @title Variant text input\n", + "# @markdown The placehold variants below use 0-based indexing.\n", + "variants_string = \"\"\"[\"T17P,T54F\", \"V28L,F51A\", \"T17P,V28L,F51A,T54F\"]\n", + "[\"T13P,T33F\"]\"\"\"\n", + "style = {'description_width':'initial'}\n", + "\n", + "variant_text = widgets.Textarea(\n", + " value='',\n", + " placeholder=variants_string,\n", + " description='Variant String:',\n", + " disabled=False,\n", + " style = style,\n", + " layout=widgets.Layout(height='100px', width='500px'),\n", + ")\n", + "\n", + "variant_text.add_class('variant_text_area')\n", + "\n", + "style = \"\"\"\n", + "\n", + "\"\"\"\n", + "\n", + "display(HTML(style))\n", + "display(variant_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ggs9zSgPHOYj" + }, + "source": [ + "If you would rather upload a file, run the cell below and use it to upload a file. If a file is uploaded, the input above will not be looked at for variants\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "7b4e21e8ad294346b7552f231dee95b4", + "5141f5b94903473c9e89d2ec2ed8e5c4", + "a03c2859b511486da91a88babee36f63" + ] + }, + "id": "8TLYi6orHN_8", + "outputId": "d6a92501-fa48-4e4c-fe90-6612a5cd57a1" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7b4e21e8ad294346b7552f231dee95b4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FileUpload(value={}, accept='.json, .txt', description='Upload')" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# @title Variant file upload\n", + "# @markdown If you want to upload a variant JSON file, run this cell and upload the file with the provided button that appears below.\n", + "\n", + "\n", + "def update_variant_file(button_input):\n", + " global variant_file\n", + " for name, data in button_input['new'].items():\n", + " clear_output()\n", + " display(variant_upload)\n", + " print(f'Loaded file: {name}')\n", + " variant_file = data['content'].decode('utf-8').splitlines()\n", + "\n", + "variant_upload = widgets.FileUpload(\n", + " accept='.json, .txt',\n", + " multiple=False\n", + ")\n", + "\n", + "variant_upload.observe(update_variant_file, names='value')\n", + "variant_upload" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mASfltPUrLGI", + "outputId": "d0ff24de-5f1d-44d9-9eca-010f6fdebe91" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using variant placeholder\n" + ] + } + ], + "source": [ + "# @title Variant Selecting Logic (always run this)\n", + "\n", + "clear_output()\n", + "if len(variant_text.value) > 0:\n", + " print(\"Using text area input\")\n", + " variants = variant_text.value\n", + "elif variant_file:\n", + " print(\"Using variants file\")\n", + " variants = variant_file\n", + "else:\n", + " print(\"Using variant placeholder\")\n", + " variants = variant_text.placeholder.splitlines()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YOfEUeNYrLGI" + }, + "source": [ + "For biologists, one-based indexing is commonly used. However, METL models were designed to used zero-based indexing. If one-based indexing is needed, select it in the dropdown below." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "cellView": "form", + "id": "ecLAU9IOrLGK" + }, + "outputs": [], + "source": [ + "# @title Transform input from 1 based indexing to zero based indexing\n", + "# @markdown Select indexing for residue mutations\n", + "indexing = \"0\" # @param ['0', '1']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ixhZbYm8rLGM" + }, + "source": [ + "Since both file and string variants give the same result, we only need to use one moving forwards. We will use the string_variants variable.\n", + "\n", + "To predict with METL, we will need to use the loaded model and encoder with our variables we defined above. We will wrap this in a for loop to predict on all of our variants as we have multiple lines of them." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "cellView": "form", + "id": "th1JyojWrLGN" + }, + "outputs": [], + "source": [ + "# @title METL predicting loop\n", + "output = []\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "metl = metl.to(device)\n", + "\n", + "if indexing == \"1\":\n", + " predict_variants = to_zero_based(variants)\n", + "else:\n", + " predict_variants = variants\n", + "\n", + "for variant in predict_variants:\n", + " # First in METL we need to encode our variants\n", + " if not isinstance(variant, list):\n", + " variant = json.loads(variant)\n", + " encoded_variants = metl.encoder.encode_variants(wildtype, variant)\n", + "\n", + " #Next, we predict\n", + " with torch.no_grad():\n", + " if pdb_file_path:\n", + " predictions = metl(torch.tensor(encoded_variants).to(device), pdb_fn=pdb_file_path)\n", + " else:\n", + " predictions = metl(torch.tensor(encoded_variants).to(device))\n", + "\n", + " output.append({\n", + " \"wt\": wildtype,\n", + " \"variants\": variant,\n", + " \"output\": predictions.tolist()\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 300 + }, + "id": "J5XJ-4memIYp", + "outputId": "982a13f1-6434-4ea8-ce68-8664b94c4ae0" + }, + "outputs": [ + { + "data": { + "application/javascript": "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " \"wt\": \"MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE\",\n", + " \"variants\": [\n", + " \"T17P,T54F\",\n", + " \"V28L,F51A\",\n", + " \"T17P,V28L,F51A,T54F\"\n", + " ],\n", + " \"output\": [\n", + " [\n", + " 0.07013243436813354,\n", + " -0.25929510593414307,\n", + " 0.03156351298093796,\n", + " 0.7507199048995972,\n", + " -0.07624393701553345,\n", + " 0.19459080696105957,\n", + " 0.4965912103652954,\n", + " -1.2189791202545166,\n", + " 0.047027163207530975,\n", + " 2.024402618408203,\n", + " 0.652029275894165,\n", + " 0.38363683223724365,\n", + " -0.29057905077934265,\n", + " 0.889313280582428,\n", + " -0.19636523723602295,\n", + " 0.82305508852005,\n", + " 0.2896294891834259,\n", + " -0.6637641191482544,\n", + " -0.39318302273750305,\n", + " 0.3603857457637787,\n", + " 0.857086181640625,\n", + " 0.9503828287124634,\n", + " 0.3519744873046875,\n", + " 1.4875739812850952,\n", + " 0.10600201040506363,\n", + " 0.33237022161483765,\n", + " -0.3101063072681427,\n", + " -0.2387685775756836,\n", + " -0.5087016224861145,\n", + " 0.686012327671051,\n", + " 0.4524794816970825,\n", + " 0.20570813119411469,\n", + " 0.30475038290023804,\n", + " -0.31177929043769836,\n", + " 0.9250588417053223,\n", + " 0.6042543649673462,\n", + " 1.1186459064483643,\n", + " 0.03992972895503044,\n", + " -0.8309147357940674,\n", + " 0.9451456665992737,\n", + " 1.0041853189468384,\n", + " 0.4524748921394348,\n", + " -0.9685558080673218,\n", + " 0.49829766154289246,\n", + " 1.1165390014648438,\n", + " 0.7435593605041504,\n", + " -0.5233420133590698,\n", + " -0.5309135913848877,\n", + " 2.0981578826904297,\n", + " 0.552436351776123,\n", + " -1.0519232749938965,\n", + " 2.288227081298828,\n", + " 0.01622585952281952,\n", + " 1.7331675291061401,\n", + " -0.1840410679578781\n", + " ],\n", + " [\n", + " 0.08126264810562134,\n", + " 1.445512056350708,\n", + " 0.3407595753669739,\n", + " -0.8155512809753418,\n", + " -0.6581068634986877,\n", + " -0.28225141763687134,\n", + " -0.43325313925743103,\n", + " 0.014442211017012596,\n", + " -0.16074422001838684,\n", + " -0.595990777015686,\n", + " -0.01838766783475876,\n", + " -0.9363240003585815,\n", + " 1.200121521949768,\n", + " 0.10803645849227905,\n", + " -0.47195640206336975,\n", + " -0.15199805796146393,\n", + " -0.2927914261817932,\n", + " -0.40319469571113586,\n", + " -0.5474604964256287,\n", + " -1.711698293685913,\n", + " -0.697638988494873,\n", + " -1.560241460800171,\n", + " -0.7673114538192749,\n", + " -1.1705756187438965,\n", + " 0.45397740602493286,\n", + " 0.5323038697242737,\n", + " 0.6842557787895203,\n", + " 0.37687546014785767,\n", + " -0.377780944108963,\n", + " 0.451775461435318,\n", + " 0.7726845741271973,\n", + " 0.17020709812641144,\n", + " -0.674484372138977,\n", + " 0.18455049395561218,\n", + " -0.5841971635818481,\n", + " -0.5465129613876343,\n", + " -0.9710506796836853,\n", + " 0.015543186105787754,\n", + " -0.1828489601612091,\n", + " -1.1787317991256714,\n", + " -1.264413833618164,\n", + " 0.7726802825927734,\n", + " -0.9820785522460938,\n", + " -0.6351808905601501,\n", + " -0.030753612518310547,\n", + " -0.04128456115722656,\n", + " -0.17912821471691132,\n", + " -0.3816293478012085,\n", + " 0.10558390617370605,\n", + " -0.049853961914777756,\n", + " -0.20711421966552734,\n", + " 0.2800188660621643,\n", + " 0.00046062562614679337,\n", + " 0.5665276050567627,\n", + " -0.15274114906787872\n", + " ],\n", + " [\n", + " 1.2070168256759644,\n", + " 1.7057719230651855,\n", + " -0.09893297404050827,\n", + " 1.4028812646865845,\n", + " -0.5031474828720093,\n", + " -0.1666990965604782,\n", + " 0.14461153745651245,\n", + " -1.1639573574066162,\n", + " -0.3492702841758728,\n", + " 2.2343969345092773,\n", + " 0.782096803188324,\n", + " 0.08824858069419861,\n", + " -0.30225202441215515,\n", + " 0.9438788294792175,\n", + " -0.34925132989883423,\n", + " 0.9428834915161133,\n", + " 0.6026463508605957,\n", + " -1.0658704042434692,\n", + " -0.4869558811187744,\n", + " -1.143844723701477,\n", + " 0.18776728212833405,\n", + " -1.3262150287628174,\n", + " -0.9022694826126099,\n", + " -0.0188913457095623,\n", + " 0.02201986312866211,\n", + " 0.7734603881835938,\n", + " 0.014177754521369934,\n", + " 0.5803621411323547,\n", + " 0.058175165206193924,\n", + " 0.7798131704330444,\n", + " 1.2709546089172363,\n", + " -0.09427222609519958,\n", + " -0.7404829263687134,\n", + " -0.44868460297584534,\n", + " 0.6765958666801453,\n", + " -0.32317063212394714,\n", + " 0.19626344740390778,\n", + " -0.061478253453969955,\n", + " -0.5555139183998108,\n", + " 0.22458764910697937,\n", + " 0.2969154417514801,\n", + " 1.2709496021270752,\n", + " -2.17531681060791,\n", + " 0.6555665731430054,\n", + " 2.507157802581787,\n", + " 0.3506653308868408,\n", + " 0.7168694734573364,\n", + " 0.04246610403060913,\n", + " 2.134085178375244,\n", + " 1.2071470022201538,\n", + " -0.540690541267395,\n", + " 2.7881016731262207,\n", + " -0.0013065459206700325,\n", + " 2.872316360473633,\n", + " -0.18022766709327698\n", + " ]\n", + " ]\n", + " },\n", + " {\n", + " \"wt\": \"MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE\",\n", + " \"variants\": [\n", + " \"T13P,T33F\"\n", + " ],\n", + " \"output\": [\n", + " [\n", + " 5.645543098449707,\n", + " -0.08823452889919281,\n", + " 3.270817518234253,\n", + " 1.2460161447525024,\n", + " 1.3037234544754028,\n", + " 0.5322443246841431,\n", + " 4.857309818267822,\n", + " -1.1067686080932617,\n", + " -0.6436038613319397,\n", + " 1.9419636726379395,\n", + " -0.8911130428314209,\n", + " 0.652341902256012,\n", + " -0.7837158441543579,\n", + " 9.617301940917969,\n", + " 2.1150782108306885,\n", + " 5.2092108726501465,\n", + " 2.7020974159240723,\n", + " -0.5816158652305603,\n", + " -0.5514440536499023,\n", + " 0.43818673491477966,\n", + " 0.4274188280105591,\n", + " 1.6742618083953857,\n", + " -0.8721922636032104,\n", + " 2.444389820098877,\n", + " -1.9845408201217651,\n", + " -1.1729459762573242,\n", + " -0.4310912489891052,\n", + " 0.5365148186683655,\n", + " 0.08791925013065338,\n", + " 0.2855781018733978,\n", + " 0.7514391541481018,\n", + " -0.9945945739746094,\n", + " -0.01072685420513153,\n", + " -2.4255361557006836,\n", + " -0.7554265260696411,\n", + " -1.6911653280258179,\n", + " 1.3066627979278564,\n", + " 0.041956719011068344,\n", + " -0.8304670453071594,\n", + " 1.5064003467559814,\n", + " 1.612711787223816,\n", + " 0.7514318823814392,\n", + " -1.3307963609695435,\n", + " 0.3869069516658783,\n", + " 0.6399444341659546,\n", + " 1.8638325929641724,\n", + " 1.619768738746643,\n", + " -0.7826499938964844,\n", + " 2.801478862762451,\n", + " -0.6386407017707825,\n", + " 0.8158987760543823,\n", + " 0.967729926109314,\n", + " -0.08497878909111023,\n", + " 0.14499327540397644,\n", + " 0.13757933676242828\n", + " ]\n", + " ]\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "# @title Display METL preditions\n", + "from IPython.display import Javascript\n", + "\n", + "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n", + "print(json.dumps(output, indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pO8jdBssrLGO" + }, + "source": [ + "Finally, we will save our output. We will save our output as a list of JSON Objects." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "cellView": "form", + "id": "nYz6NRB_rLGP" + }, + "outputs": [], + "source": [ + "# @title Saving the predictions\n", + "with open('./output.json', 'w') as f:\n", + " f.write(json.dumps(output, indent=2))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "08b5e6a111014f50be5b549dc61c04bd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_43f1996def8b4fca9ef92d15e935f1ba", + "max": 269, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1cea69ccb7ab49eab6b3ceccc60f846b", + "value": 269 + } + }, + "08ce7a7c160f48488a1ccfff9e52fc19": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_289c3d37c529432d9f2d4ee502b427b5", + "IPY_MODEL_634f94c13ef74971a26f4988ade714fa", + "IPY_MODEL_55eb8ffb95564d25acb8b40d2ac86b0c" + ], + "layout": "IPY_MODEL_f935b17bb46649b5a81733fb0e5529ab" + } + }, + "1cea69ccb7ab49eab6b3ceccc60f846b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1ee82a5c2acc4c7fb0c63ddbb30b471c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1f6ef9a762f841f4acf9fe8edd589b2b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "289c3d37c529432d9f2d4ee502b427b5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1ee82a5c2acc4c7fb0c63ddbb30b471c", + "placeholder": "​", + "style": "IPY_MODEL_67a7158ed55842d38aa55392ab1d3744", + "value": "huggingface_wrapper.py: 100%" + } + }, + "3f9981ff162245f49f533d90c0100a9c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_69bd8895c8f948c7b879b1c869539334", + "max": 176, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_61c6494ff4fb4f7a888e3129ac641df4", + "value": 176 + } + }, + "3fcaf29e93bc4f0ea5293b761fbe9194": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "43f1996def8b4fca9ef92d15e935f1ba": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "47cb28f509c648a1a435b946bfc07518": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b8b4d2ec394547cc86dcfe1eddaf0402", + "placeholder": "​", + "style": "IPY_MODEL_579c2d34818e4ff0ba157532d81edfb7", + "value": " 269/269 [00:00<00:00, 3.64kB/s]" + } + }, + "5141f5b94903473c9e89d2ec2ed8e5c4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "531cde6e3d8f44009ab3480338bb316e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7243419b29bc471fa4a8ef56d6041f47", + "placeholder": "​", + "style": "IPY_MODEL_a4b4fdbd7d5648a1ab7850e2f81ead15", + "value": " 176/176 [00:00<00:00, 3.73kB/s]" + } + }, + "541c0af07fc64fcfabaa49cd90ccacd2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "548aee36a23840da8d2d9112c1a126e4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_9875bd75c19140e8982cfc83ab405980", + "IPY_MODEL_08b5e6a111014f50be5b549dc61c04bd", + "IPY_MODEL_47cb28f509c648a1a435b946bfc07518" + ], + "layout": "IPY_MODEL_541c0af07fc64fcfabaa49cd90ccacd2" + } + }, + "55eb8ffb95564d25acb8b40d2ac86b0c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5b68d66d5ceb473eaf8e45a979e58e8f", + "placeholder": "​", + "style": "IPY_MODEL_1f6ef9a762f841f4acf9fe8edd589b2b", + "value": " 95.9k/95.9k [00:00<00:00, 1.46MB/s]" + } + }, + "579c2d34818e4ff0ba157532d81edfb7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5b68d66d5ceb473eaf8e45a979e58e8f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "61c6494ff4fb4f7a888e3129ac641df4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "634f94c13ef74971a26f4988ade714fa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_812642ecbb744904a74c1dbcf414cea9", + "max": 95901, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c38ea535ff6e4d60a772700d1667c7c1", + "value": 95901 + } + }, + "640f95344552481aa56d9d17398a9b45": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "67a7158ed55842d38aa55392ab1d3744": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "69bd8895c8f948c7b879b1c869539334": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6eabe0cb30a142f2b79185efb52f6bb0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a05042c291aa4e1e9bb4f09be3d9f806", + "IPY_MODEL_3f9981ff162245f49f533d90c0100a9c", + "IPY_MODEL_531cde6e3d8f44009ab3480338bb316e" + ], + "layout": "IPY_MODEL_8ae64abd8d234c218a8fb4e92a059377" + } + }, + "7243419b29bc471fa4a8ef56d6041f47": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7b4e21e8ad294346b7552f231dee95b4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FileUploadModel", + "state": { + "_counter": 0, + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FileUploadModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "FileUploadView", + "accept": ".json, .txt", + "button_style": "", + "data": [], + "description": "Upload", + "description_tooltip": null, + "disabled": false, + "error": "", + "icon": "upload", + "layout": "IPY_MODEL_5141f5b94903473c9e89d2ec2ed8e5c4", + "metadata": [], + "multiple": false, + "style": "IPY_MODEL_a03c2859b511486da91a88babee36f63" + } + }, + "7df715d2ab2d45adb6cf04e1b908fe34": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "812642ecbb744904a74c1dbcf414cea9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8131214e739d4fea923d9a4b61bc0a8e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": "100px", + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "500px" + } + }, + "88183c92883b4c9787ecaf63b1f73f6e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FileUploadModel", + "state": { + "_counter": 0, + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FileUploadModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "FileUploadView", + "accept": ".pdb", + "button_style": "", + "data": [], + "description": "Upload", + "description_tooltip": null, + "disabled": false, + "error": "", + "icon": "upload", + "layout": "IPY_MODEL_974500e54aa5465fb97c17152439918a", + "metadata": [], + "multiple": false, + "style": "IPY_MODEL_e3d473ff87114b4aafb71269df884892" + } + }, + "8ae64abd8d234c218a8fb4e92a059377": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "974500e54aa5465fb97c17152439918a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9875bd75c19140e8982cfc83ab405980": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7df715d2ab2d45adb6cf04e1b908fe34", + "placeholder": "​", + "style": "IPY_MODEL_9df76c111df140c3b7dfadb724aae88f", + "value": "config.json: 100%" + } + }, + "9d004e2783e54df7a8dca1952ec2434d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "TextareaModel", + "state": { + "_dom_classes": [ + "variant_text_area" + ], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "TextareaModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "TextareaView", + "continuous_update": true, + "description": "Variant String:", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_8131214e739d4fea923d9a4b61bc0a8e", + "placeholder": "[\"T17P,T54F\", \"V28L,F51A\", \"T17P,V28L,F51A,T54F\"]\n[\"T13P,T33F\"]", + "rows": null, + "style": "IPY_MODEL_ea799d8361454d628d2dcdb09c6b18a8", + "value": "" + } + }, + "9df76c111df140c3b7dfadb724aae88f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a03c2859b511486da91a88babee36f63": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "a05042c291aa4e1e9bb4f09be3d9f806": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_640f95344552481aa56d9d17398a9b45", + "placeholder": "​", + "style": "IPY_MODEL_3fcaf29e93bc4f0ea5293b761fbe9194", + "value": "model.safetensors: 100%" + } + }, + "a4b4fdbd7d5648a1ab7850e2f81ead15": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b8b4d2ec394547cc86dcfe1eddaf0402": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c38ea535ff6e4d60a772700d1667c7c1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e3d473ff87114b4aafb71269df884892": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "ea799d8361454d628d2dcdb09c6b18a8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "initial" + } + }, + "f935b17bb46649b5a81733fb0e5529ab": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/finetuning.ipynb b/notebooks/finetuning.ipynb index 7322c3d..8eb5603 100644 --- a/notebooks/finetuning.ipynb +++ b/notebooks/finetuning.ipynb @@ -338,9 +338,7 @@ "\n", "PyTorch Lightning has a built-in progress bar that is convenient for seeing training progress, but it does not display correctly in Jupyter when calling the script with `!python`. We are going to disable the progress bar for by setting the flag `--enable_progress_bar false`. Instead, we implemented a simple print statement to track training progress, which we will enable with the flag `--enable_simple_progress_messages`. \n", "\n", - "The [train_target_model.py](../code/train_target_model.py) script can support running on Apple Silicon with acceleration via MPS, but the version of PyTorch used in this environment is slightly outdated and does not support all MPS operations, so MPS support has been disabled. The script will run on GPU via CUDA if available, otherwise it will use CPUs.\n", - "\n", - "To speed up training for demo purposes, we also override `--max_epochs 50` and `--unfreeze_backbone_at_epoch 25`." + "The [train_target_model.py](../code/train_target_model.py) script can support running on Apple Silicon with acceleration via MPS, but the version of PyTorch used in this environment is slightly outdated and does not support all MPS operations, so MPS support has been disabled. The script will run on GPU via CUDA if available, otherwise it will use CPUs." ] }, {