Skip to content

Commit

Permalink
updates for pr actions
Browse files Browse the repository at this point in the history
1. formart for planemo test and lint, and flask8 lint
2. found and fixed bugs during planemo test
3. added test-data files
  • Loading branch information
qchiujunhao committed Jul 16, 2024
1 parent 5696cf1 commit 9bc0a76
Show file tree
Hide file tree
Showing 18 changed files with 2,751 additions and 128 deletions.
403 changes: 403 additions & 0 deletions .github/workflows/pr.yaml

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
.env
env
*.csv
*.tsv
*.log
*.png
*.txt
*.pkl
*.html
*.ipynb
__pycache__
__pycache__
.DS_Store
Binary file added galaxy-master.tar.gz
Binary file not shown.
74 changes: 49 additions & 25 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import sys
import pandas as pd
import os
import logging
import logging
import base64

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:

def __init__(self, input_file, target_col, output_dir, **kwargs):
self.exp = None # This will be set in the subclass
self.input_file = input_file
Expand All @@ -21,7 +22,7 @@ def __init__(self, input_file, target_col, output_dir, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.setup_params = {}

LOG.info(f"Model kwargs: {self.__dict__}")

def load_data(self):
Expand All @@ -48,26 +49,35 @@ def setup_pycaret(self):
if hasattr(self, 'normalize') and self.normalize is not None:
self.setup_params['normalize'] = self.normalize

if hasattr(self, 'feature_selection') and self.feature_selection is not None:
if hasattr(self, 'feature_selection') and \
self.feature_selection is not None:
self.setup_params['feature_selection'] = self.feature_selection

if hasattr(self, 'cross_validation') and self.cross_validation is not None and self.cross_validation == False:
if hasattr(self, 'cross_validation') and \
self.cross_validation is not None \
and self.cross_validation is False:
self.setup_params['cross_validation'] = self.cross_validation

if hasattr(self, 'cross_validation') and self.cross_validation is not None:
if hasattr(self, 'cross_validation') and \
self.cross_validation is not None:
if hasattr(self, 'cross_validation_folds'):
self.setup_params['fold'] = self.cross_validation_folds

if hasattr(self, 'remove_outliers') and self.remove_outliers is not None:
if hasattr(self, 'remove_outliers') and \
self.remove_outliers is not None:
self.setup_params['remove_outliers'] = self.remove_outliers

if hasattr(self, 'remove_multicollinearity') and self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = self.remove_multicollinearity
if hasattr(self, 'remove_multicollinearity') and \
self.remove_multicollinearity is not None:
self.setup_params['remove_multicollinearity'] = \
self.remove_multicollinearity

if hasattr(self, 'polynomial_features') and self.polynomial_features is not None:
if hasattr(self, 'polynomial_features') and \
self.polynomial_features is not None:
self.setup_params['polynomial_features'] = self.polynomial_features

if hasattr(self, 'fix_imbalance') and self.fix_imbalance is not None:
if hasattr(self, 'fix_imbalance') and \
self.fix_imbalance is not None:
self.setup_params['fix_imbalance'] = self.fix_imbalance

LOG.info(self.setup_params)
Expand All @@ -93,17 +103,25 @@ def save_html_report(self):
LOG.info("Saving HTML report")

model_name = type(self.best_model).__name__

excluded_params = ['html', 'log_experiment', 'system_log']
filtered_setup_params = {k: v for k, v in self.setup_params.items() if k not in excluded_params}
setup_params_table = pd.DataFrame(list(filtered_setup_params.items()), columns=['Parameter', 'Value'])

filtered_setup_params = {
k: v
for k, v in self.setup_params.items() if k not in excluded_params
}
setup_params_table = pd.DataFrame(
list(filtered_setup_params.items()),
columns=['Parameter', 'Value'])
# Save model summary
best_model_params = pd.DataFrame(self.best_model.get_params().items(), columns=['Parameter', 'Value'])
best_model_params.to_csv(os.path.join(self.output_dir, 'best_model.csv'), index=False)
best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
columns=['Parameter', 'Value'])
best_model_params.to_csv(
os.path.join(self.output_dir, 'best_model.csv'),
index=False)

# Save comparison results
self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
self.results.to_csv(os.path.join(
self.output_dir, "comparison_results.csv"))

# Read and encode plot images
plots_html = ""
Expand All @@ -112,7 +130,8 @@ def save_html_report(self):
plots_html += f"""
<div class="plot">
<h3>{plot_name.capitalize()}</h3>
<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
<img src="data:image/png;base64,
{encoded_image}" alt="{plot_name}">
</div>
"""

Expand All @@ -122,7 +141,8 @@ def save_html_report(self):
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="viewport" content="width=device-width,
initial-scale=1.0">
<title>PyCaret Model Training Report</title>
<style>
body {{
Expand Down Expand Up @@ -179,16 +199,19 @@ def save_html_report(self):
<h2>Setup Parameters</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{setup_params_table.to_html(index=False, header=False, classes='table')}
{setup_params_table.to_html(index=False,
header=False, classes='table')}
</table>
<h2>Best Model: {model_name}</h2>
<table>
<tr><th>Parameter</th><th>Value</th></tr>
{best_model_params.to_html(index=False, header=False, classes='table')}
{best_model_params.to_html(index=False,
header=False, classes='table')}
</table>
<h2>Comparison Results</h2>
<table>
{self.results.to_html(index=False, classes='table')}
{self.results.to_html(index=False,
classes='table')}
</table>
<h2>Plots</h2>
{plots_html}
Expand All @@ -197,12 +220,13 @@ def save_html_report(self):
</html>
"""

with open(os.path.join(self.output_dir, "comparison_result.html"), "w") as file:
with open(os.path.join(
self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)

def save_dashboard(self):
raise NotImplementedError("Subclasses should implement this method")

def run(self):
self.load_data()
self.setup_pycaret()
Expand Down
Loading

0 comments on commit 9bc0a76

Please sign in to comment.