Skip to content

Commit

Permalink
Regression (#3)
Browse files Browse the repository at this point in the history
* add regression

restructure classes

* better html report
  • Loading branch information
qchiujunhao authored Jun 13, 2024
1 parent 2ea7a84 commit 96251f6
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 122 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ env
*.txt
*.pkl
*.html
*.ipynb
*.ipynb
__pycache__
164 changes: 164 additions & 0 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import sys
import pandas as pd
import os
import logging
import base64

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)

class BaseModelTrainer:
def __init__(self, input_file, target_col, output_dir):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
self.output_dir = output_dir
self.data = None
self.target = None
self.best_model = None
self.results = None
self.plots = {}

def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
self.data = pd.read_csv(self.input_file, sep=None, engine='python')
names = self.data.columns.to_list()
self.target = names[int(self.target_col)-1]
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.data.columns = self.data.columns.str.replace('.', '_')

def setup_pycaret(self):
LOG.info("Initializing PyCaret")
self.exp.setup(self.data, target=self.target,
session_id=123, html=True,
log_experiment=False, system_log=False)

def train_model(self):
LOG.info("Training and selecting the best model")
self.best_model = self.exp.compare_models()
self.results = self.exp.pull()

def save_model(self):
LOG.info("Saving the model")
self.exp.save_model(self.best_model, "model.pkl")

def generate_plots(self):
raise NotImplementedError("Subclasses should implement this method")

def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')

def save_html_report(self):
LOG.info("Saving HTML report")

model_name = type(self.best_model).__name__

# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)

# Save comparison results
self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))

# Read and encode plot images
plots_html = ""
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
</div>
"""

# Generate HTML content
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>PyCaret Model Training Report</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f4f4f4;
}}
.container {{
max-width: 800px;
margin: auto;
background: white;
padding: 20px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
text-align: center;
color: #333;
}}
h2 {{
border-bottom: 2px solid #4CAF50;
color: #4CAF50;
padding-bottom: 5px;
}}
table {{
width: 100%;
border-collapse: collapse;
margin: 20px 0;
}}
table, th, td {{
border: 1px solid #ddd;
}}
th, td {{
padding: 8px;
text-align: left;
}}
th {{
background-color: #4CAF50;
color: white;
}}
.plot {{
text-align: center;
margin: 20px 0;
}}
.plot img {{
max-width: 100%;
height: auto;
}}
</style>
</head>
<body>
<div class="container">
<h1>PyCaret Model Training Report</h1>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False, header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
{self.results.to_html(index=False, classes='table')}
</table>
<h2>Plots</h2>
{plots_html}
</div>
</body>
</html>
"""

with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)

def save_dashboard(self):
raise NotImplementedError("Subclasses should implement this method")

def run(self):
self.load_data()
self.setup_pycaret()
self.train_model()
self.save_model()
self.generate_plots()
self.save_html_report()
self.save_dashboard()
66 changes: 65 additions & 1 deletion tools/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)

