Skip to content

Commit

Permalink
Make sklearn visualisations to support validation datasets.
Browse files Browse the repository at this point in the history
  • Loading branch information
tlapusan authored and parrt committed Apr 6, 2023
1 parent 8c15ccc commit 113ec46
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
13 changes: 12 additions & 1 deletion dtreeviz/models/sklearn_decision_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,19 @@ def get_node_feature(self, id) -> int:
return self.tree_model.tree_.feature[id]

def get_node_nsamples_by_class(self, id):
# This is the code to return the nsamples/class from tree metadata. It's faster, but the visualisations cannot
# be made on new datasets.
# if self.is_classifier():
# return self.tree_model.tree_.value[id][0]

# This code allows us to return the nsamples/class based on a dataset, train or validation
if self.is_classifier():
return self.tree_model.tree_.value[id][0]
all_nodes = self.internal + self.leaves
node_value = [node.n_sample_classes() for node in all_nodes if node.id == id]
if self.get_class_weights() is None:
return node_value[0]
else:
return node_value[0] * self.get_class_weights()

def get_prediction(self, id):
if self.is_classifier():
Expand Down
4 changes: 4 additions & 0 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,10 @@ def _class_leaf_viz(node: ShadowDecTreeNode,
counts = node.class_counts()
prediction = node.prediction_name()

# when using another dataset than the training dataset, some leaves could have 0 samples.
# Trying to make a pie chart will raise some deprecation
if sum(counts) == 0:
return
if leaftype == 'pie':
_draw_piechart(counts, size=size, colors=colors, filename=filename, label=f"n={nsamples}\n{prediction}",
graph_colors=graph_colors, fontname=fontname)
Expand Down

0 comments on commit 113ec46

Please sign in to comment.