Skip to content

Source code for "Unsupervised Object Representation Learning using Translation and Rotation Group Equivariant VAE," in pytorch. NeurIPS 2022

License

Notifications You must be signed in to change notification settings

SMLC-NYSBC/TARGET-VAE

Repository files navigation

TARGET-VAE

Source code for Unsupervised Object Representation Learning using Translation and Rotation Group Equivariant VAE
Overall Framework TARGET-VAE framework

Translation and Rotation Group Equivariant Encoder Encoder of TARGET-VAE

Spatially equivariant generator Generator of TARGET-VAE


TARGET-VAE identifies protein heterogeneity on the cryo-EM particle-stack from EMPIAR-10025 dataset. protein heterogeneity identified by TARGET-VAE

Setup

Dependencies

  • Python 3
  • Pytorch >= 1.11
  • torchvision >= 0.12
  • numpy >= 1.21
  • scikit-learn >= 1.0.2
  • astropy >= 5.0.4

Datasets

The rotated and translated mnist datasets can be downloaded from these links:
MNIST(N)
MNIST(U)

Usage

The code in train_mnist.py, train_particles.py, train_dsprites.py, and train_galaxy.py, train TARGET-VAE on mnist (regular, MNIST(N), MNIST(U)), particle stacks of cryo-EM, dSprites, and galaxies datasets. The scripts with clustering at the start of their names, can be used to apply the trained model for clustering on a specific dataset.

For example to train TARGET-VAE with P8 groupconv and z_dim=2, on the mnist_U dataset (described in the paper):

python train_mnist.py -z 2 --dataset mnist-U --t-inf attention --r-inf attention+offsets --groupconv 8 --fourier-expansion

, and to use this trained model for clustering:

python clustering_mnist.py -z 2 --dataset mnist-U --path-to-encoder training_logs/2022-06-08-18-53_mnist-N_zDim_2_translation_attention_rotation_attention+offsets_groupconv8/inference.sav --t-inf attention --r-inf attention+offsets 



To train TARGET-VAE with P8 and z_dim=2, on the particle stack saved in the folder 'data/EMPIAR10025/mrcs/':

python train_particles.py --train-path data/EMPIAR10025/mrcs/ -z 2 --t-inf attention --r-inf attention+offsets --groupconv 8 --fourier-expansion

, and to use the trained model for clustering:

python clustering_particles.py -z 2 --test-path data/EMPIAR10025/mrcs/ --path-to-encoder 2022-05-27-14-45_10025_zDim_2_translation_attention_rotation_attention+offsets_groupconv8/inference.sav --t-inf attention --r-inf attention+offsets

License

This source code is provided under the MIT License

About

Source code for "Unsupervised Object Representation Learning using Translation and Rotation Group Equivariant VAE," in pytorch. NeurIPS 2022

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages