This repository contains training, embedding and sampling code for the "Graph Infused Representation Assembled From a multi-Faceted variational auto-Encoder" (Giraffe) model used to create meaningful molecular representations for small molecules.
An overview of the architecture and training logic of Giraffe looks as follows:
NOTE: If the
--no-vae
or--wae
flags are used during training, the encoder only outputs a single vector.
All packages needed to use GIRAFFE are provided in environment.yml
. To create a new Anaconda environment from it, run the following:
conda env create -f environment.yml
Training a new model on a file with SMILES strings can be achieved as follows:
python train.py data/100k.smi
The call above will train a VAE. To train a traditional autoencoder, use the --no-vae
flag. It is also possible to train a Wasserstein autoencoder (WAE) using the maximum mean discrepancy in inverse multiquadratic kernel space to match the encoder distribution with a gaussian. Use the --wae
flag to train a MMD WAE.
To get all the options, call python train.py --help
:
Usage: train.py [OPTIONS] FILENAME
Options:
-n, --run_name TEXT Name of the run for saving (filename if
omitted).
-d, --delimiter TEXT Column delimiter of input file.
-c, --smls_col TEXT Name of column that contains SMILES.
-e, --epochs INTEGER Nr. of epochs to train.
-o, --dropout FLOAT Dropout fraction.
-b, --batch_size INTEGER Number of molecules per batch.
-r, --random Randomly sample molecules in each training
step.
-p, --props TEXT Comma-seperated list of descriptors to use.
All, if omitted
--epoch_steps, --es INTEGER If random, number of batches per epoch.
-v, --val FLOAT Fraction of the data to use for validation.
-l, --lr FLOAT Learning rate.
--lr_fact, --lf FLOAT Learning rate decay factor.
--lr_step, --ls INTEGER LR Step decay after nr. of epochs.
-a, --after INTEGER Epoch steps to save model.
-t, --temp FLOAT Temperature to use during SMILES sampling.
--n_sample, --ns INTEGER Nr. SMILES to sample after each trainin epoch.
--kernels_gnn, --nk INTEGER Nr. GNN kernels
--layers_gnn, --ng INTEGER Nr. GNN layers
--layers_rnn, --nr INTEGER Nr. RNN layers
--layers_mlp, --nm INTEGER Nr. MLP layers
--dim_gnn, --dg INTEGER Hidden dimension of GNN layers
--dim_rnn, --dr INTEGER Hidden dimension of RNN layers
--dim_tok, --dt INTEGER Dimension of RNN token embedding
--dim_mlp, --dm INTEGER Hidden dimension of MLP layers
--weight_prop, --wp FLOAT Factor for weighting property loss in VAE loss
--weight_vae, --wk FLOAT Factor for weighting KL divergence loss in VAE
loss
--anneal_type, --at TEXT Shape of cyclical annealing: linear or sigmoid
--anneal_cycle, --ac INTEGER Number of epochs for one VAE loss annealing
cycle
--anneal_grow, --ag INTEGER Number of annealing cycles with increasing
values
--anneal_ratio, --ar FLOAT Fraction of annealing vs. constant VAE loss
weight
--vae / --no-vae Whether to train a variational AE or classical
AE
--wae / --no-wae Whether to train a Wasserstein autoencoder
using MMD
--scale / --no-scale Whether to scale all properties from 0 to 1
--n_proc, --np INTEGER Number of CPU processes to use
--help Show this message and exit.
After training, a config file containing all the used options will be saved in the checkpoints folder. This file is used for later sampling and embedding tasks.
If the input file only contains SMILES strings (single column, with or without header), Giraffe uses all calculable RDKit properties (scaled from 0 to 1). If the input file contains other numerical columns next to the SMILES, it will use those values as properties. The user has to ensure the properties are scaled to a reasonable range.
To randomly sample up to 100
SMILES strings of maximum length 96
at temperature 0.6
from a trained model checkpoint (in this case epoch 70
of the model pub_vae_sig
), run the following:
python sampling.py -r -e 70 -t 0.6 -l 96 -n 100 -c models/pub_vae_sig
If you ommit the -r
flag, instead of selecting random points in latent space, the model will select random 100
SMILES from the training data and sample SMILES based on those seeds.
Conditional sampling around a single SMILES string of interest using epoch 70
of the pretrained model pub_vae_sig
:
python sampling.py -e 70 -t 0.6 -l 96 -n 100 -s "CC1(CC(CC(N1)(C)C)OC2=NN=C(C=C2)C3=C(C=C(C=C3)C4=CNN=C4)O)C" -c models/pub_vae_sig
The sampled SMILES strings are stored in output/sampled.csv
together with the negative log likelihood score.
To get all available options, call python sampling.py --help
:
Usage: sampling.py [OPTIONS]
Options:
-c, --checkpoint TEXT Checkpoint folder.
-e, --epoch INTEGER Epoch of models to load.
-s, --smiles TEXT Reference SMILES to use as seed for sampling.
-n, --num INTEGER How many molecules to sample.
-t, --temp FLOAT Temperature to transform logits before for multinomial sampling.
-l, --maxlen INTEGER Maximum allowed SMILES string length.
-o, --out TEXT Output filename
-i, --interpolate Linear interpolation between 2 SMILES (',' separated in -s).
-r, --random Randomly sample from latent space.
-p, --parent Store parent seed molecule in output file.
--help Show this message and exit.
To embed SMILES strings using the pretrained GNN, proceed as follows:
python embedding.py -f models/pub_vae_sig -e 70 data/1k.txt output/test/embeddings.csv
To get all available options, call python embedding.py --help
:
Usage: embedding.py [OPTIONS] INPUT_FILE OUTPUT_FILE
Options:
-d, --delimiter TEXT Column delimiter of input file.
-c, --smls_col TEXT Name of column that contains SMILES.
-i, --id_col TEXT Name of column that contains compound IDs.
-f, --folder TEXT Checkpoint folder to load models from.
-e, --epoch INTEGER Epoch of models to load.
-b, --batch_size INTEGER Batch size to use for embedding.
-n, --n_mols INTEGER Number of molecules to randomly sub-sample. Default: 0 = all
-j, --n_jobs INTEGER Number of cores to use for data loader.
--help Show this message and exit.
Finetuneing a trained model on another set of SMILES strings can be achieved as follows:
python finetune.py data/actives.smi
To get all available options, call python finetune.py --help
:
Usage: finetune.py [OPTIONS] FILENAME
Options:
-c, --checkpoint TEXT Checkpoint folder.
-e, --epoch_load INTEGER Epoch of models to load.
-n, --run_name TEXT Name of the run for saving (filename if omitted).
-d, --delimiter TEXT Column delimiter of input file.
-sc, --smls_col TEXT Name of column that contains SMILES.
-ne, --epochs INTEGER Nr. of epochs to train.
-o, --dropout FLOAT Dropout fraction.
-b, --batch_size INTEGER Number of molecules per batch.
-r, --random Randomly sample molecules in each training step.
-es, --epoch_steps INTEGER If random, number of batches per epoch.
-v, --val FLOAT Fraction of the data to use for validation.
-l, --lr FLOAT Learning rate.
-lf, --lr_fact FLOAT Learning rate decay factor.
-ls, --lr_step INTEGER LR Step decay after nr. of epochs.
-a, --after INTEGER Epoch steps to save model.
-t, --temp FLOAT Temperature to use during SMILES sampling.
-ns, --n_sample INTEGER Nr. SMILES to sample after each trainin epoch.
-wp, --weight_prop FLOAT Factor for weighting property loss in VAE loss
-wk, --weight_vae FLOAT Factor for weighting VAE loss
-ac, --anneal_cycle INTEGER Number of epochs for one VAE loss annealing cycle
-ag, --anneal_grow INTEGER Number of annealing cycles with increasing values
-ar, --anneal_ratio FLOAT Fraction of annealing vs. constant VAE loss weight
--vae / --no-vae Whether to train a VAE or only AE
--scale / --no-scale Whether to scale all properties from 0 to 1
-p, --n_proc INTEGER Number of CPU processes to use
--help Show this message and exit.
In the VAE setup, we are emplyoing a growing cyclical annealing schedule. Here's an example of how the schedule looks for the two best performing cyclical annealing strategies for β values during training. Top (red): Linear increase over 4
cycles with cycle sizes of 10’000
steps with 7’500
increasing and 2’500
constant steps. Bottom (blue): Sigmoidal increase over 20
cycles with cycle sizes of 5’000
steps with 3’750
increasing and 1’250
constant steps. Both strategies were allowed to reach a maximum β value of 0.2
, and performed best in the tested benchmarks at the step indicated by a dashed line.
Adapted from https://github.com/haofuml/cyclical_annealing
To benchmark the obtained representation, use benchmark.py
.
It relies on the Chembench repository, and optionally on the CDDD repository.
Please follow the installation instructions described in their READMEs.
GIRAFFE also contains a script for benchmark datasets hosted on the Polaris Hub.
First, login to the polaris hub by running the command polaris login
. Then adapt the benchmark datasets in examples/benchmark_polaris.py
and finally run the script using your desired model checkpoint:
python examples/benchmark_polaris.py -m models/pub_vae_sig/atfp_70.pt <polaris username>
The examples
folder contains various scripts with examples on how to use the trained GIRAFFE models. Some examples reproduce figures presented in the corresponding publication.
If you'd like to contribute to GIRAFFE, have a look at CONTRIBUTING.md
.
Please check the LICENSE
file for more information.
If you are using GIRAFFE in your research, please cite the following publication:
@article{ mueller2024giraffe,
title={Combining Graph Attention and Recurrent Neural Networks in a Variational Autoencoder for Molecular Representation Learning and Drug Design},
author={M{\"u}ller, Alex T and Atz, Kenneth and Reutlinger, Michael and Zorn, Nicolas},
journal={International Conference on Machine Learning, ML4LMS Workshop},
year={2024},
url{https://openreview.net/forum?id=7WYcOGds6R}
}
A PDF of the publication is provided here.
To reproduce what is presented in the above publication, use the commit with tag v1.0.0
.