Skip to content

Commit

Permalink
Added project
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Miesenboeck committed May 23, 2024
1 parent 3b515ae commit 98ea233
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 1 deletion.
69 changes: 68 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,68 @@
# clinical-study-GAN
# patGAN

patGAN is a Python tool that allows you to generate synthetic clinical study data in the form of individual synthetic patients. The data is generated using a **GAN** (Generative Adversarial Network) that was trained on a subset of the **Framingham Heart Study** data. The individual patients consist of *continuous* and *binary* demographic and clinical features, including vital parameter measurements, lifestyle choices, and medical history. The generated data is exported to a CSV file.

## Requirements

Before using patGAN, make sure your environment meets the following requirements:

- `Python 3.x` (The tool was tested with Python 3.11)
- Required dependencies (See installation guide below)

## Installation

1. Clone or download this repository to your local machine.

2. Navigate to the root directory of the project.

3. Install the required dependencies from the `requirements.txt` file:

`pip install -r requirements.txt`

## How to Use

To use patGAN, follow these steps:

1. Make sure you have the required dependencies installed.

2. Open a terminal or command prompt in the root directory of your project.

3. Run the CLI by executing the following command:
`python patgan.py [OPTIONS]`

4. You can get an overview of all available options by displaying the help section:
`python patgan.py -h`



## CLI Options

The patGAN CLI supports the following options:

- `-n` or `--n_patients`: Specifies the number of patients to generate. The default value is 100.
- `-o` or `--output`: Specifies the name of the output CSV file. The default name is `generated_patients.csv`

## Examples

Here are some example usages of the patGAN CLI:

1. Generate the default amount of patients (100):
```
python patGAN.py
```

2. Generate a specific amount of patients (e.g. 1000):
```
python patGAN.py -n 1000
```
3. Generate the default amount of patients (100) and specify the output CSV name:
```
python patGAN.py -o custom_name
```

4. Generate a specific amount of patients (e.g. 5000) and specify the output CSV name:
```
python patGAN.py -n 5000 -o my_patients
```

18 changes: 18 additions & 0 deletions config/config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[COLUMNS]
continuous_columns = age,totChol,sysBP,diaBP,BMI,heartRate,glucose
discrete_count_columns = cigsPerDay
binary_columns = male,smoker,BPMeds,prevStroke,hypertension,diabetes

[POSTPROCESSING]
integer_columns = age,totChol,heartRate,glucose,cigsPerDay
two_decimal_columns = BMI
ptfive_round_columns = sysBP,diaBP
smoker_column = smoker
cigs_per_day_column = cigsPerDay
sys_bp_column = sysBP
dia_bp_column = diaBP

[MODEL]
latent_dim = 8
min_samples = 10
max_samples = 1000000
Binary file added config/continuous_scaler.bin
Binary file not shown.
Binary file added config/discrete_count_scaler.bin
Binary file not shown.
Binary file added model/model.onnx
Binary file not shown.
79 changes: 79 additions & 0 deletions patgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse

import onnxruntime
import pandas as pd

import utils.helpers as helpers
from utils.configparser import ConfigParser
from utils.postprocessor import PostProcessor

MODEL_PATH = "model/model.onnx"
CONFIG_PATH = "config/config.ini"
CONT_SCALER_PATH = "config/continuous_scaler.bin"
DC_SCALER_PATH = "config/discrete_count_scaler.bin"


class PatGAN:
def __init__(self, config_parser: ConfigParser) -> None:
self.__model = self.load_model()
self.__config_parser = config_parser
self.__postprocessor = PostProcessor(self.__config_parser)
self.__postprocessor.load_scalers(CONT_SCALER_PATH, DC_SCALER_PATH)

def load_model(self):
return onnxruntime.InferenceSession(MODEL_PATH)

