diff --git a/tensorflow_model_analysis/extractors/inference_base_test.py b/tensorflow_model_analysis/extractors/inference_base_test.py index 7f0c643109..f89d13f780 100644 --- a/tensorflow_model_analysis/extractors/inference_base_test.py +++ b/tensorflow_model_analysis/extractors/inference_base_test.py @@ -19,18 +19,12 @@ import os -import apache_beam as beam -from apache_beam.testing import util -import numpy as np import tensorflow as tf from tensorflow_model_analysis import constants -from tensorflow_model_analysis.api import model_eval_lib -from tensorflow_model_analysis.eval_saved_model import testutil -from tensorflow_model_analysis.eval_saved_model.example_trainers import fixed_prediction_estimator_extra_fields from tensorflow_model_analysis.extractors import features_extractor from tensorflow_model_analysis.extractors import inference_base -from tensorflow_model_analysis.extractors import tfx_bsl_predictions_extractor from tensorflow_model_analysis.proto import config_pb2 +from tensorflow_model_analysis.utils import test_util as testutil from tfx_bsl.tfxio import tensor_adapter from tfx_bsl.tfxio import test_util @@ -76,114 +70,6 @@ def _create_tfxio_and_feature_extractor( ) return tfx_io, feature_extractor - def testRegressionModel(self): - temp_export_dir = self._getExportDir() - export_dir, _ = ( - fixed_prediction_estimator_extra_fields.simple_fixed_prediction_estimator_extra_fields( - temp_export_dir, None - ) - ) - - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model_1', signature_name='serving_default' - ) - ] - ) - eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING] - ) - tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor( - eval_config, - text_format.Parse( - """ - feature { - name: "prediction" - type: FLOAT - } - feature { - name: "label" - type: FLOAT - } - feature { - name: "fixed_int" - type: INT - } - feature { - name: "fixed_float" - type: FLOAT - } - feature { - name: "fixed_string" - type: BYTES - } - """, - schema_pb2.Schema(), - ), - ) - - examples = [ - self._makeExample( - prediction=0.2, - label=1.0, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string1', - ), - self._makeExample( - prediction=0.8, - label=0.0, - fixed_int=1, - fixed_float=1.0, - fixed_string='fixed_string2', - ), - self._makeExample( - prediction=0.5, - label=0.0, - fixed_int=2, - fixed_float=1.0, - fixed_string='fixed_string3', - ), - ] - num_examples = len(examples) - - tfx_bsl_inference_ptransform = inference_base.RunInference( - tfx_bsl_predictions_extractor.TfxBslInferenceWrapper( - eval_config.model_specs, {'': eval_shared_model} - ), - output_batch_size=num_examples, - ) - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | 'Create' - >> beam.Create( - [e.SerializeToString() for e in examples], reshuffle=False - ) - | 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples) - | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts() - | feature_extractor.stage_name >> feature_extractor.ptransform - | 'RunInferenceBase' >> tfx_bsl_inference_ptransform - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 1) - self.assertIn(constants.PREDICTIONS_KEY, got[0]) - self.assertAllClose( - np.array([[0.2], [0.8], [0.5]]), got[0][constants.PREDICTIONS_KEY] - ) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - def testIsValidConfigForBulkInferencePass(self): saved_model_proto = text_format.Parse( """ @@ -231,7 +117,7 @@ def testIsValidConfigForBulkInferencePass(self): ] ) eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=temp_dir.full_path, + model_path=temp_dir.full_path, model_name='model_1', tags=[tf.saved_model.SERVING], model_type=constants.TF_GENERIC, @@ -286,7 +172,7 @@ def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self): model_specs=[config_pb2.ModelSpec(name='model_1')] ) eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=temp_dir.full_path, + model_path=temp_dir.full_path, model_name='model_1', tags=[tf.saved_model.SERVING], model_type=constants.TF_GENERIC, @@ -343,7 +229,7 @@ def testIsValidConfigForBulkInferenceFailNoSignatureFound(self): ] ) eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=temp_dir.full_path, + model_path=temp_dir.full_path, model_name='model_1', model_type=constants.TF_GENERIC, ) @@ -400,7 +286,7 @@ def testIsValidConfigForBulkInferenceFailKerasModel(self): ] ) eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=temp_dir.full_path, + model_path=temp_dir.full_path, model_name='model_1', model_type=constants.TF_KERAS, ) @@ -410,63 +296,6 @@ def testIsValidConfigForBulkInferenceFailKerasModel(self): ) ) - def testIsValidConfigForBulkInferenceFailMoreThanOneInput(self): - saved_model_proto = text_format.Parse( - """ - saved_model_schema_version: 1 - meta_graphs { - meta_info_def { - tags: "serve" - } - signature_def: { - key: "serving_default" - value: { - inputs: { - key: "inputs" - value { - dtype: DT_STRING - name: "input_node:0" - } - } - method_name: "predict" - outputs: { - key: "outputs" - value { - dtype: DT_FLOAT - tensor_shape { - dim { size: -1 } - dim { size: 100 } - } - } - } - } - } - } - """, - saved_model_pb2.SavedModel(), - ) - temp_dir = self.create_tempdir() - temp_dir.create_file( - 'saved_model.pb', content=saved_model_proto.SerializeToString() - ) - eval_config = config_pb2.EvalConfig( - model_specs=[ - config_pb2.ModelSpec( - name='model_1', signature_name='serving_default' - ) - ] - ) - eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=temp_dir.full_path, - model_name='model_1', - model_type=constants.TF_GENERIC, - ) - self.assertFalse( - inference_base.is_valid_config_for_bulk_inference( - eval_config, eval_shared_model - ) - ) - def testIsValidConfigForBulkInferenceFailWrongInputType(self): saved_model_proto = text_format.Parse( """ @@ -514,7 +343,7 @@ def testIsValidConfigForBulkInferenceFailWrongInputType(self): ] ) eval_shared_model = self.createTestEvalSharedModel( - eval_saved_model_path=temp_dir.full_path, + model_path=temp_dir.full_path, model_name='model_1', model_type=constants.TF_GENERIC, )