Source code for Unsupervised Object Representation Learning using Translation and Rotation Group Equivariant VAE
Overall Framework
Translation and Rotation Group Equivariant Encoder
Spatially equivariant generator
TARGET-VAE identifies protein heterogeneity on the cryo-EM particle-stack from EMPIAR-10025 dataset.
Dependencies
- Python 3
- Pytorch >= 1.11
- torchvision >= 0.12
- numpy >= 1.21
- scikit-learn >= 1.0.2
- astropy >= 5.0.4
The rotated and translated mnist datasets can be downloaded from these links:
MNIST(N)
MNIST(U)
The code in train_mnist.py, train_particles.py, train_dsprites.py, and train_galaxy.py, train TARGET-VAE on mnist (regular, MNIST(N), MNIST(U)), particle stacks of cryo-EM, dSprites, and galaxies datasets. The scripts with clustering at the start of their names, can be used to apply the trained model for clustering on a specific dataset.
For example to train TARGET-VAE with P8 groupconv and z_dim=2, on the mnist_U dataset (described in the paper):
python train_mnist.py -z 2 --dataset mnist-U --t-inf attention --r-inf attention+offsets --groupconv 8 --fourier-expansion
, and to use this trained model for clustering:
python clustering_mnist.py -z 2 --dataset mnist-U --path-to-encoder training_logs/2022-06-08-18-53_mnist-N_zDim_2_translation_attention_rotation_attention+offsets_groupconv8/inference.sav --t-inf attention --r-inf attention+offsets
To train TARGET-VAE with P8 and z_dim=2, on the particle stack saved in the folder 'data/EMPIAR10025/mrcs/':
python train_particles.py --train-path data/EMPIAR10025/mrcs/ -z 2 --t-inf attention --r-inf attention+offsets --groupconv 8 --fourier-expansion
, and to use the trained model for clustering:
python clustering_particles.py -z 2 --test-path data/EMPIAR10025/mrcs/ --path-to-encoder 2022-05-27-14-45_10025_zDim_2_translation_attention_rotation_attention+offsets_groupconv8/inference.sav --t-inf attention --r-inf attention+offsets
This source code is provided under the MIT License