Skip to content

Commit

Permalink
Update model_eval_lib_test to use the new export method for Keras mod…
Browse files Browse the repository at this point in the history
…els.

PiperOrigin-RevId: 686750089
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Oct 17, 2024
1 parent f8fef85 commit 181ac2f
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions tensorflow_model_analysis/api/model_eval_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def _exportEvalSavedModel(self, classifier):
return eval_export_dir

def _exportKerasModel(self, classifier):
temp_export_dir = os.path.join(self._getTempDir(), 'keras_export_dir')
classifier.save(temp_export_dir, save_format='tf')
temp_export_dir = os.path.join(self._getTempDir(), 'saved_model_export_dir')
classifier.export(temp_export_dir)
return temp_export_dir

def _writeTFExamplesToTFRecords(self, examples):
Expand Down Expand Up @@ -200,7 +200,7 @@ def testRunModelAnalysis(self):
self._makeExample(age=5.0, language='chinese', label=1.0),
self._makeExample(age=5.0, language='hindi', label=1.0),
]
classifier = example_keras_model.ExampleClassifierModel(
classifier = example_keras_model.get_example_classifier_model(
example_keras_model.LANGUAGE
)
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
Expand Down Expand Up @@ -533,7 +533,7 @@ def _build_keras_model(
f.write(tflite_model)
elif model_type == constants.TF_JS:
src_model_path = tempfile.mkdtemp()
model.save(src_model_path, save_format='tf')
model.export(src_model_path)

tfjs_converter.convert([
'--input_format=tf_saved_model',
Expand All @@ -543,7 +543,7 @@ def _build_keras_model(
model_location,
])
else:
model.save(model_location, save_format='tf')
model.export(model_location)
return model_eval_lib.default_eval_shared_model(
eval_saved_model_path=model_location,
eval_config=eval_config,
Expand Down Expand Up @@ -768,7 +768,8 @@ def check_eval_result(eval_result, model_location):
},
}
if (
model_type not in (constants.TF_LITE, constants.TF_JS)
model_type
not in (constants.TF_LITE, constants.TF_JS, constants.TF_KERAS)
and _TF_MAJOR_VERSION >= 2
):
expected_metrics[''] = {'loss': True}
Expand Down Expand Up @@ -801,7 +802,7 @@ def _build_keras_model(eval_config, export_name='export_dir'):
model = tf_keras.models.Model(layers_per_output, layers_per_output)
model.compile(loss=tf_keras.losses.categorical_crossentropy)
model_location = os.path.join(self._getTempDir(), export_name)
model.save(model_location, save_format='tf')
model.export(model_location)
return model_eval_lib.default_eval_shared_model(
eval_saved_model_path=model_location,
eval_config=eval_config,
Expand Down Expand Up @@ -978,7 +979,7 @@ def testRunModelAnalysisWithQueryBasedMetrics(self):
model.fit(dataset, steps_per_epoch=1)

model_location = os.path.join(self._getTempDir(), 'export_dir')
model.save(model_location, save_format='tf')
model.export(model_location)

schema = text_format.Parse(
"""
Expand Down Expand Up @@ -1129,7 +1130,7 @@ def testRunModelAnalysisWithUncertainty(self):
self._makeExample(age=5.0, language='chinese', label=1.0),
self._makeExample(age=5.0, language='hindi', label=1.0),
]
classifier = example_keras_model.ExampleClassifierModel(
classifier = example_keras_model.get_example_classifier_model(
example_keras_model.LANGUAGE
)
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
Expand Down Expand Up @@ -1226,7 +1227,7 @@ def testRunModelAnalysisWithDeterministicConfidenceIntervals(self):
self._makeExample(age=5.0, language='chinese', label=1.0),
self._makeExample(age=5.0, language='hindi', label=1.0),
]
classifier = example_keras_model.ExampleClassifierModel(
classifier = example_keras_model.get_example_classifier_model(
example_keras_model.LANGUAGE
)
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
Expand Down Expand Up @@ -1333,7 +1334,7 @@ def testRunModelAnalysisWithSchema(self):
self._makeExample(age=5.0, language='hindi', label=2.0),
]
data_location = self._writeTFExamplesToTFRecords(examples)
classifier = example_keras_model.ExampleClassifierModel(
classifier = example_keras_model.get_example_classifier_model(
example_keras_model.LANGUAGE
)
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
Expand Down

0 comments on commit 181ac2f

Please sign in to comment.