- Training ECG-Byte
- Training LLM
- Inference LLM
- Analysis of methods
- Training 1st Stage
- Training 2nd Stage
- Inferencing 2 Stage
This repository is the official implementation of ECG-Byte: A Tokenizer for End-to-End Generative Electrocardiogram Language Modeling by William Jongwon Han, Choajing Duan, Michael A. Rosenberg, Emerson Liu, and Ding Zhao.
Please carefully read the below documentations to run the pipeline. If there are any questions or bugs, please do not hesitate to reach out to wjhan{@}andrew{dot}cmu{edu} or submit an issue with corresponding details.
All installations and experiments were completed on Ubuntu 20.04.5 LTS with NVIDIA A6000 GPUs.
We want to note that the codebase is quite messy and we plan on cleaning it up and adding more modularity for better readability, structure, and modifiability. For now, please read the following description of the main files and folders to consider in the codebase.
-
data_loader.py
- Dataloader classes for training end to end LLM and 2 stage training. -
main.py
- Main end to end training pipeline. -
train_tokenizer.py
- File for training ECG-Byte. -
pretrain.py
- Pretraining pipeline for 1st stage training -
finetune.py
- Finetuning pipeline for 2nd stage training. -
interp_analysis.py
- After training the end to end pipeline, this code is for running the pipeline for attention visualizations. -
scheduler.py
- Our custom scheduler for training our end to end method and 2 stage training method.
We provide the functions that are utilzed throughout the codebase in the following individual folders.
-
analysis
- Folder containing the mapping between ECG and tokens and plotting token distributions/usage. -
models
- Folder containing all implemented models in the study. -
preprocess
- Folder containing the main preprocessing pipeline. -
runners
- Folder containing our training, inferencing, and interpretability runners. -
rust_bpe
- Folder containing the main ECG-Byte code. -
scripts
- Folder containing bash scripts to execute all preprocessing and main experiments. -
utils
- Folder containing all of the helper functions used throughout the codebase.
-
To install Rust:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain=1.79.0 -y
-
Open a new terminal to set PATH for Rust installation.
-
After opening a new terminal, check the installation by running
rustc --version
. -
Create the conda virtual environment via
conda create -n ecg-byte python=3.10.15
. -
Activate the environment
conda activate ecg-byte
-
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
-
Please
cd
into thetransformers
directory andpip install -e .
. -
Now
cd ../
andpip install -e .
-
Run the
ecg_byte/test/test_gpu.py
to ensure you are able to use your GPU. -
Run the
ecg_byte/test/test_transformers.py
to ensure you properly installed thetransformers
package. -
cd
intoecg_byte/ecg_byte/rust_bpe
and executematurin develop --release
to compile the tokenizer. -
Another consideration is that we use gated models (e.g., Llama 3.2, Gemma) from HuggingFace, therefore you will need to get an api key and log into it via
huggingface-cli login
in the terminal. We also require you to log in inside the main training *.py file via the login functionfrom huggingface_hub import login
.
NOTE: From now, all instructions will assume you are working from the ecg_byte/ecg_byte
directory.
As described in the paper, we have only experiments with training on 25% of each respective dataset. However, if you have enough compute, feel free to modify to run the pipeline with the full data!
-
Please download the PTB-XL dataset through this link.
-
Please create a
data
folder, unzip the zip file inside thedata
folder and rename the folder asptb
.
-
Please download the Mimic IV ECG dataset through this link.
-
Unzip the zip file inside the
data
directory and rename the unzipped directory asmimic
.
- To download the ECG-QA dataset, please execute the following command in the
data
folder:
git clone https://github.com/Jwoo5/ecg-qa.git
-
We exactly follow the instructions in this section of the repository for mapping the PTB-XL and MIMIC IV ECG dataset to the question and answers, thus we encourage you to do so as well. If there are any troubles, please do not hesitate to submit an issue in this repository or theirs.
-
After mapping the datasets, you should have an output folder in the
data/ecg-qa
folder with the mappedparaphrased
andtemplate
question and answers.
Pretrain MIMIC dataset curated by ECG-Chat Zhao et al.
- Next create a
data/ecg_chat_data
directory and download thepretrain_mimic.json
file from this dropbox link.
Once you are finished with these steps, it's time to preprocess the data!
-
Execute the preprocessing script by
bash scripts/preprocess.sh
. We have provided default configurations for all the datasets used in our study but feel free to experiment with others! -
After preprocessing, you will see files with this structure:
data/{data_name}_{seg_len}
, where data_name is the dataset name, and seg_len is the segment length of each ECG (e.g., 500, 1250, 2500). For sampling and training the tokenizer, we utilize data from mimic with seg_len = 2500 for the full 10 seconds -
Due to the enormity of the data, in our paper, we sample a subset of the 2500 seg_len ECGs to train the tokenizer. In order to do this, please execute
bash scripts/sample_ecg.sh
.
After sampling, you should see a .txt
file appear in the data
which represents the sampled ECGs.
- Once you sampled the ECGs, you can simply run
bash scripts/train_tok.sh
to train the tokenizer. We also provide a script to load in your trained tokenizer and see the encoding compression rate and original vs. decoded signal. Lastly, we provide basic configurations, however, please feel free to modify these.
NOTE: We also provide a trained tokenizer at this link. Please feel free to use this or train your own!
- We provide training scripts in
scripts/train_model.sh
for both distributed and single GPU setting. Please utilize whichever script for your own setting. To train the model just simply runbash scripts/train_model.sh
after defining the correct paths. We provide an example of how it should look like below and in the.sh
file.
python main.py \
--model=meta-llama/Llama-3.2-1B \
--tokenizer_check=$TOKENIZER_NAME \
--device=cuda:0 \
--batch_size=2 \
--pad_to_max=1020 \
--peft \
--num_merges=3500 \
--epochs=1 \
--percentiles=$PATH_TO_PERCENTILES_FROM_PREPROCESSING \
--dataset=$DATASET_FOLDER_NAME \
--gpus=0,1,2,3 \
--dis \
--ports=12359
If you do not have multiple GPUs, simply modify the script like so:
python main.py \
--model=meta-llama/Llama-3.2-1B \
--tokenizer_check=$TOKENIZER_NAME \
--device=cuda:0 \
--batch_size=2 \
--pad_to_max=1020 \
--peft \
--num_merges=3500 \
--epochs=1 \
--percentiles=$PATH_TO_PERCENTILES_FROM_PREPROCESSING \
--dataset=$DATASET_FOLDER_NAME \
--device=cuda:0
NOTE: With our LoRA configurations, sequence length of 1024, and a batch size of 2, this sums up to ~14 GB of GPU memory. We also provide finetuned checkpoints of the main models (Llama 3.2 1B) for ECG-QA PTB-XL, ECG-QA MIMIC-IV, and Pretrain MIMIC at this link.
- We provide inference scripts in
scripts/inference.sh
for generating answers. We have set the initial configurations but feel free to modify them. To inference the model simply runbash scripts/inference.sh
. We provide an example of how the script should look like below and in the.sh
file.
python main.py \
--model=meta-llama/Llama-3.2-1B \
--tokenizer_check=$TOKENIZER_NAME \
--batch_size=1 \
--pad_to_max=1020 \
--device=cuda:4 \
--peft \
--num_merges=3500 \
--percentiles=$PATH_TO_PERCENTILES_FROM_PREPROCESSING \
--checkpoint=$PATH_TO_CHECKPOINT_FOLDER \
--inference \
--dataset=$DATASET_FOLDER_NAME
NOTE: With our LoRA configurations and a batch size of 1, inference sums up to ~5 GB of GPU memory.
-
We provide a number of anlaysis scripts seen in our paper, namely the attention visualizations in subsection 5.5 and the ECG-Byte analysis in subsection 5.4.
-
To visualize the attentions and get the overlayed attention images, utilize the script
scripts/interpret.sh
. Below is an example of the outputted visualization:
- To visualize how ECG-Byte is merging the signal, please use the script
scripts/track_encoding.sh
. Below is an example of the outputted visualization:
- To visualize the token usage and distribution of ECG-Byte, please use the script
scripts/token_dist.sh
. Below is an example of the outputted visualization:
- The 1st stage training script is provided in
scripts/pretrain.sh
, where we provide the default configurations to set for training all of the implementations. Please feel free to follow the paper's configurations or utilize your own! Below is an example of how the script should look like.
python pretrain.py \
--model=resnet \
--batch_size=64 \
--device=cuda:2 \
--peft \
--epochs=20 \
--seg \
--dataset=$DATASET_FOLDER_NAME \
--gpus=0,1,2,3 \
--dis \
--log
NOTE: We also provide pretrained checkpoints of all 2 stage methods we implemented at this link.
- We also provide the scripts to train the 2nd stage in
scripts/finetune.sh
, where we provide the default configurations to set for training all of the said implementations. Again, please feel free to follow the paper's configurations or utilize your own! Below is an example of how the script should look like
python finetune.py \
--model=resnet_model \
--batch_size=2 \
--pad_to_max=1022 \
--peft \
--epochs=1 \
--seg \
--dataset=ptb_qa \
--first_check=$PATH_TO_1st_STAGE_CHECKPOINT_FOLDER \
--gpus=0,1,2,3 \
--dis \
--ports=12359
- We provide the scripts to inference the 2 stage training methods after completing the 1st and 2nd stage. Please utilize the
scripts/inference.sh
script for inference.
python finetune.py \
--model=resnet_model \
--batch_size=1 \
--pad_to_max=1022 \
--device=cuda:4 \
--peft \
--num_merges=3500 \
--dataset=ptb_qa \
--inference \
--first_check=$PATH_TO_1st_STAGE_CHECKPOINT_FOLDER \
--checkpoint=$PATH_TO_CHECKPOINT_FOLDER
We encountered some issues during development of ECG-Byte and hope to contribute to the open source community by reporting them here and adding any tips if possible. If you happen to know a good solution to any of them, please do not hesitate to open an issue or pull request!
-
tqdm
bar freezing script with multiprocessing - We noticed that the tqdm bar freezes sometimes when we put it inside a multiprocessing job (especially during preprocessing). We recommend adding print statements before and after the main operations inside the tqdm loop to ensure the operations are being executed. This is a thread of the issue from the tqdm repository. Please feel free to look at it! -
Utilizing inputs_embeds for generation - We noticed that utilizing inputs_embeds as the primary input to the model for generation is quite instable (e.g., example1 from HF, example2 from stackoverflow, example3 from vllm but related, example4 from HF). When we tried generating via only
inputs_embeds
the model failed to generate anything coherent (i.e., mostly empty strings). Our current workaround is passing in bothinput_ids
andinputs_embeds
as inputs for generation. The reasoning behind this is from the GenerationMixin code and this thread. From the code, it seems like the model creates an empty input_ids tensor of shape (batch_size, 0) and uses the embeddings only for the first forward pass. However, this can be unstable because there's no explicit token mapping for the embeddings, making it harder for the model to maintain coherence between the embedded representation and subsequent token generation. The solution for this would be to create betterinputs_embeds
from the getgo. However, we wanted to add some guidance to the generation therefore we provided embeddings for the initial forward pass while having input_ids that explicitly map to those embeddings, providing a more stable foundation for generation. This is not "true" generation only usinginputs_embeds
, therefore we believe that this reinforces our method of representing ECGs even more. -
HuggingFace api key not being recognized - We also noticed that the main training script sometimes crashes due to the huggingface api key not being recognized. The current workaround is just to relogin utilizing your own personal api key.
-
Nan values during preprocessing - We noticed that the MIMIC-IV ECG dataset has many nan values during preprocessing so we workaround this by skipping them.
This work is done in collaboration with the Mario Lemieux Center for Heart Rhythm Care at Allegheny General Hospital. We thank Wenhao Ding, Haohong Lin, Shiqi Liu, and Hyoeun Kang for the valuable discussions.
We thank the authors of MERL for their ResNet code and the authors of ECG-QA, and ECG-Chat for their publicly released datasets.
Lastly, we thank HuggingFace for providing the APIs for the models.
If this work has helped you please cite the following:
@misc{han2024ecgbytetokenizerendtoendgenerative,
title={ECG-Byte: A Tokenizer for End-to-End Generative Electrocardiogram Language Modeling},
author={William Han and Chaojing Duan and Michael A. Rosenberg and Emerson Liu and Ding Zhao},
year={2024},
eprint={2412.14373},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.14373},
}