Skip to content

Commit

Permalink
Enhance Testing Coverage for keras/metrics/accuracy_metrics.py (ker…
Browse files Browse the repository at this point in the history
…as-team#19429)

* Add tests  SparseCategoricalAccuracy class

* Add tests for sparsecategoricalaccuracy class

* Add tests for categorical_accuracy

* Add test for matching shapes without squeeze in SparseCategoricalAccuracy

* Add tests for accuracy metrics
  • Loading branch information
Faisal-Alsrheed authored Apr 3, 2024
1 parent 5f31cac commit 4b612a3
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions keras/metrics/accuracy_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,97 @@ def test_weighted(self):
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 0.3, atol=1e-3)

def test_squeeze_y_true(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
# Scenario with 100% accuracy for simplicity.
# y_true is a 2D tensor with shape (3, 1) to test squeeze.
y_true = np.array([[0], [1], [2]])
y_pred = np.array(
[[0.9, 0.05, 0.05], [0.05, 0.9, 0.05], [0.05, 0.05, 0.9]]
)
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 1.0, atol=1e-4)

def test_cast_y_pred_dtype(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
# Scenario with 100% accuracy for simplicity.
# y_true is a 1D tensor with shape (2,) to test cast.
y_true = np.array([0, 1], dtype=np.int64)
y_pred = np.array([[0.9, 0.1], [0.1, 0.9]], dtype=np.float32)
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 1.0, atol=1e-4)

def test_reshape_matches(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
# Scenario with 100% accuracy for simplicity.
# y_true is a 2D tensor with shape (2, 1) to test reshape.
y_true = np.array([[0], [0]], dtype=np.int64)
y_pred = np.array(
[[[0.9, 0.1, 0.0], [0.8, 0.15, 0.05]]], dtype=np.float32
)
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, np.array([1.0, 1.0]))

def test_squeeze_y_true_shape(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
# True labels are in the shape (num_samples, 1) should be squeezed.
y_true = np.array([[0], [1], [2]])
y_pred = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 1.0, atol=1e-4)

def test_cast_y_pred_to_match_y_true_dtype(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
# True labels are integers, while predictions are floats.
y_true = np.array([0, 1, 2], dtype=np.int32)
y_pred = np.array(
[[0.9, 0.1, 0.0], [0.0, 0.9, 0.1], [0.1, 0.0, 0.9]],
dtype=np.float64,
)
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 1.0, atol=1e-4)

def test_reshape_matches_to_original_y_true_shape(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
# True labels have an additional dimension that needs to be squeezed.
y_true = np.array([[0], [1]])
# Predictions must trigger a reshape of matches.
y_pred = np.array([[0.9, 0.1], [0.1, 0.9]])
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 1.0, atol=1e-4)

def test_matching_shapes_without_squeeze(self):
sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(
name="sparse_categorical_accuracy", dtype="float32"
)
y_true = np.array([2, 1, 0], dtype=np.int32)
y_pred = np.array(
[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
dtype=np.float32,
)
# No need to squeeze or reshape.
sp_cat_acc_obj.update_state(y_true, y_pred)
result = sp_cat_acc_obj.result()
self.assertAllClose(result, 1.0, atol=1e-4)


class TopKCategoricalAccuracyTest(testing.TestCase):
def test_config(self):
Expand Down

0 comments on commit 4b612a3

Please sign in to comment.