Skip to content

Commit

Permalink
Adopt TF2 Keras layer tests to TF 2.17
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants committed Aug 29, 2024
1 parent aad0533 commit 4c3bad5
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tests/layer_tests/common/tf2_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
def save_to_tf2_savedmodel(tf2_model, path_to_saved_tf2_model):
import tensorflow as tf
assert int(tf.__version__.split('.')[0]) >= 2, "TensorFlow 2 must be used for this suite validation"
tf.keras.models.save_model(tf2_model, path_to_saved_tf2_model, save_format='tf')
# Since TF 2.16 this is only way to serialize Keras objects into SavedModel format
tf2_model.export(path_to_saved_tf2_model)
assert os.path.isdir(path_to_saved_tf2_model), "the model haven't been saved " \
"here: {}".format(path_to_saved_tf2_model)
return path_to_saved_tf2_model
Expand Down
1 change: 0 additions & 1 deletion tests/layer_tests/tensorflow_tests/test_tf_Bucketize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import platform
import platform
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
Expand Down
4 changes: 4 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Round.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import platform
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
Expand Down Expand Up @@ -40,6 +41,9 @@ def test_round_basic(self, input_shape, input_type, ie_device, precision,
ir_version, temp_dir, use_legacy_frontend):
if input_type in [np.int8, np.int16, np.int32, np.int64]:
pytest.skip('TensorFlow issue: https://github.com/tensorflow/tensorflow/issues/74789')
if platform.machine() in ["aarch64", "arm64", "ARM64"] and \
input_type == np.float32 and input_shape == [10, 5, 1, 5]:
pytest.skip("150999: Accuracy issue on CPU")
self._test(*self.create_tf_round_net(input_shape, input_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 4c3bad5

Please sign in to comment.