-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
366 lines (266 loc) · 17 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score
from xgboost import XGBClassifier
import shap
import eli5
from eli5.sklearn import PermutationImportance
from pdpbox import pdp
st.set_page_config(layout="wide", page_title='Explaining Heart Diseases ML Model')
st.set_option('deprecation.showPyplotGlobalUse', False)
shap.initjs()
header = st.container()
dataset = st.container()
model = st.container()
explainable = st.container()
@st.cache_data(persist="disk")
def read_data():
"""
Read the dataset
@return: dataframe with the renamed column names
"""
data = pd.read_csv('data/heart.csv')
data.columns = ['age', 'sex', 'chest_pain_agnia', 'resting_blood_pressure', 'cholestrol', 'fasting_blood_sugar',
'resting_ecg', 'max_heart_rate', 'exercise_induced_agnia', 'st_depression_rt_rest', 'slope',
'number_of_major_vessels', 'thalassemia', 'target']
return data
@st.cache_data(persist="disk")
def train_test_split_data(df):
"""
One hot encode the dataframe and return the train/test split
"""
data_catg = df.copy()
data_catg['chest_pain_agnia'] = df['chest_pain_agnia'].map(
{0: "asymptomatic", 1: "typical", 2: "atypical", 3: "non_anginal"})
data_catg['sex'] = df['sex'].map({0: "female", 1: "male"})
data_catg['exercise_induced_agnia'] = df['exercise_induced_agnia'].map({0: "false", 1: "true"})
data_catg['slope'] = df['slope'].map({1: "upsloping", 2: "flat", 3: "downsloping"})
data_catg['thalassemia'] = df['thalassemia'].map({1: "normal", 2: "fixed_defect", 3: "reversable_defect"})
data_catg['resting_ecg'] = df['resting_ecg'].map(
{0: 'normal', 1: 'st_wave_abnormal', 2: 'left_ventricular_hypertrophy'})
data_catg['fasting_blood_sugar'] = df['fasting_blood_sugar'].map({0: '<=120mg/ml', 1: '>120mg/ml'})
df = pd.get_dummies(data_catg, drop_first=True)
X = df.drop("target", axis=1).values
y = df["target"].astype("float").values
# column is used for the charts below
encoded_df_column_list = df.columns.drop('target')
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=32)
return X_train, X_test, y_train, y_test, encoded_df_column_list
st.sidebar.markdown("""
**Author**: Rishu Shrivastava
**Last Published**: 01-Aug-2021
**Feature Detailed Description**: [Kaggle Heart Disease Dataset](https://www.kaggle.com/ronitf/heart-disease-uci)
**Codebase**: [Github Repo.: explainable-ai-app](https://github.com/rishuatgithub/explainable-ai-app)
**Report**: [Issues](https://github.com/rishuatgithub/explainable-ai-app/issues)
**References**:
- [Kaggle Explainable AI Course](https://www.kaggle.com/learn/machine-learning-explainability)
- [Interpretable Machine Learning](https://christophm.github.io/interpretable-ml-book/)
- [Kaggle Notebooks](https://www.kaggle.com/chingchunyeh/heart-disease-report)
""")
with header:
st.title("Explaining Heart Diseases ML Model")
st.markdown("""
Many people say machine learning models are **black boxes**, in the sense that they can make good predictions but you can't understand the logic behind those predictions. This statement is true in the sense that most data scientists don't know how to extract insights from models yet.
This interactive application explains the presence of heart disease in a person based on [Heart Disease UCI](https://archive.ics.uci.edu/ml/datasets/Heart+Disease) dataset using **Explainable AI** technique.
""")
with dataset:
st.header("**Dataset**")
st.markdown("""
This database contains 76 attributes, but all published experiments refer to using a subset of 14 of them.
In particular, the Cleveland database is the only one that has been used by ML researchers to this date.
The **target** attribute refers to the presence of heart disease in the person.
""")
df = read_data()
st.write(df.head())
st.markdown("""
The dataset is having some features with **categorical** dataset.
We will apply [dummy encoding](https://pandas.pydata.org/docs/reference/api/pandas.get_dummies.html) technique to convert the categorical data to binary features.
""")
st.text("""
Note: There are many feature engineering techniques that could be applied on this dataset.
For the purpose of this exercise, we will not deep dive into feature engineering other than encoding done above.
""")
with model:
st.header("**Model Training**")
st.markdown("""
We will use **XGBoost classifier** algorithm to train on the dataset based on the selection of parameters.
The dataset uses 80:20 train-test split ratio of data.
""")
# splitting the dataset
X_train, X_test, y_train, y_test, encoded_df_column_list = train_test_split_data(df)
model_col1, model_col2, model_col3 = st.columns((1, 1, 2))
model_max_depth = model_col1.slider("Max Depth", min_value=3, max_value=10, step=1, value=5)
model_learning_rate = model_col1.selectbox("Learning Rate", options=[0.001, 0.01, 0.1, 0.3, 0.5], index=1)
model_estimators = model_col1.selectbox("Number of estimators", options=[100, 300, 400, 500, 700], index=3)
model = XGBClassifier(max_depth=model_max_depth,
learning_rate=model_learning_rate,
n_estimators=model_estimators,
use_label_encoder=False,
#enable_categorical=True,
n_jobs=1)
eval_set = [(X_train, y_train), (X_test, y_test)]
model_train = model.fit(X_train, y_train, eval_metric=['error', 'logloss'], eval_set=eval_set, verbose=False)
y_pred = model_train.predict(X_test)
y_proba = model_train.predict_proba(X_test)[:, 1]
accuracy_score = accuracy_score(y_pred, y_test)
tp, fn, fp, tn = confusion_matrix(y_test, y_pred, labels=[1, 0]).ravel()
precision_rate = tp / (tp + fp)
recall_rate = tp / (tp + fn)
# F1 = 2 * (precision * recall) / (precision + recall)
f1_score = 2 * (precision_rate * recall_rate) / (precision_rate + recall_rate)
results = model_train.evals_result()
epochs = len(results['validation_0']['error'])
x_axis = range(0, epochs)
model_col2.subheader("**Model Accuracy**")
model_col2.write(round(accuracy_score * 100, 2))
model_col2.subheader("**Model Precision**")
model_col2.write(round(precision_rate * 100, 2))
model_col2.subheader("**Model Recall**")
model_col2.write(round(recall_rate * 100, 2))
model_col2.subheader("**Model F1 Score**")
model_col2.write(round(f1_score * 100, 2))
fig = go.Figure()
fig.add_trace(go.Scatter(x=list(x_axis), y=results['validation_0']['logloss'], name='Train'))
fig.add_trace(go.Scatter(x=list(x_axis), y=results['validation_1']['logloss'], name='Test'))
fig.update_layout(title='<b>Model Loss</b>',
margin=dict(l=1, r=1, b=0),
height=400,
width=600,
xaxis_title="Epochs",
yaxis_title="Model Loss")
model_chart_error_loss = fig
model_col3.plotly_chart(model_chart_error_loss)
st.text("_Note: All the prediction values are displayed in percentage (%)")
with explainable:
st.header("**Explaining the Model**")
st.markdown("""
In the above section, XGBoost model predicted some output score based on the heart disease prediction dataset.
Based on the initial model parameters, a **F1 score** of `""" + str(round(f1_score * 100, 2)) + """%` was
achieved based on the validation/test data. However, there are open questions that would easily come to mind:
1. What features in the data did the model **think are most important**?
2. How did each **feature** in the data **affect a particular prediction**?
3. How does each feature **affect** the model's predictions **over a larger dataset**?
In the field of **Explainable AI**, by definition, [explainability](
https://en.wikipedia.org/wiki/Explainable_artificial_intelligence) is considered as "the collection of
features of the interpretable domain, that have contributed for a given example to produce a decision (e.g.,
classification or regression)". If algorithms meet these requirements, they provide a basis for justifying
decisions, tracking and thereby verifying them, improving the algorithms, and exploring new facts.
Based on the above trained model, let us try to answer the above three basic questions.
""")
feature_dict = dict(enumerate(encoded_df_column_list))
features_list = encoded_df_column_list
st.markdown("### **Permutation Importance**")
pi_col1, pi_col2 = st.columns(2)
pi_col1.markdown("""
**_What are the features that have the biggest impact on the prediction?_**
This concept of finding the feature importance is called Permutation Importance. This technique is fast to
calculate and easy to understand. The feature importance is calculated based on the trained model.
The values on the top are the most important features, and those at the bottom are the least. On the our
dataset, _(based on the initial model parameters)_ the top 3 most important features are
`number_of_major_vessel`, `cholestrol` and `st_depression_rt_rest`.
Model thinks that the presence of blood diseases like thalessemmia, higher cholestrol levels in a persons are
some of the key reasons for having a heart disease. According to the [NHS - UK website](
https://www.nhs.uk/conditions/cardiovascular-disease/), heart related diseases are caused by having some
pre-conditions in the blood of a person.
""")
perm = PermutationImportance(model_train, random_state=1).fit(X_test, y_test)
permutation_imp_chart = eli5.show_weights(perm, feature_names=list(feature_dict.values())).data
pi_col2.markdown(permutation_imp_chart.replace('\n', ''), unsafe_allow_html=True)
st.markdown("### **Partial Dependence Plots**")
pp_col1, pp_col2 = st.columns(2)
selected_feature = pp_col2.selectbox('Feature Attributes', features_list, index=5)
pp_col1.markdown("""
_How a feature effects a prediction?_
Partial Dependence Plots (PDP) show the dependence between the target response and a set of input features of
interest, marginalizing over the values of all other input features (the ‘complement’ features). Intuitively,
we can interpret the partial dependence as the expected target response as a function of the input features
of interest. PDP are also used in Google's [What-If Tool](
https://pair-code.github.io/what-if-tool/learn/tutorials/walkthrough/). The target here is to predict the
heart related diseases.
In PDPs, the y-axis or the feature column predicts the **change in prediction** from what it would be
predicted at the baseline or left-most value.
Let's look into one of the feature: `number_of_major_vessels`. With the increase in the number of vessels,
the model thinks the probability of having a heart diseases decreases.
For feature: `max_heart_rate` the chance of having a heart disease increases with the increase in the heart
rate. """)
X_test_df = pd.DataFrame(X_test).rename(columns=feature_dict)
fig = plt.figure()
#pdp_dist = pdp.PDPIsolate(model=model,
# df=X_test_df,
# model_features=features_list,
# feature_name=selected_feature,
# feature=selected_feature)
#pdp_dist = pdp.plot(pdp_dist, selected_feature)
#pdp_dist = plt.show()
#pp_col2.pyplot(pdp_dist, bbox_inches='tight')
st.markdown("### **SHAP (SHapley Additive exPlanations)** ")
st.markdown("""
_How much a prediction was driven by the fact that a person's_ `max_heart_rate` _is greater than 120?_
A prediction can be explained by assuming that each feature value of the instance is a “player” in a game
where the prediction is the payout. [Shapley values](
https://christophm.github.io/interpretable-ml-book/shapley.html) – a method from coalition game theory –
tells us how to fairly distribute the “payout” among the features. SHAP values interpret the impact of having
a certain value for a given feature in comparison to the prediction we'd make if that feature took some
baseline value.
The SHAP values provide two great advantages:
- **Global interpretability**: The SHAP values can show how much each predictor contributes,
either positively or negatively, to the target variable. This is like the partial dependence plot but it is
able to show the positive or negative relationship for each variable with the target. - **Local
interpretability**: Each observation gets its own set of SHAP values. This greatly increases its
transparency. We can explain why a case receives its prediction and the contributions of the predictors.
Traditional variable importance algorithms only show the results across the entire population but not on each
individual case. The local interpretability enables us to pinpoint and contrast the impacts of the factors.
""")
shap_col1, shap_col2 = st.columns((1.5, 2))
shap_col1.markdown("""The chart on the right hand side shows the individual feature contribution towards
predicting the model's output.
If you select a `Person: 1` from the selection box, the model predicted `-2.74`, whereas the base value is
`0.50`. Feature values causing increased predictions are in _pink_, and their visual size shows the magnitude
of the feature's effect. Feature values decreasing the prediction are in blue. The biggest impact comes from
`number_of_major_vessel` being `2`. As we found out in partial dependence plot, having more number of blood
vessels in heart decreases the chance of having a heart related diseases.
However, if you interpret `Person: 2`, the model predicted a shap value of `+1.32` against the base value of
`0.50`. This person is at a high risk of having a heart disease and most contributing feature increasing the
chance of this score are `sex_male` and `thalassemia_reversible_defect`. This prediction sounds good as being
a male with thalassemia disease does increase the chance of heart disease.
""")
select_person = shap_col2.selectbox('Select the Person', range(1, len(X_test)), index=0)
# shap person plot
select_person_row = pd.DataFrame(X_test).rename(columns=feature_dict).iloc[[select_person]]
def plot_force_shap_values(model, patient_data):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(patient_data)
shap.initjs()
return shap.force_plot(explainer.expected_value,
shap_values,
patient_data,
matplotlib=True,
show=False,
feature_names=list(feature_dict.values()),
text_rotation=10)
shap_plt = plot_force_shap_values(model, select_person_row.values)
shap_col2.pyplot(shap_plt)
plt.clf()
shap_col3, shap_col4 = st.columns(2)
shap_col4.markdown("""
**SHAP Summary plot** provides an overall view of the feature contribution across a larger set of data.
The summary plot on the left hand side has many dots. Each dot has the following characteristics:
- The Vertical location shows what feature it is depicting
- Color shows whether that feature was high or low for that row of the dataset
- Horizontal location shows whether the effect of that value caused a higher or lower prediction.
In our heart prediction model, `thalassemia_normal` does not quite contribute to the overall model
prediction. However, features like `age` might contribute to the increase in prediction _(more the age,
more is the chances of having disease)_ on specific cases. But on a birds-eye view across all the dataset,
it doesn't quite play a significant role. The same would go for the `fasting_blood_sugar` and `cholestrol`.
The use of SHAP plots does provides us with an overall understanding of the features contributions across a
larger dataset. This inturn helps in taking informative approach towards the feature engineering.
""")
# summary plot
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
summary_plot = shap.summary_plot(shap_values, X_test, feature_names=list(feature_dict.values()))
shap_col3.pyplot(summary_plot)
plt.clf()