Skip to content

Commit

Permalink
Remove estimator related tests in inference_base.test.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685841191
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Oct 14, 2024
1 parent c8b5fbd commit 2291159
Showing 1 changed file with 6 additions and 177 deletions.
183 changes: 6 additions & 177 deletions tensorflow_model_analysis/extractors/inference_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,12 @@

import os

import apache_beam as beam
from apache_beam.testing import util
import numpy as np
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import model_eval_lib
from tensorflow_model_analysis.eval_saved_model import testutil
from tensorflow_model_analysis.eval_saved_model.example_trainers import fixed_prediction_estimator_extra_fields
from tensorflow_model_analysis.extractors import features_extractor
from tensorflow_model_analysis.extractors import inference_base
from tensorflow_model_analysis.extractors import tfx_bsl_predictions_extractor
from tensorflow_model_analysis.proto import config_pb2
from tensorflow_model_analysis.utils import test_util as testutil
from tfx_bsl.tfxio import tensor_adapter
from tfx_bsl.tfxio import test_util

Expand Down Expand Up @@ -76,114 +70,6 @@ def _create_tfxio_and_feature_extractor(
)
return tfx_io, feature_extractor

def testRegressionModel(self):
temp_export_dir = self._getExportDir()
export_dir, _ = (
fixed_prediction_estimator_extra_fields.simple_fixed_prediction_estimator_extra_fields(
temp_export_dir, None
)
)

eval_config = config_pb2.EvalConfig(
model_specs=[
config_pb2.ModelSpec(
name='model_1', signature_name='serving_default'
)
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING]
)
tfx_io, feature_extractor = self._create_tfxio_and_feature_extractor(
eval_config,
text_format.Parse(
"""
feature {
name: "prediction"
type: FLOAT
}
feature {
name: "label"
type: FLOAT
}
feature {
name: "fixed_int"
type: INT
}
feature {
name: "fixed_float"
type: FLOAT
}
feature {
name: "fixed_string"
type: BYTES
}
""",
schema_pb2.Schema(),
),
)

examples = [
self._makeExample(
prediction=0.2,
label=1.0,
fixed_int=1,
fixed_float=1.0,
fixed_string='fixed_string1',
),
self._makeExample(
prediction=0.8,
label=0.0,
fixed_int=1,
fixed_float=1.0,
fixed_string='fixed_string2',
),
self._makeExample(
prediction=0.5,
label=0.0,
fixed_int=2,
fixed_float=1.0,
fixed_string='fixed_string3',
),
]
num_examples = len(examples)

tfx_bsl_inference_ptransform = inference_base.RunInference(
tfx_bsl_predictions_extractor.TfxBslInferenceWrapper(
eval_config.model_specs, {'': eval_shared_model}
),
output_batch_size=num_examples,
)

with beam.Pipeline() as pipeline:
# pylint: disable=no-value-for-parameter
result = (
pipeline
| 'Create'
>> beam.Create(
[e.SerializeToString() for e in examples], reshuffle=False
)
| 'BatchExamples' >> tfx_io.BeamSource(batch_size=num_examples)
| 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
| feature_extractor.stage_name >> feature_extractor.ptransform
| 'RunInferenceBase' >> tfx_bsl_inference_ptransform
)

# pylint: enable=no-value-for-parameter

def check_result(got):
try:
self.assertLen(got, 1)
self.assertIn(constants.PREDICTIONS_KEY, got[0])
self.assertAllClose(
np.array([[0.2], [0.8], [0.5]]), got[0][constants.PREDICTIONS_KEY]
)

except AssertionError as err:
raise util.BeamAssertException(err)

util.assert_that(result, check_result)

def testIsValidConfigForBulkInferencePass(self):
saved_model_proto = text_format.Parse(
"""
Expand Down Expand Up @@ -231,7 +117,7 @@ def testIsValidConfigForBulkInferencePass(self):
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=temp_dir.full_path,
model_path=temp_dir.full_path,
model_name='model_1',
tags=[tf.saved_model.SERVING],
model_type=constants.TF_GENERIC,
Expand Down Expand Up @@ -286,7 +172,7 @@ def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self):
model_specs=[config_pb2.ModelSpec(name='model_1')]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=temp_dir.full_path,
model_path=temp_dir.full_path,
model_name='model_1',
tags=[tf.saved_model.SERVING],
model_type=constants.TF_GENERIC,
Expand Down Expand Up @@ -343,7 +229,7 @@ def testIsValidConfigForBulkInferenceFailNoSignatureFound(self):
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=temp_dir.full_path,
model_path=temp_dir.full_path,
model_name='model_1',
model_type=constants.TF_GENERIC,
)
Expand Down Expand Up @@ -400,7 +286,7 @@ def testIsValidConfigForBulkInferenceFailKerasModel(self):
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=temp_dir.full_path,
model_path=temp_dir.full_path,
model_name='model_1',
model_type=constants.TF_KERAS,
)
Expand All @@ -410,63 +296,6 @@ def testIsValidConfigForBulkInferenceFailKerasModel(self):
)
)

def testIsValidConfigForBulkInferenceFailMoreThanOneInput(self):
saved_model_proto = text_format.Parse(
"""
saved_model_schema_version: 1
meta_graphs {
meta_info_def {
tags: "serve"
}
signature_def: {
key: "serving_default"
value: {
inputs: {
key: "inputs"
value {
dtype: DT_STRING
name: "input_node:0"
}
}
method_name: "predict"
outputs: {
key: "outputs"
value {
dtype: DT_FLOAT
tensor_shape {
dim { size: -1 }
dim { size: 100 }
}
}
}
}
}
}
""",
saved_model_pb2.SavedModel(),
)
temp_dir = self.create_tempdir()
temp_dir.create_file(
'saved_model.pb', content=saved_model_proto.SerializeToString()
)
eval_config = config_pb2.EvalConfig(
model_specs=[
config_pb2.ModelSpec(
name='model_1', signature_name='serving_default'
)
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=temp_dir.full_path,
model_name='model_1',
model_type=constants.TF_GENERIC,
)
self.assertFalse(
inference_base.is_valid_config_for_bulk_inference(
eval_config, eval_shared_model
)
)

def testIsValidConfigForBulkInferenceFailWrongInputType(self):
saved_model_proto = text_format.Parse(
"""
Expand Down Expand Up @@ -514,7 +343,7 @@ def testIsValidConfigForBulkInferenceFailWrongInputType(self):
]
)
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=temp_dir.full_path,
model_path=temp_dir.full_path,
model_name='model_1',
model_type=constants.TF_GENERIC,
)
Expand Down

0 comments on commit 2291159

Please sign in to comment.