-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
integrate feature importance analysis into comparing (#21)
* init feature_importance * integrate result into comparing result * changed the title of plots * changed for flask8 * resolved bugs and added tests for best_model.csv * updated the test file * clear for lint
- Loading branch information
1 parent
919f2a9
commit 59ec798
Showing
18 changed files
with
1,126 additions
and
438 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,6 @@ env | |
*.pkl | ||
*.ipynb | ||
__pycache__ | ||
.DS_Store | ||
.DS_Store | ||
tool_test_output.html | ||
tool_test_output.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import base64 | ||
import logging | ||
import os | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
import pandas as pd | ||
|
||
from pycaret.classification import ClassificationExperiment | ||
from pycaret.regression import RegressionExperiment | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
class FeatureImportanceAnalyzer: | ||
def __init__( | ||
self, | ||
task_type, | ||
output_dir, | ||
data_path=None, | ||
data=None, | ||
target_col=None): | ||
|
||
if data is not None: | ||
self.data = data | ||
LOG.info("Data loaded from memory") | ||
else: | ||
self.target_col = target_col | ||
self.data = pd.read_csv(data_path, sep=None, engine='python') | ||
self.data.columns = self.data.columns.str.replace('.', '_') | ||
self.data = self.data.fillna(self.data.median(numeric_only=True)) | ||
self.task_type = task_type | ||
self.target = self.data.columns[int(target_col) - 1] | ||
self.exp = ClassificationExperiment() \ | ||
if task_type == 'classification' \ | ||
else RegressionExperiment() | ||
self.plots = {} | ||
self.output_dir = output_dir | ||
|
||
def setup_pycaret(self): | ||
LOG.info("Initializing PyCaret") | ||
setup_params = { | ||
'target': self.target, | ||
'session_id': 123, | ||
'html': True, | ||
'log_experiment': False, | ||
'system_log': False | ||
} | ||
LOG.info(self.task_type) | ||
LOG.info(self.exp) | ||
self.exp.setup(self.data, **setup_params) | ||
|
||
def save_coefficients(self): | ||
model = self.exp.create_model('lr') | ||
coef_df = pd.DataFrame({ | ||
'Feature': self.data.columns.drop(self.target), | ||
'Coefficient': model.coef_[0] | ||
}) | ||
coef_html = coef_df.to_html(index=False) | ||
return coef_html | ||
|
||
def save_tree_importance(self): | ||
model = self.exp.create_model('rf') | ||
importances = model.feature_importances_ | ||
feature_importances = pd.DataFrame({ | ||
'Feature': self.data.columns.drop(self.target), | ||
'Importance': importances | ||
}).sort_values(by='Importance', ascending=False) | ||
plt.figure(figsize=(10, 6)) | ||
plt.barh( | ||
feature_importances['Feature'], | ||
feature_importances['Importance']) | ||
plt.xlabel('Importance') | ||
plt.title('Feature Importance (Random Forest)') | ||
plot_path = os.path.join( | ||
self.output_dir, | ||
'tree_importance.png') | ||
plt.savefig(plot_path) | ||
plt.close() | ||
self.plots['tree_importance'] = plot_path | ||
|
||
def save_shap_values(self): | ||
model = self.exp.create_model('lightgbm') | ||
import shap | ||
explainer = shap.Explainer(model) | ||
shap_values = explainer.shap_values( | ||
self.data.drop(columns=[self.target])) | ||
shap.summary_plot(shap_values, self.data.drop( | ||
columns=[self.target]), show=False) | ||
plt.title('Shap (LightGBM)') | ||
plot_path = os.path.join( | ||
self.output_dir, 'shap_summary.png') | ||
plt.savefig(plot_path) | ||
plt.close() | ||
self.plots['shap_summary'] = plot_path | ||
|
||
def generate_feature_importance(self): | ||
coef_html = self.save_coefficients() | ||
self.save_tree_importance() | ||
self.save_shap_values() | ||
return coef_html | ||
|
||
def encode_image_to_base64(self, img_path): | ||
with open(img_path, 'rb') as img_file: | ||
return base64.b64encode(img_file.read()).decode('utf-8') | ||
|
||
def generate_html_report(self, coef_html): | ||
LOG.info("Generating HTML report") | ||
|
||
# Read and encode plot images | ||
plots_html = "" | ||
for plot_name, plot_path in self.plots.items(): | ||
encoded_image = self.encode_image_to_base64(plot_path) | ||
plots_html += f""" | ||
<div class="plot" id="{plot_name}"> | ||
<h2>Feature importance analysis from a | ||
trained Random Forest</h2> | ||
<h3>{'Use gini impurity for' | ||
'calculating feature importance for classification' | ||
'and Variance Reduction for regression' | ||
if plot_name == 'tree_importance' | ||
else 'SHAP Summary from a trained lightgbm'}</h3> | ||
<img src="data:image/png;base64, | ||
{encoded_image}" alt="{plot_name}"> | ||
</div> | ||
""" | ||
|
||
# Generate HTML content with tabs | ||
html_content = f""" | ||
<h1>PyCaret Feature Importance Report</h1> | ||
<div id="coefficients" class="tabcontent"> | ||
<h2>Coefficients (based on a trained | ||
{'Logistic Regression' | ||
if self.task_type == 'classification' | ||
else 'Linear Regression'} Model)</h2> | ||
<div>{coef_html}</div> | ||
</div> | ||
{plots_html} | ||
""" | ||
|
||
return html_content | ||
|
||
def run(self): | ||
LOG.info("Running feature importance analysis") | ||
self.setup_pycaret() | ||
coef_html = self.generate_feature_importance() | ||
html_content = self.generate_html_report(coef_html) | ||
LOG.info("Feature importance analysis completed") | ||
return html_content | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
parser = argparse.ArgumentParser(description="Feature Importance Analysis") | ||
parser.add_argument( | ||
"--data_path", type=str, help="Path to the dataset") | ||
parser.add_argument( | ||
"--target_col", type=int, | ||
help="Index of the target column (1-based)") | ||
parser.add_argument( | ||
"--task_type", type=str, | ||
choices=["classification", "regression"], | ||
help="Task type: classification or regression") | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
help="Directory to save the outputs") | ||
args = parser.parse_args() | ||
|
||
analyzer = FeatureImportanceAnalyzer( | ||
args.data_path, args.target_col, | ||
args.task_type, args.output_dir) | ||
analyzer.run() |
Oops, something went wrong.