From 6ddc634945088dae2e60deefd1ded389e282636d Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Thu, 6 Jul 2023 17:24:33 +0200 Subject: [PATCH] fix(tests): Update imports in keras (#621) * fix(tests): Update imports in keras * Fix test --- src/dvclive/keras.py | 7 +++---- tests/test_frameworks/test_keras.py | 17 ++++++----------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/src/dvclive/keras.py b/src/dvclive/keras.py index d74b4f5d..c85bfa82 100644 --- a/src/dvclive/keras.py +++ b/src/dvclive/keras.py @@ -2,14 +2,13 @@ import os from typing import Dict, Optional -from tensorflow.keras.callbacks import Callback -from tensorflow.keras.models import load_model +import tensorflow as tf from dvclive import Live from dvclive.utils import standardize_metric_name -class DVCLiveCallback(Callback): +class DVCLiveCallback(tf.keras.callbacks.Callback): def __init__( self, model_file=None, @@ -31,7 +30,7 @@ def on_train_begin(self, logs=None): if self.save_weights_only: self.model.load_weights(self.model_file) else: - self.model = load_model(self.model_file) + self.model = tf.keras.models.load_model(self.model_file) def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): logs = logs or {} diff --git a/tests/test_frameworks/test_keras.py b/tests/test_frameworks/test_keras.py index b7c5478c..e12e10d6 100644 --- a/tests/test_frameworks/test_keras.py +++ b/tests/test_frameworks/test_keras.py @@ -15,18 +15,17 @@ @pytest.fixture() def xor_model(): import numpy as np - from tensorflow.python.keras import Sequential - from tensorflow.python.keras.layers import Activation, Dense + import tensorflow as tf def make(): x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) y = np.array([[0], [1], [1], [0]]) - model = Sequential() - model.add(Dense(8, input_dim=2)) - model.add(Activation("relu")) - model.add(Dense(1)) - model.add(Activation("sigmoid")) + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(8, input_dim=2)) + model.add(tf.keras.layers.Activation("relu")) + model.add(tf.keras.layers.Dense(1)) + model.add(tf.keras.layers.Activation("sigmoid")) model.compile(loss="binary_crossentropy", optimizer="sgd", metrics=["accuracy"]) @@ -90,8 +89,6 @@ def test_keras_model_file(tmp_dir, xor_model, mocker, save_weights_only): @pytest.mark.parametrize("save_weights_only", [True, False]) def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_only): - import dvclive.keras - model, x, y = xor_model() if save_weights_only: @@ -100,7 +97,6 @@ def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_onl model.save("model.h5") load_weights = mocker.spy(model, "load_weights") - load_model = mocker.spy(dvclive.keras, "load_model") model.fit( x, @@ -116,7 +112,6 @@ def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_onl ], ) - assert load_model.call_count != save_weights_only assert load_weights.call_count == save_weights_only