Skip to content

Commit

Permalink
Update ruff and fix more issues
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Apr 25, 2024
1 parent cde2eaa commit e2a1d90
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 6,720 deletions.
25 changes: 22 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 # Use the ref you want to point at
rev: v4.6.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -16,7 +16,7 @@ repos:
- id: check-toml

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.3.7'
rev: 'v0.4.2'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -25,7 +25,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.10.0
hooks:
- id: mypy
entry: python3 -m mypy --config-file pyproject.toml
Expand All @@ -41,3 +41,22 @@ repos:
language: system
pass_filenames: false
always_run: true

- repo: local
hooks:
- id: nbstripout
name: nbstripout
language: system
entry: python3 -m nbstripout

ci:
autofix_commit_msg: |
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: [pytest,nbstripout,mypy]
submodules: false
89 changes: 17 additions & 72 deletions odyssey/data/DataProcessor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,32 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:14:45.546088300Z",
"start_time": "2024-03-13T16:14:43.587090300Z"
},
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import pickle\n",
"import random\n",
"from typing import Any, Dict, List, Optional\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"\n",
"ROOT = \"/h/afallah/odyssey/odyssey\"\n",
"DATA_ROOT = f\"{ROOT}/odyssey/data/bigbird_data\"\n",
"DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n",
"MAX_LEN = 2048\n",
"\n",
"os.chdir(ROOT)\n",
"\n",
"from odyssey.utils.utils import seed_everything\n",
"from odyssey.data.processor import (\n",
" filter_by_num_visit,\n",
" filter_by_length_of_stay,\n",
" get_last_occurence_index,\n",
" check_readmission_label,\n",
" get_length_of_stay,\n",
" get_visit_cutoff_at_threshold,\n",
" process_length_of_stay_dataset,\n",
" get_finetune_split,\n",
" process_condition_dataset,\n",
" process_length_of_stay_dataset,\n",
" process_mortality_dataset,\n",
" process_readmission_dataset,\n",
" process_multi_dataset,\n",
" stratified_train_test_split,\n",
" sample_balanced_subset,\n",
" get_pretrain_test_split,\n",
" get_finetune_split,\n",
" process_readmission_dataset,\n",
")\n",
"from odyssey.utils.utils import seed_everything\n",
"\n",
"\n",
"SEED = 23\n",
"seed_everything(seed=SEED)"
Expand All @@ -54,13 +37,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:15:12.321718600Z",
"start_time": "2024-03-13T16:14:45.553089800Z"
},
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Load complete dataset\n",
Expand Down Expand Up @@ -95,13 +72,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:15:16.075719400Z",
"start_time": "2024-03-13T16:15:12.335721100Z"
},
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Process the dataset for mortality in two weeks or one month task\n",
Expand All @@ -111,13 +82,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:15:47.326996100Z",
"start_time": "2024-03-13T16:15:16.094719300Z"
},
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Process the dataset for hospital readmission in one month task\n",
Expand Down Expand Up @@ -212,17 +177,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:15:51.800996200Z",
"start_time": "2024-03-13T16:15:50.494996100Z"
},
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Get finetune split\n",
"for task in task_config.keys():\n",
"for task in task_config:\n",
" patient_ids_dict = get_finetune_split(\n",
" task_config=task_config,\n",
" task=task,\n",
Expand All @@ -233,13 +192,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T14:14:10.181184300Z",
"start_time": "2024-03-13T14:13:39.154567400Z"
},
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"dataset_2048_mortality.to_parquet(\n",
Expand Down Expand Up @@ -365,9 +318,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Assuming dataset.event_tokens is your DataFrame column\n",
Expand All @@ -383,9 +334,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# len(patient_ids_dict['group3']['cv'])\n",
Expand All @@ -402,9 +351,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"##### DEAD ZONE | DO NOT ENTER #####\n",
Expand All @@ -424,9 +371,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]\n",
Expand Down
20 changes: 5 additions & 15 deletions odyssey/evals/CompareAUROC-Poster.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
Expand All @@ -19,9 +17,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Import dependencies and define useful constants\n",
Expand Down Expand Up @@ -51,9 +47,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Load predictions, labels, and probabilities of different models\n",
Expand Down Expand Up @@ -81,9 +75,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Plot ROC Curve for XGBoost, Bi-LSTM, and Transformer\n",
Expand All @@ -103,9 +95,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"# Plot Information\n",
Expand Down
Loading

0 comments on commit e2a1d90

Please sign in to comment.