Skip to content

Commit

Permalink
added model selection and tested (#20)
Browse files Browse the repository at this point in the history
* added model selection and tested

* changed expected test outputs

since the new docker image/pycaret with extra models, the results have been changed
  • Loading branch information
qchiujunhao authored Jul 25, 2024
1 parent 88a264c commit 919f2a9
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 104 deletions.
6 changes: 5 additions & 1 deletion tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def setup_pycaret(self):

def train_model(self):
LOG.info("Training and selecting the best model")
self.best_model = self.exp.compare_models()
if hasattr(self, 'models') and self.models is not None:
self.best_model = self.exp.compare_models(
include=self.models)
else:
self.best_model = self.exp.compare_models()
self.results = self.exp.pull()

def save_model(self):
Expand Down
16 changes: 13 additions & 3 deletions tools/pycaret_train.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
<command>
<![CDATA[
python $__tool_directory__/pycaret_train.py --input_file $input_file --target_col $target_feature --output_dir "`pwd`" --model_type $model_type
#if $model_type == "classification"
#if $classification_models
--models $classification_models
#end if
#end if
#if $model_type == "regression"
#if $regression_models
--models $regression_models
#end if
#end if
#if $customize_defaults == "true"
#if $train_size
--train_size $train_size
Expand Down Expand Up @@ -42,12 +52,12 @@
<inputs>
<param name="input_file" type="data" format="csv,tabular" label="Input Dataset (CSV or TSV)" />
<param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
<!-- <conditional name="model_selection"> -->
<conditional name="model_selection">
<param name="model_type" type="select" label="Task">
<option value="classification">classification</option>
<option value="regression">regression</option>
</param>
<!-- <when value="classification">
<when value="classification">
<param name="classification_models" type="select" multiple="true" label="Only Select Classification Models if you don't want to compare all models">
<option value="lr">Logistic Regression</option>
<option value="knn">K Neighbors Classifier</option>
Expand Down Expand Up @@ -98,7 +108,7 @@
<option value="catboost">CatBoost Regressor</option>
</param>
</when>
</conditional> -->
</conditional>
<conditional name="advanced_settings">
<param name="customize_defaults" type="select" label="Customize Default Settings?" help="Select yes if you want to customize the default settings of the experiment.">
<option value="false" selected="true">No</option>
Expand Down
156 changes: 67 additions & 89 deletions tools/test-data/expected_comparison_result_regression.html

Large diffs are not rendered by default.

Loading

0 comments on commit 919f2a9

Please sign in to comment.