def generate_samples(self, n_samples):
noise = helpers.generate_noise(n_samples * 3, self.__config_parser.get_latent_dim())
generated_samples = self.__model.run(None, {"args_0": noise})[0]
generated_samples = pd.DataFrame(
generated_samples,
columns=self.__config_parser.get_continuous_cols()
+ self.__config_parser.get_discrete_count_cols()
+ self.__config_parser.get_binary_cols(),
)
generated_samples = self.__postprocessor.reverse_scaling(generated_samples)
generated_samples = self.__postprocessor.fit_transform(generated_samples)
generated_samples = self.__postprocessor.filter(generated_samples, n_samples)
return generated_samples

def export_samples(self, samples, output):
samples.to_csv(output, index=False)

def parse_arguments(self):
parser = argparse.ArgumentParser(
description="patGAN - Generate synthetic clinical study data in the form of individual patients (CSV file)."
)
parser.add_argument(
"-n",
"--n_patients",
metavar="\b",
type=int,
default=100,
help="The number of patients to generate. Default is 100.",
)
parser.add_argument("-o", "--output", metavar="\b", help="The name of the output CSV file.")
args = parser.parse_args()
return args


if __name__ == "__main__":
config_parser = ConfigParser(CONFIG_PATH)
gan = PatGAN(config_parser)
min_samples, max_samples = config_parser.get_min_samples(), config_parser.get_max_samples()
args = gan.parse_arguments()

args.output = args.output or "generated_patients"
if not args.output.endswith(".csv"):
args.output += ".csv"
if args.n_patients < min_samples or args.n_patients > max_samples:
print(f"\nERROR: The number of patients must be between {min_samples} and {max_samples}.")
else:
print(f"\nGenerating {args.n_patients} patients...")
generated_samples = gan.generate_samples(args.n_patients)
if generated_samples is None:
exit(1)
gan.export_samples(generated_samples, args.output)
print(f"Generated patients saved to {args.output}")
print()
print(generated_samples)
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
joblib==1.4.2
numpy==1.26.4
onnxruntime==1.16.3
pandas==2.1.3
scikit-learn==1.5.0
61 changes: 61 additions & 0 deletions utils/configparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import configparser


class ConfigParser:
def __init__(self, config_path):
self.__config = configparser.ConfigParser()
self.__config.read(config_path)
self.__continuous_cols = self.__config.get("COLUMNS", "continuous_columns").split(",")
self.__discrete_count_cols = self.__config.get("COLUMNS", "discrete_count_columns").split(
","
)
self.__binary_cols = self.__config.get("COLUMNS", "binary_columns").split(",")
self.__integer_cols = self.__config.get("POSTPROCESSING", "integer_columns").split(",")
self.__two_dec_cols = self.__config.get("POSTPROCESSING", "two_decimal_columns").split(",")
self.__ptfive_cols = self.__config.get("POSTPROCESSING", "ptfive_round_columns").split(",")
self.__smoker_col = self.__config.get("POSTPROCESSING", "smoker_column")
self.__cigs_per_day_col = self.__config.get("POSTPROCESSING", "cigs_per_day_column")
self.__sys_bp_col = self.__config.get("POSTPROCESSING", "sys_bp_column")
self.__dia_bp_col = self.__config.get("POSTPROCESSING", "dia_bp_column")
self.__latent_dim = self.__config.getint("MODEL", "latent_dim")
self.__min_samples = self.__config.getint("MODEL", "min_samples")
self.__max_samples = self.__config.getint("MODEL", "max_samples")

def get_continuous_cols(self):
return self.__continuous_cols

def get_discrete_count_cols(self):
return self.__discrete_count_cols

def get_binary_cols(self):
return self.__binary_cols

def get_integer_cols(self):
return self.__integer_cols

def get_two_dec_cols(self):
return self.__two_dec_cols

def get_ptfive_cols(self):
return self.__ptfive_cols

def get_smoker_col(self):
return self.__smoker_col

def get_cigs_per_day_col(self):
return self.__cigs_per_day_col

def get_sys_bp_col(self):
return self.__sys_bp_col

def get_dia_bp_col(self):
return self.__dia_bp_col

def get_latent_dim(self):
return self.__latent_dim