def generate_dashboard(
def generate_classifier_explainer_dashboard(
exp,
estimator,
display_format: str = "dash",
Expand Down Expand Up @@ -76,4 +76,68 @@ def generate_dashboard(
)
return ExplainerDashboard(
explainer, mode=display_format, contributions=False, whatif=False, **dashboard_kwargs
)

def generate_regression_explainer_dashboard(
exp,
estimator,
display_format: str = "dash",
dashboard_kwargs: Optional[Dict[str, Any]] = None,
run_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
This function is changed from pycaret.regression.oop.dashboard()
This function generates the interactive dashboard for a trained model. The
dashboard is implemented using ExplainerDashboard (explainerdashboard.readthedocs.io)
estimator: scikit-learn compatible object
Trained model object
display_format: str, default = 'dash'
Render mode for the dashboard. The default is set to ``dash`` which will
render a dashboard in browser. There are four possible options:
- 'dash' - displays the dashboard in browser
- 'inline' - displays the dashboard in the jupyter notebook cell.
- 'jupyterlab' - displays the dashboard in jupyterlab pane.
- 'external' - displays the dashboard in a separate tab. (use in Colab)
dashboard_kwargs: dict, default = {} (empty dict)
Dictionary of arguments passed to the ``ExplainerDashboard`` class.
run_kwargs: dict, default = {} (empty dict)
Dictionary of arguments passed to the ``run`` method of ``ExplainerDashboard``.
**kwargs:
Additional keyword arguments to pass to the ``ClassifierExplainer`` or
``RegressionExplainer`` class.
Returns:
ExplainerDashboard
"""

dashboard_kwargs = dashboard_kwargs or {}
run_kwargs = run_kwargs or {}

from explainerdashboard import ExplainerDashboard, RegressionExplainer

# Replaceing chars which dash doesnt accept for column name `.` , `{`, `}`
X_test_df = exp.X_test_transformed.copy()
X_test_df.columns = [
col.replace(".", "__").replace("{", "__").replace("}", "__")
for col in X_test_df.columns
]
explainer = RegressionExplainer(
estimator, X_test_df, exp.y_test_transformed, **kwargs
)
return ExplainerDashboard(
explainer, mode=display_format, contributions=False, whatif=False, shap_interaction=False, decision_trees=False, **dashboard_kwargs
)
129 changes: 11 additions & 118 deletions tools/pycaret_classification.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,23 @@
import sys
import pandas as pd
from base_model_trainer import BaseModelTrainer
from pycaret.classification import ClassificationExperiment
import os
import logging
from dashboard import generate_dashboard
from jinja_report.generate_report import main as generate_report
import base64
from dashboard import generate_classifier_explainer_dashboard

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)

class ModelTrainer:
class ClassificationModelTrainer(BaseModelTrainer):
def __init__(self, input_file, target_col, output_dir):
super().__init__(input_file, target_col, output_dir)
self.exp = ClassificationExperiment()
self.input_file = input_file
self.target_col = target_col
self.output_dir = output_dir
self.data = None
self.target = None
self.best_model = None
self.results = None
self.plots = {}

def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
self.data = pd.read_csv(self.input_file, sep=None, engine='python')
names = self.data.columns.to_list()
self.target = names[int(self.target_col)-1]
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.data.columns = self.data.columns.str.replace('.', '_')

def setup_pycaret(self):
LOG.info("Initializing PyCaret")
self.exp.setup(self.data, target=self.target,
session_id=123, html=True,
log_experiment=False, system_log=False)

def train_model(self):
LOG.info("Training and selecting the best model")
self.best_model = self.exp.compare_models()
self.results = self.exp.pull()

def save_model(self):
LOG.info("Saving the model")
self.exp.save_model(self.best_model, "model")

def save_dashboard(self):
LOG.info("Saving explainer dashboard")
dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model)
dashboard.save_html("dashboard.html")

def generate_plots(self):
LOG.info("Generating and saving plots")
# Generate PyCaret plots
plots = ['auc', 'confusion_matrix',
'threshold',
'pr', 'error',
'class_report', 'learning',
'calibration', 'vc',
'dimension',
'manifold', 'rfe',
'feature', 'feature_all']
plots = ['auc', 'confusion_matrix', 'threshold', 'pr', 'error', 'class_report', 'learning', 'calibration', 'vc', 'dimension', 'manifold', 'rfe', 'feature', 'feature_all']
for plot_name in plots:
plot_path = self.exp.plot_model(self.best_model, plot=plot_name,
save=True)
plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True)
self.plots[plot_name] = plot_path

def encode_image_to_base64(self, img_path):
with open(img_path, 'rb') as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')

def save_html_report(self):
LOG.info("Saving HTML report")

model_name = type(self.best_model).__name__

report_data = {
"title": "PyCaret Model Training Report",
'Best Model': [
{
'type': 'table',
'src': os.path.join(self.output_dir, 'best_model.csv'),
'label': f'Best Model: {model_name}'
}
],
'Comparison Results': [
{
'type': 'table',
'src': os.path.join(self.output_dir, 'comparison_results.csv'),
'label': 'Comparison Result <br> The scoring grid with average cross-validation scores'
}
],
"Plots": []
}

# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)

# Save comparison results
self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))

# Add plots to the report data
for plot_name, plot_path in self.plots.items():
encoded_image = self.encode_image_to_base64(plot_path)
report_data['Plots'].append({
'type': 'html',
'src': f'data:image/png;base64,{encoded_image}',
'label': plot_name.capitalize()
})

generate_report(inputs=report_data, outfile=os.path.join(self.output_dir, "comparison_result.html"))

def save_dashboard(self):
LOG.info("Saving explainer dashboard")
dashboard = generate_dashboard(self.exp, self.best_model)
dashboard.save_html("dashboard.html")

def run(self):
self.load_data()
self.setup_pycaret()
self.train_model()
self.save_model()
self.generate_plots()
self.save_html_report()
self.save_dashboard()

if __name__ == "__main__":
input_file = sys.argv[1]
target_col = sys.argv[2]
output_dir = sys.argv[3]

trainer = ModelTrainer(input_file, target_col, output_dir)
trainer.run()
Loading

0 comments on commit 96251f6

Please sign in to comment.