Implementation of CalciumGAN, used to obtain the results detailed in CalciumGAN: A Generative Adversarial Network Model for Synthesising Realistic Calcium Imaging Data of Neuronal Populations.
Citing this work
@article{li2020calciumgan,
title={CalciumGAN: A Generative Adversarial Network Model for Synthesising Realistic Calcium Imaging Data of Neuronal Populations},
author={Li, Bryan M and Amvrosiadis, Theoklitos and Rochefort, Nathalie and Onken, Arno},
journal={arXiv preprint arXiv:2009.02707},
year={2020}
}
- 1. Installation
- 2. Dataset
- 3. Training model
- 4. Spike analysis
- 5. Visualization and Profiling
- 6. Hyper-parameter Search
- It is recommended to install the codebase in a virtual environment, such as conda.
- Create a new
conda
environment in Python 3.6conda create -n calciumgan python=3.6
- Activate
calciumgan
virtual environmentconda activate calciumgan
- Install all dependencies and packages with
setup.sh
script, works on both Linus and macOS.sh setup.sh
Install the following packages:
- TensorFlow
- j-friedrich/OASIS
- Neo
- Elephant
- packages in
requirements.txt
- code from dg_python are also being used for the dichotomized Gaussian model
- Navigate to
dataset
cd dataset
- Place all raw calcium imaging data under
dataset/raw_data
- Apply OASIS to infer spike train
python spike_train_inference.py --input_dir raw_data
- Generate
TFRecords
from a specific pickle file--input
, normalize the data, preform segmentation and store theTFrecords
inoutput_dir
. Use--help
to see all available arguments.python generate_tfrecords.py --input raw_data/signals.pkl --output_dir tfrecords/sl2048 --sequence_length 2048 --normalize
- Generate artificial spike trains and calcium-like signals from the
Dichomotized Gaussian distribution with the mean and covariance of data in
--input
, save the the output pickle file to--output
.TFRecords
in--output_dir
. Use--help
to see all available arguments.python generate_dg_data.py --input raw_data/signals.pkl --output dg.pkl
- Generate
TFRecords
from a specific pickle file--input
, normalize the data, preform segmentation and store theTFrecords
inoutput_dir
. Use--help
to see all available arguments.python generate_tfrecords.py --input dg.pkl --output_dir tfrecords/sl2048_dg --sequence_length 2048 --normalize
- To train CalciumGAN on the recorded calcium imaging data with the default
hyper-parameters for 400 epochs. Checkpoints, generated data, model training
information are stored in
--output_dir
.python main.py --input_dir dataset/tfrecords/sl2048 --output_dir runs/001 --epochs 400 --batch_size 128 --model calciumgan --algoirthm wgan-gp --noise_dim 32 --num_units 64 --kernel_size 24 --strides 2 --m 10 --layer_norm --mixed_precision --save_generated last
- Use
--help
to check all available arguments. Mixed precision compute, TensorBoard profiling, hyper-parameter search are some of the features built into this codebase. - The training command applies to both recorded data and dischotomized Gaussian artificial data.
- Deconvolve the calcium signals to spike trains from generated data in
--output_dir
, then compute various spikes statistics. Use--help
to check all available arguments.python compute_metrics.py --output_dir runs/001
- All the plots can be found in
runs/001/metrics/plots
- Deconvolve the calcium signals to spike trains from generated data in
--output_dir
, then compute various spikes statistics. Use--help
to check all available arguments.python compute_dg_metrics.py --output_dir runs/002
- PLots of mean and covariance can be found in
diagrams/
- Run
tensorboard
tensorboard --logdir runs/001
- We have implemented profiling with
TensorFlow Profiler support.
You can enable profiling with
--profile
flag when training the model withmain.py
.
- We have incorporated the Hyperparameter Turning with Keras feature.
Modify the hyper-parameters you would like to test in
seasrch.py
and runpython search.py --input_dir dataset/tfrecords/sl2048 --output_dir runs/hparams_search --epochs 400 --mixed_precision