def get_min_samples(self):
return self.__min_samples

def get_max_samples(self):
return self.__max_samples
7 changes: 7 additions & 0 deletions utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np


def generate_noise(batch_size, latent_dim):
noise = np.random.normal(0, 1, size=(batch_size, latent_dim))
noise = noise.astype(np.float32)
return noise
86 changes: 86 additions & 0 deletions utils/postprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pandas as pd
from joblib import load
from sklearn.preprocessing import StandardScaler

from utils.configparser import ConfigParser


class PostProcessor:
def __init__(self, config_parser):
self.__config_parser: ConfigParser = config_parser

self.__continuous_scaler = StandardScaler(with_mean=True, with_std=True)
self.__discrete_count_scaler = StandardScaler(with_mean=False, with_std=True)

self.__continuous_cols = self.__config_parser.get_continuous_cols()
self.__discrete_count_cols = self.__config_parser.get_discrete_count_cols()
self.__binary_cols = self.__config_parser.get_binary_cols()

self.__integer_cols = self.__config_parser.get_integer_cols()
self.__two_dec_cols = self.__config_parser.get_two_dec_cols()
self.__ptfive_cols = self.__config_parser.get_ptfive_cols()

self.__smoker_col = self.__config_parser.get_smoker_col()
self.__cigs_per_day_col = self.__config_parser.get_cigs_per_day_col()
self.__sys_bp_col = self.__config_parser.get_sys_bp_col()
self.__dia_bp_col = self.__config_parser.get_dia_bp_col()

def fit_transform(self, data: pd.DataFrame):
data_postprocessed = data.copy()

data_postprocessed = self.round_cols(data_postprocessed, self.__integer_cols, 1)
data_postprocessed = self.round_cols(data_postprocessed, self.__ptfive_cols, 0.5)
data_postprocessed = self.round_cols(data_postprocessed, self.__two_dec_cols, 0.01)

for col in data_postprocessed.columns:
data_postprocessed[col] = data_postprocessed[col].apply(lambda x: max(0, x))

if self.__binary_cols is not None:
for col in self.__binary_cols:
data_postprocessed[col] = data_postprocessed[col].apply(
lambda x: 0 if x < 0.5 else 1
)

return data_postprocessed

def round_cols(self, data, cols, round_to):
if cols is not None:
for col in cols:
data[col] = data[col].apply(lambda x: round(x / round_to) * round_to)
return data

def reverse_scaling(self, data: pd.DataFrame, scale_binary=True):
data_restored = data.copy()
if self.__continuous_scaler is not None:
data_restored[self.__continuous_cols] = self.__continuous_scaler.inverse_transform(
data_restored[self.__continuous_cols]
)
if self.__discrete_count_scaler is not None:
data_restored[self.__discrete_count_cols] = (
self.__discrete_count_scaler.inverse_transform(
data_restored[self.__discrete_count_cols]
)
)
if scale_binary and self.__binary_cols is not None:
for col in self.__binary_cols:
data_restored[col] = data_restored[col].apply(lambda x: 0 if x < 0.5 else 1)

return data_restored

def filter(self, data: pd.DataFrame, n_samples):
cond_1 = (data[self.__smoker_col] == 0) & (data[self.__cigs_per_day_col] > 0)
cond_2 = (data[self.__smoker_col] == 1) & (data[self.__cigs_per_day_col] == 0)
cond_3 = data[self.__sys_bp_col] < data[self.__dia_bp_col]
conditions = cond_1 | cond_2 | cond_3
data = data[~conditions]

try:
data = data.sample(n_samples)
data.reset_index(drop=True, inplace=True)
return data
except ValueError:
print("\nERROR: Not enough samples to filter, please try again.")

def load_scalers(self, continuous_scaler_path, discrete_count_scaler_path):
self.__continuous_scaler = load(continuous_scaler_path)
self.__discrete_count_scaler = load(discrete_count_scaler_path)

0 comments on commit 98ea233

Please sign in to comment.