Skip to content

Commit

Permalink
Merge pull request #21 from MannLabs/update-metrics
Browse files Browse the repository at this point in the history
UPDATE result metrics table
  • Loading branch information
furkanmtorun authored May 10, 2023
2 parents a855732 + 35446b9 commit 6149895
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 18 deletions.
1 change: 1 addition & 0 deletions omiclearn/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def run():
file_path,
"--global.developmentMode=false",
"--browser.gatherUsageStats=False",
"--logger.level=error",
]

sys.argv = args
Expand Down
43 changes: 37 additions & 6 deletions omiclearn/utils/ml_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def perform_cross_validation(state, cohort_column=None):

for metric_name, metric_fct in scorer_dict.items():
_cv_results[metric_name] = []
_cv_results[metric_name + "_train"] = []
_cv_results["pr_auc"] = [] # ADD pr_auc manually
_cv_results["pr_auc_train"] = [] # ADD pr_auc manually

X = state.X
y = state.y
Expand Down Expand Up @@ -367,10 +369,20 @@ def perform_cross_validation(state, cohort_column=None):

calibrated_clf = CalibratedClassifierCV(clf, cv=cv_generator)
calibrated_clf.fit(X_train, y_train)

# Train
y_train_pred = calibrated_clf.predict(X_train)
y_train_pred_proba = calibrated_clf.predict_proba(X_train)
# Validation
y_pred = calibrated_clf.predict(X_test)
y_pred_proba = calibrated_clf.predict_proba(X_test)
else:
clf.fit(X_train, y_train)

# Train
y_train_pred = clf.predict(X_train)
y_train_pred_proba = clf.predict_proba(X_train)
# Validation
y_pred = clf.predict(X_test)
y_pred_proba = clf.predict_proba(X_test)

Expand All @@ -395,21 +407,42 @@ def perform_cross_validation(state, cohort_column=None):
feature_importance = None

# ROC CURVE
# Validation
fpr, tpr, cutoffs = roc_curve(y_test, y_pred_proba[:, 1])

# PR CURVE
# Train
precision_train, recall_train, _train = precision_recall_curve(
y_train, y_train_pred_proba[:, 1]
)
# Validation
precision, recall, _ = precision_recall_curve(y_test, y_pred_proba[:, 1])

for metric_name, metric_fct in scorer_dict.items():
if metric_name == "roc_auc":
# Train
_cv_results[metric_name + "_train"].append(
metric_fct(y_train, y_train_pred_proba[:, 1])
)
# Validation
_cv_results[metric_name].append(
metric_fct(y_test, y_pred_proba[:, 1])
)
elif metric_name in ["precision", "recall", "f1"]:
# Train
_cv_results[metric_name + "_train"].append(
metric_fct(y_train, y_train_pred, zero_division=0)
)
# Validation
_cv_results[metric_name].append(
metric_fct(y_test, y_pred, zero_division=0)
)
else:
# Train
_cv_results[metric_name + "_train"].append(
metric_fct(y_train, y_train_pred)
)
# Validation
_cv_results[metric_name].append(metric_fct(y_test, y_pred))

# Results of Cross Validation
Expand All @@ -423,12 +456,10 @@ def perform_cross_validation(state, cohort_column=None):
_cv_results["n_class_0_test"].append(np.sum(y_test))
_cv_results["n_class_1_test"].append(np.sum(~y_test))
_cv_results["class_ratio_test"].append(np.sum(y_test) / len(y_test))
_cv_results["pr_auc"].append(
auc(recall, precision)
) # ADD PR Curve AUC Score
_cv_curves["pr_auc"].append(
auc(recall, precision)
) # ADD PR Curve AUC Score
# Train PR Curve AUC Score
_cv_results["pr_auc_train"].append(auc(recall_train, precision_train))
# Validation PR Curve AUC Score
_cv_results["pr_auc"].append(auc(recall, precision))
_cv_curves["roc_curves_"].append((fpr, tpr, cutoffs))
_cv_curves["pr_curves_"].append((precision, recall, _))
_cv_curves["y_hats_"].append((y_test.values, y_pred))
Expand Down
5 changes: 1 addition & 4 deletions tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,13 @@ def test_load_data():
# csv
df = pd.DataFrame({"A": [1, 1], "B": [0, 0]})
csv_data, warnings = load_data("test_csv_c.csv", "Comma (,)")
print(csv_data)
pd.testing.assert_frame_equal(csv_data, df)

csv_data, warnings = load_data("test_csv_sc.csv", "Semicolon (;)")
print(csv_data)
pd.testing.assert_frame_equal(csv_data, df)

# TSV
tsv_data, warnings = load_data("test_tsv.tsv", "Tab (\\t) for TSV")
print(tsv_data)
pd.testing.assert_frame_equal(tsv_data, df)


Expand Down Expand Up @@ -153,10 +150,10 @@ def test_integration():
test_state["cv_repeats"] = 2
test_state["bar"] = st.progress(0)
test_state["features"] = ["AAA", "BBB", "CCC", "_study"]

# Generate X and y
main_analysis_run(test_state)

