Skip to content

Commit

Permalink
using tfa
Browse files Browse the repository at this point in the history
  • Loading branch information
nmcardoso committed Aug 22, 2023
1 parent 86ed7ad commit 85316a4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mergernet/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List, Tuple, Union

import tensorflow as tf
import tensorflow_addons as tfa

from mergernet.core.constants import RANDOM_SEED
from mergernet.core.experiment import Experiment
Expand Down Expand Up @@ -175,7 +176,8 @@ def get_dataaug_block(

def get_metric(self, metric: str):
if metric == 'f1':
return tf.keras.metrics.F1Score(name='f1', average='weighted')
return tfa.metrics.F1Score(name='f1', average='weighted')
# return tf.keras.metrics.F1Score(name='f1', average='weighted') # tf 2.13
elif metric == 'precision':
return tf.keras.metrics.Precision(name='precision')
elif metric == 'recall':
Expand Down
2 changes: 2 additions & 0 deletions mergernet/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
import tensorflow_addons as tfa

tfa.register_all()

0 comments on commit 85316a4

Please sign in to comment.