Skip to content

Commit

Permalink
Adds InvertLogarithmPreprocessorTest to invert the binary RMSLE optim…
Browse files Browse the repository at this point in the history
…ization objective.

PiperOrigin-RevId: 660728458
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Aug 8, 2024
1 parent 78b1889 commit 8922093
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
"""Init module for TensorFlow Model Analysis related preprocssors."""
from tensorflow_model_analysis.metrics.preprocessors import utils
from tensorflow_model_analysis.metrics.preprocessors.image_preprocessors import DecodeImagePreprocessor
from tensorflow_model_analysis.metrics.preprocessors.invert_logarithm_preprocessors import InvertBinaryLogarithmPreprocessor
from tensorflow_model_analysis.metrics.preprocessors.object_detection_preprocessors import BoundingBoxMatchPreprocessor
from tensorflow_model_analysis.metrics.preprocessors.set_match_preprocessors import SetMatchPreprocessor
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Includes preprocessors for log2 inversion transformation."""

from typing import Iterator, Optional

import numpy as np
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.metrics import metric_types
from tensorflow_model_analysis.metrics import metric_util
from tensorflow_model_analysis.utils import util

_INVERT_BINARY_LOGARITHM_PREPROCESSOR_BASE_NAME = (
'invert_binary_logarithm_preprocessor'
)


def _invert_log2_values(
log_values: np.ndarray,
) -> np.ndarray:
"""Invert the binary logarithm and return an ndarray."""
# We invert the following formula: log_2(y_pred + 1.0)
return np.power(2.0, log_values) - 1.0


class InvertBinaryLogarithmPreprocessor(metric_types.Preprocessor):
"""Read label and prediction from binary logarithm to numpy array."""

def __init__(
self,
name: Optional[str] = None,
model_name: str = '',
):
"""Initialize the preprocessor for binary logarithm inversion.
Args:
name: (Optional) name for the preprocessor.
model_name: (Optional) model name (if multi-model evaluation).
"""
if not name:
name = metric_util.generate_private_name_from_arguments(
_INVERT_BINARY_LOGARITHM_PREPROCESSOR_BASE_NAME
)
super().__init__(name=name)
self._model_name = model_name

def _read_label_or_prediction_in_multiple_dicts(
self,
key: str,
extracts: util.StandardExtracts,
) -> np.ndarray:
"""Reads and inverts the binary logarithm from extracts."""
if key == constants.LABELS_KEY:
value = extracts.get_labels(self._model_name)
else:
value = extracts.get_predictions(self._model_name)
return _invert_log2_values(value)

def process(
self, extracts: types.Extracts
) -> Iterator[metric_types.StandardMetricInputs]:
"""Reads and inverts the binary logarithm from extracts.
It will search in labels/predictions, features and transformed features.
Args:
extracts: A tfma extract contains the regression data.
Yields:
A standard metric input contains the following key and values:
- {'labels'}: A numpy array represents the regressed values.
- {'predictions'}: A numpy array represents the regression predictions.
- {'example_weights'}: (Optional) A numpy array represents the example
weights.
"""
extracts = util.StandardExtracts(extracts)

extracts[constants.LABELS_KEY] = (
self._read_label_or_prediction_in_multiple_dicts(
constants.LABELS_KEY, extracts
)
)

extracts[constants.PREDICTIONS_KEY] = (
self._read_label_or_prediction_in_multiple_dicts(
constants.PREDICTIONS_KEY,
extracts,
)
)

if (
extracts[constants.LABELS_KEY].shape
!= extracts[constants.PREDICTIONS_KEY].shape
):
raise ValueError(
'The size of ground truth '
f'{extracts[constants.LABELS_KEY].shape} does not match '
'with the size of prediction '
f'{extracts[constants.PREDICTIONS_KEY].shape}'
)

yield metric_util.to_standard_metric_inputs(extracts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for invert logarithm preprocessors."""

from absl.testing import absltest
from absl.testing import parameterized
import apache_beam as beam
from apache_beam.testing import util as beam_testing_util
import numpy as np
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.metrics.preprocessors import invert_logarithm_preprocessors
from tensorflow_model_analysis.utils import util


class InvertBinaryLogarithmPreprocessorTest(parameterized.TestCase):

def setUp(self):
super().setUp()
values = np.array([[1, 2, 4], [1, 2, 4]], dtype=np.int32)
processed_values = np.array([[1, 3, 15], [1, 3, 15]], dtype=np.float32)

self._extract_inputs = [{
constants.LABELS_KEY: values,
constants.PREDICTIONS_KEY: values,
}]
self._expected_processed_inputs = [
util.StandardExtracts({
constants.LABELS_KEY: processed_values,
constants.PREDICTIONS_KEY: processed_values,
})
]

def testInvertBinaryLogarithmPreprocessor(self):
with beam.Pipeline() as pipeline:
updated_pcoll = (
pipeline
| 'Create' >> beam.Create(self._extract_inputs)
| 'Preprocess'
>> beam.ParDo(
invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor()
)
)

def check_result(result):
# Only single extract case is tested
self.assertLen(result, len(self._expected_processed_inputs))
for updated_extracts, expected_input in zip(
result, self._expected_processed_inputs
):
self.assertIn(constants.PREDICTIONS_KEY, updated_extracts)
np.testing.assert_allclose(
updated_extracts[constants.PREDICTIONS_KEY],
expected_input[constants.PREDICTIONS_KEY],
)
self.assertIn(constants.LABELS_KEY, updated_extracts)
np.testing.assert_allclose(
updated_extracts[constants.LABELS_KEY],
expected_input[constants.LABELS_KEY],
)
if constants.EXAMPLE_WEIGHTS_KEY in expected_input:
self.assertIn(constants.EXAMPLE_WEIGHTS_KEY, updated_extracts)
np.testing.assert_allclose(
updated_extracts[constants.EXAMPLE_WEIGHTS_KEY],
expected_input[constants.EXAMPLE_WEIGHTS_KEY],
)

beam_testing_util.assert_that(updated_pcoll, check_result)

def testName(self):
preprocessor = (
invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor()
)
self.assertEqual(
preprocessor.name, '_invert_binary_logarithm_preprocessor:'
)

def testLabelPreidictionSizeMismatch(self):
extracts = {
constants.LABELS_KEY: np.array([[1, 2]]),
constants.PREDICTIONS_KEY: np.array([[1, 2, 3]]),
}
with self.assertRaisesRegex(ValueError, 'does not match'):
_ = next(
invert_logarithm_preprocessors.InvertBinaryLogarithmPreprocessor().process(
extracts=extracts
)
)


if __name__ == '__main__':
absltest.main()

0 comments on commit 8922093

Please sign in to comment.