This repository contains an implementation of a Physics-Informed Neuran Network (PINN) for solving systems of misspecified ODEs at multiple parameter settings. This framework is built for approximating solutions of models for hERG ion channel kinetics that do not fully represent the experimental data.
We use two toy problems to assess the performance of PINN is fitting a simpler model to the synthetic data generated by a more detailed model. The description of the first generative model for the toy problem (M1) can be foound in Kemp et.al. (https://doi.org/10.1085/jgp.202112923) The description of the second generative model for the toy problem (M2) can be found in Wang et.al. (https://doi.org/10.1111/j.1469-7793.1997.045bl.x). The current data produced by a generative (true) model is used optimise parameters of a model with incorrect structure, in our case Hodgkin-Huxley (HH) model. The core idea is to train a PINN on the structure of the misspecified model (HH) and the current generated by either (M1) and (M2) to approximate a family of ODE solutions obtained for different parameter settings. The trained PINN can then be used in ODE parameter optimisation instead of widely-used iterative solvers.
To install and run the ionch_pinn project, follow these steps:
Clone the repository:
git clone https://github.com/anastasia-rk/ionch_pinn.git cd ionch_pinn
Create a virtual environment (optional but recommended):
python3 -m venv .venv source .venv/bin/activate # On Windows use venv\Scripts\activate
Install dependencies:
pip install -r requirements.txt
Run the training scirpt for either HH generated data (used for benchmarking):
nohup python train_pinn_on_hh_data.py > output_for_hh_data.txt &
Or run the training script for the M1 generated data:
nohup python train_pinn_on_kemp_data.py > output_for_kemp_data.txt &
- Deterministic algorithm requirement is cannot be satisfied for the ODE gradient computation on cuda devices (it includes running the backward step)
- No arguments are currently passed to the file from command, unless
IsATest
flag is toggled toTrue
in the script, it will run the full training loop. - The code is currently 80% functional programming with a simple fully connected network class - need to write a comprehensive class with proper methods to utilise multiprocessing.
- M2 generated data has not been used for training yet.
- Too many figures stored at the moment to observe the performance throughout the training process.
- Annoying warming about log scale of axes in costs figure.
- The ODE parameter optimisation is missing - can be brought over from the ionch repo.