-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jonas Miesenboeck
committed
May 23, 2024
1 parent
3b515ae
commit 98ea233
Showing
10 changed files
with
324 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,68 @@ | ||
# clinical-study-GAN | ||
# patGAN | ||
|
||
patGAN is a Python tool that allows you to generate synthetic clinical study data in the form of individual synthetic patients. The data is generated using a **GAN** (Generative Adversarial Network) that was trained on a subset of the **Framingham Heart Study** data. The individual patients consist of *continuous* and *binary* demographic and clinical features, including vital parameter measurements, lifestyle choices, and medical history. The generated data is exported to a CSV file. | ||
|
||
## Requirements | ||
|
||
Before using patGAN, make sure your environment meets the following requirements: | ||
|
||
- `Python 3.x` (The tool was tested with Python 3.11) | ||
- Required dependencies (See installation guide below) | ||
|
||
## Installation | ||
|
||
1. Clone or download this repository to your local machine. | ||
|
||
2. Navigate to the root directory of the project. | ||
|
||
3. Install the required dependencies from the `requirements.txt` file: | ||
|
||
`pip install -r requirements.txt` | ||
|
||
## How to Use | ||
|
||
To use patGAN, follow these steps: | ||
|
||
1. Make sure you have the required dependencies installed. | ||
|
||
2. Open a terminal or command prompt in the root directory of your project. | ||
|
||
3. Run the CLI by executing the following command: | ||
`python patgan.py [OPTIONS]` | ||
|
||
4. You can get an overview of all available options by displaying the help section: | ||
`python patgan.py -h` | ||
|
||
|
||
|
||
## CLI Options | ||
|
||
The patGAN CLI supports the following options: | ||
|
||
- `-n` or `--n_patients`: Specifies the number of patients to generate. The default value is 100. | ||
- `-o` or `--output`: Specifies the name of the output CSV file. The default name is `generated_patients.csv` | ||
|
||
## Examples | ||
|
||
Here are some example usages of the patGAN CLI: | ||
|
||
1. Generate the default amount of patients (100): | ||
``` | ||
python patGAN.py | ||
``` | ||
|
||
2. Generate a specific amount of patients (e.g. 1000): | ||
``` | ||
python patGAN.py -n 1000 | ||
``` | ||
3. Generate the default amount of patients (100) and specify the output CSV name: | ||
``` | ||
python patGAN.py -o custom_name | ||
``` | ||
|
||
4. Generate a specific amount of patients (e.g. 5000) and specify the output CSV name: | ||
``` | ||
python patGAN.py -n 5000 -o my_patients | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[COLUMNS] | ||
continuous_columns = age,totChol,sysBP,diaBP,BMI,heartRate,glucose | ||
discrete_count_columns = cigsPerDay | ||
binary_columns = male,smoker,BPMeds,prevStroke,hypertension,diabetes | ||
|
||
[POSTPROCESSING] | ||
integer_columns = age,totChol,heartRate,glucose,cigsPerDay | ||
two_decimal_columns = BMI | ||
ptfive_round_columns = sysBP,diaBP | ||
smoker_column = smoker | ||
cigs_per_day_column = cigsPerDay | ||
sys_bp_column = sysBP | ||
dia_bp_column = diaBP | ||
|
||
[MODEL] | ||
latent_dim = 8 | ||
min_samples = 10 | ||
max_samples = 1000000 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import argparse | ||
|
||
import onnxruntime | ||
import pandas as pd | ||
|
||
import utils.helpers as helpers | ||
from utils.configparser import ConfigParser | ||
from utils.postprocessor import PostProcessor | ||
|
||
MODEL_PATH = "model/model.onnx" | ||
CONFIG_PATH = "config/config.ini" | ||
CONT_SCALER_PATH = "config/continuous_scaler.bin" | ||
DC_SCALER_PATH = "config/discrete_count_scaler.bin" | ||
|
||
|
||
class PatGAN: | ||
def __init__(self, config_parser: ConfigParser) -> None: | ||
self.__model = self.load_model() | ||
self.__config_parser = config_parser | ||
self.__postprocessor = PostProcessor(self.__config_parser) | ||
self.__postprocessor.load_scalers(CONT_SCALER_PATH, DC_SCALER_PATH) | ||
|
||
def load_model(self): | ||
return onnxruntime.InferenceSession(MODEL_PATH) | ||
|
||
def generate_samples(self, n_samples): | ||
noise = helpers.generate_noise(n_samples * 3, self.__config_parser.get_latent_dim()) | ||
generated_samples = self.__model.run(None, {"args_0": noise})[0] | ||
generated_samples = pd.DataFrame( | ||
generated_samples, | ||
columns=self.__config_parser.get_continuous_cols() | ||
+ self.__config_parser.get_discrete_count_cols() | ||
+ self.__config_parser.get_binary_cols(), | ||
) | ||
generated_samples = self.__postprocessor.reverse_scaling(generated_samples) | ||
generated_samples = self.__postprocessor.fit_transform(generated_samples) | ||
generated_samples = self.__postprocessor.filter(generated_samples, n_samples) | ||
return generated_samples | ||
|
||
def export_samples(self, samples, output): | ||
samples.to_csv(output, index=False) | ||
|
||
def parse_arguments(self): | ||
parser = argparse.ArgumentParser( | ||
description="patGAN - Generate synthetic clinical study data in the form of individual patients (CSV file)." | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--n_patients", | ||
metavar="\b", | ||
type=int, | ||
default=100, | ||
help="The number of patients to generate. Default is 100.", | ||
) | ||
parser.add_argument("-o", "--output", metavar="\b", help="The name of the output CSV file.") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
config_parser = ConfigParser(CONFIG_PATH) | ||
gan = PatGAN(config_parser) | ||
min_samples, max_samples = config_parser.get_min_samples(), config_parser.get_max_samples() | ||
args = gan.parse_arguments() | ||
|
||
args.output = args.output or "generated_patients" | ||
if not args.output.endswith(".csv"): | ||
args.output += ".csv" | ||
if args.n_patients < min_samples or args.n_patients > max_samples: | ||
print(f"\nERROR: The number of patients must be between {min_samples} and {max_samples}.") | ||
else: | ||
print(f"\nGenerating {args.n_patients} patients...") | ||
generated_samples = gan.generate_samples(args.n_patients) | ||
if generated_samples is None: | ||
exit(1) | ||
gan.export_samples(generated_samples, args.output) | ||
print(f"Generated patients saved to {args.output}") | ||
print() | ||
print(generated_samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
joblib==1.4.2 | ||
numpy==1.26.4 | ||
onnxruntime==1.16.3 | ||
pandas==2.1.3 | ||
scikit-learn==1.5.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import configparser | ||
|
||
|
||
class ConfigParser: | ||
def __init__(self, config_path): | ||
self.__config = configparser.ConfigParser() | ||
self.__config.read(config_path) | ||
self.__continuous_cols = self.__config.get("COLUMNS", "continuous_columns").split(",") | ||
self.__discrete_count_cols = self.__config.get("COLUMNS", "discrete_count_columns").split( | ||
"," | ||
) | ||
self.__binary_cols = self.__config.get("COLUMNS", "binary_columns").split(",") | ||
self.__integer_cols = self.__config.get("POSTPROCESSING", "integer_columns").split(",") | ||
self.__two_dec_cols = self.__config.get("POSTPROCESSING", "two_decimal_columns").split(",") | ||
self.__ptfive_cols = self.__config.get("POSTPROCESSING", "ptfive_round_columns").split(",") | ||
self.__smoker_col = self.__config.get("POSTPROCESSING", "smoker_column") | ||
self.__cigs_per_day_col = self.__config.get("POSTPROCESSING", "cigs_per_day_column") | ||
self.__sys_bp_col = self.__config.get("POSTPROCESSING", "sys_bp_column") | ||
self.__dia_bp_col = self.__config.get("POSTPROCESSING", "dia_bp_column") | ||
self.__latent_dim = self.__config.getint("MODEL", "latent_dim") | ||
self.__min_samples = self.__config.getint("MODEL", "min_samples") | ||
self.__max_samples = self.__config.getint("MODEL", "max_samples") | ||
|
||
def get_continuous_cols(self): | ||
return self.__continuous_cols | ||
|
||
def get_discrete_count_cols(self): | ||
return self.__discrete_count_cols | ||
|
||
def get_binary_cols(self): | ||
return self.__binary_cols | ||
|
||
def get_integer_cols(self): | ||
return self.__integer_cols | ||
|
||
def get_two_dec_cols(self): | ||
return self.__two_dec_cols | ||
|
||
def get_ptfive_cols(self): | ||
return self.__ptfive_cols | ||
|
||
def get_smoker_col(self): | ||
return self.__smoker_col | ||
|
||
def get_cigs_per_day_col(self): | ||
return self.__cigs_per_day_col | ||
|
||
def get_sys_bp_col(self): | ||
return self.__sys_bp_col | ||
|
||
def get_dia_bp_col(self): | ||
return self.__dia_bp_col | ||
|
||
def get_latent_dim(self): | ||
return self.__latent_dim | ||
|
||
def get_min_samples(self): | ||
return self.__min_samples | ||
|
||
def get_max_samples(self): | ||
return self.__max_samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import numpy as np | ||
|
||
|
||
def generate_noise(batch_size, latent_dim): | ||
noise = np.random.normal(0, 1, size=(batch_size, latent_dim)) | ||
noise = noise.astype(np.float32) | ||
return noise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import pandas as pd | ||
from joblib import load | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
from utils.configparser import ConfigParser | ||
|
||
|
||
class PostProcessor: | ||
def __init__(self, config_parser): | ||
self.__config_parser: ConfigParser = config_parser | ||
|
||
self.__continuous_scaler = StandardScaler(with_mean=True, with_std=True) | ||
self.__discrete_count_scaler = StandardScaler(with_mean=False, with_std=True) | ||
|
||
self.__continuous_cols = self.__config_parser.get_continuous_cols() | ||
self.__discrete_count_cols = self.__config_parser.get_discrete_count_cols() | ||
self.__binary_cols = self.__config_parser.get_binary_cols() | ||
|
||
self.__integer_cols = self.__config_parser.get_integer_cols() | ||
self.__two_dec_cols = self.__config_parser.get_two_dec_cols() | ||
self.__ptfive_cols = self.__config_parser.get_ptfive_cols() | ||
|
||
self.__smoker_col = self.__config_parser.get_smoker_col() | ||
self.__cigs_per_day_col = self.__config_parser.get_cigs_per_day_col() | ||
self.__sys_bp_col = self.__config_parser.get_sys_bp_col() | ||
self.__dia_bp_col = self.__config_parser.get_dia_bp_col() | ||
|
||
def fit_transform(self, data: pd.DataFrame): | ||
data_postprocessed = data.copy() | ||
|
||
data_postprocessed = self.round_cols(data_postprocessed, self.__integer_cols, 1) | ||
data_postprocessed = self.round_cols(data_postprocessed, self.__ptfive_cols, 0.5) | ||
data_postprocessed = self.round_cols(data_postprocessed, self.__two_dec_cols, 0.01) | ||
|
||
for col in data_postprocessed.columns: | ||
data_postprocessed[col] = data_postprocessed[col].apply(lambda x: max(0, x)) | ||
|
||
if self.__binary_cols is not None: | ||
for col in self.__binary_cols: | ||
data_postprocessed[col] = data_postprocessed[col].apply( | ||
lambda x: 0 if x < 0.5 else 1 | ||
) | ||
|
||
return data_postprocessed | ||
|
||
def round_cols(self, data, cols, round_to): | ||
if cols is not None: | ||
for col in cols: | ||
data[col] = data[col].apply(lambda x: round(x / round_to) * round_to) | ||
return data | ||
|
||
def reverse_scaling(self, data: pd.DataFrame, scale_binary=True): | ||
data_restored = data.copy() | ||
if self.__continuous_scaler is not None: | ||
data_restored[self.__continuous_cols] = self.__continuous_scaler.inverse_transform( | ||
data_restored[self.__continuous_cols] | ||
) | ||
if self.__discrete_count_scaler is not None: | ||
data_restored[self.__discrete_count_cols] = ( | ||
self.__discrete_count_scaler.inverse_transform( | ||
data_restored[self.__discrete_count_cols] | ||
) | ||
) | ||
if scale_binary and self.__binary_cols is not None: | ||
for col in self.__binary_cols: | ||
data_restored[col] = data_restored[col].apply(lambda x: 0 if x < 0.5 else 1) | ||
|
||
return data_restored | ||
|
||
def filter(self, data: pd.DataFrame, n_samples): | ||
cond_1 = (data[self.__smoker_col] == 0) & (data[self.__cigs_per_day_col] > 0) | ||
cond_2 = (data[self.__smoker_col] == 1) & (data[self.__cigs_per_day_col] == 0) | ||
cond_3 = data[self.__sys_bp_col] < data[self.__dia_bp_col] | ||
conditions = cond_1 | cond_2 | cond_3 | ||
data = data[~conditions] | ||
|
||
try: | ||
data = data.sample(n_samples) | ||
data.reset_index(drop=True, inplace=True) | ||
return data | ||
except ValueError: | ||
print("\nERROR: Not enough samples to filter, please try again.") | ||
|
||
def load_scalers(self, continuous_scaler_path, discrete_count_scaler_path): | ||
self.__continuous_scaler = load(continuous_scaler_path) | ||
self.__discrete_count_scaler = load(discrete_count_scaler_path) |