# print("\n", test_state, "\n")
_cv_results, _cv_curves = perform_cross_validation(test_state, cohort_column=None)
assert _cv_results == expected_cv_results, "Error in CV Results"
assert str(_cv_curves) == str(expected_cv_curves_str), "Error in CV Curves"
Expand Down
30 changes: 22 additions & 8 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,35 @@
0.8333333333333334,
0.8571428571428571,
0.8333333333333334,
1,
1.0,
],
"roc_auc": [0.875, 0.8333333333333333, 1, 0.875, 0.8333333333333333, 1],
"precision": [0.75, 1, 0.75, 0.75, 1, 1],
"recall": [1, 0.6666666666666666, 1, 1, 0.6666666666666666, 1],
"f1": [0.8571428571428571, 0.8, 0.8571428571428571, 0.8571428571428571, 0.8, 1],
"accuracy_train": [1.0, 1.0, 0.9230769230769231, 1.0, 1.0, 0.9230769230769231],
"roc_auc": [0.875, 0.8333333333333333, 1.0, 0.875, 0.8333333333333333, 1.0],
"roc_auc_train": [1.0, 1.0, 0.9761904761904763, 1.0, 1.0, 0.9761904761904763],
"precision": [0.75, 1.0, 0.75, 0.75, 1.0, 1.0],
"precision_train": [1.0, 1.0, 0.8571428571428571, 1.0, 1.0, 0.8571428571428571],
"recall": [1.0, 0.6666666666666666, 1.0, 1.0, 0.6666666666666666, 1.0],
"recall_train": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"f1": [0.8571428571428571, 0.8, 0.8571428571428571, 0.8571428571428571, 0.8, 1.0],
"f1_train": [1.0, 1.0, 0.923076923076923, 1.0, 1.0, 0.923076923076923],
"balanced_accuracy": [
0.875,
0.8333333333333333,
0.8333333333333333,
0.875,
0.8333333333333333,
1,
1.0,
],
"pr_auc": [0.875, 0.9166666666666666, 1, 0.875, 0.9166666666666666, 1],
"balanced_accuracy_train": [
1.0,
1.0,
0.9285714285714286,
1.0,
1.0,
0.9285714285714286,
],
"pr_auc": [0.875, 0.9166666666666666, 1.0, 0.875, 0.9166666666666666, 1.0],
"pr_auc_train": [1.0, 1.0, 0.9742063492063492, 1.0, 1.0, 0.9742063492063492],
}

expected_cv_curves_str = """{'pr_auc': [0.875, 0.9166666666666666, 1.0, 0.875, 0.9166666666666666, 1.0], 'roc_curves_': [(array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0. , 0. , 0.33333333, 1. ]), array([0., 1., 1., 1.]), array([1.8069754, 0.8069754, 0.502567 , 0.1422766], dtype=float32)), (array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0., 0., 1.]), array([0., 1., 1.]), array([1.8069754, 0.8069754, 0.1422766], dtype=float32))], 'pr_curves_': [(array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5 , 0.75, 1. , 1. ]), array([1., 1., 1., 0.]), array([0.1422766, 0.502567 , 0.8069754], dtype=float32)), (array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1., 1., 0.]), array([0.1422766, 0.8069754], dtype=float32))], 'y_hats_': [(array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 1, 0, 0])), (array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 0, 0, 0]))], 'feature_importances_': [{'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}], 'features_': []}"""
expected_cv_curves_str = """{'pr_auc': [], 'roc_curves_': [(array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0. , 0. , 0.33333333, 1. ]), array([0., 1., 1., 1.]), array([1.8069754, 0.8069754, 0.502567 , 0.1422766], dtype=float32)), (array([0. , 0.25, 1. ]), array([0., 1., 1.]), array([1.7956569 , 0.7956569 , 0.20434304], dtype=float32)), (array([0., 0., 1.]), array([0. , 0.66666667, 1. ]), array([1.8162205 , 0.8162206 , 0.15752529], dtype=float32)), (array([0., 0., 1.]), array([0., 1., 1.]), array([1.8069754, 0.8069754, 0.1422766], dtype=float32))], 'pr_curves_': [(array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5 , 0.75, 1. , 1. ]), array([1., 1., 1., 0.]), array([0.1422766, 0.502567 , 0.8069754], dtype=float32)), (array([0.42857143, 0.75 , 1. ]), array([1., 1., 0.]), array([0.20434304, 0.7956569 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1. , 0.66666667, 0. ]), array([0.15752529, 0.8162206 ], dtype=float32)), (array([0.5, 1. , 1. ]), array([1., 1., 0.]), array([0.1422766, 0.8069754], dtype=float32))], 'y_hats_': [(array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 1, 0, 0])), (array([ True, True, True, False, False, False, False]), array([1, 1, 1, 1, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 0, 0, 0, 0])), (array([ True, True, True, False, False, False]), array([1, 1, 1, 0, 0, 0]))], 'feature_importances_': [{'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'CCC': 1.0, 'BBB': 0.0, 'AAA': 0.0}, {'_study': 0.0, 'AAA': 0.0717181, 'CCC': 0.9282819, 'BBB': 0.0}], 'features_': []}"""

0 comments on commit 6149895

Please sign in to comment.