Skip to content

Commit

Permalink
include num_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
nmcardoso committed Aug 23, 2023
1 parent b455bcf commit 1937843
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion mergernet/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ def get_dataaug_block(

def get_metric(self, metric: str):
if metric == 'f1':
return tfa.metrics.F1Score(name='f1', average='weighted')
return tfa.metrics.F1Score(
num_classes=self.dataset.config.n_classes,
name='f1',
average='weighted'
)
# return tf.keras.metrics.F1Score(name='f1', average='weighted') # tf 2.13
elif metric == 'precision':
return tf.keras.metrics.Precision(name='precision')
Expand Down

0 comments on commit 1937843

Please sign in to comment.