Skip to content

Commit

Permalink
resolved a bug in feature importance ana
Browse files Browse the repository at this point in the history
  • Loading branch information
qchiujunhao committed Oct 18, 2024
1 parent d21412f commit 590f692
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
17 changes: 13 additions & 4 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,16 @@ def __init__(
def load_data(self):
LOG.info(f"Loading data from {self.input_file}")
self.data = pd.read_csv(self.input_file, sep=None, engine='python')
self.data = self.data.apply(pd.to_numeric, errors='coerce')
self.data.columns = self.data.columns.str.replace('.', '_')

numeric_cols = self.data.select_dtypes(include=['number']).columns
non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns

self.data[numeric_cols] = self.data[numeric_cols].apply(pd.to_numeric, errors='coerce')

if len(non_numeric_cols) > 0:
LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")

names = self.data.columns.to_list()
target_index = int(self.target_col)-1
self.target = names[target_index]
Expand All @@ -71,12 +80,12 @@ def load_data(self):
else:
# Default strategy if not specified
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.data.columns = self.data.columns.str.replace('.', '_')


if self.test_file:
LOG.info(f"Loading test data from {self.test_file}")
self.test_data = pd.read_csv(self.test_file, sep=None, engine='python')
self.test_data = self.test_data.apply(pd.to_numeric, errors='coerce')
self.test_data = self.test_data[numeric_cols].apply(pd.to_numeric, errors='coerce')
self.test_data.columns = self.test_data.columns.str.replace('.', '_')


Expand Down Expand Up @@ -221,7 +230,7 @@ def save_html_report(self):
"""

analyzer = FeatureImportanceAnalyzer(
data=self.data,
data=self.exp.X_transformed,
target_col=self.target_col,
task_type=self.task_type,
output_dir=self.output_dir)
Expand Down
7 changes: 5 additions & 2 deletions tools/feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def setup_pycaret(self):
def save_tree_importance(self):
model = self.exp.create_model('rf')
importances = model.feature_importances_
processed_features = self.exp.get_config('X_transformed').columns
LOG.debug(f"Feature importances: {importances}")
LOG.debug(f"Features: {processed_features}")
feature_importances = pd.DataFrame({
'Feature': self.data.columns.drop(self.target),
'Feature': processed_features,
'Importance': importances
}).sort_values(by='Importance', ascending=False)
plt.figure(figsize=(10, 6))
Expand All @@ -77,7 +80,7 @@ def save_tree_importance(self):
self.output_dir,
'tree_importance.png')
plt.savefig(plot_path)
plt.close()
plt.close()
self.plots['tree_importance'] = plot_path

def save_shap_values(self):
Expand Down
9 changes: 4 additions & 5 deletions tools/pycaret_train.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@
#if $fix_imbalance
--fix_imbalance
#end if
#end if
--model_type $model_type
#if $test_file
--test_file $test_file &&
#end if
#if $test_file
--test_file $test_file
#end if
--model_type $model_type &&
mkdir -p $comparison_result.extra_files_path &&
cp -r best_model.csv $comparison_result.extra_files_path
]]>
Expand Down

0 comments on commit 590f692

Please sign in to comment.