Skip to content

Commit

Permalink
fix(tests): Update imports in keras (#621)
Browse files Browse the repository at this point in the history
* fix(tests): Update imports in keras

* Fix test
  • Loading branch information
daavoo authored Jul 6, 2023
1 parent 26524fd commit 6ddc634
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
7 changes: 3 additions & 4 deletions src/dvclive/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import os
from typing import Dict, Optional

from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import load_model
import tensorflow as tf

from dvclive import Live
from dvclive.utils import standardize_metric_name


class DVCLiveCallback(Callback):
class DVCLiveCallback(tf.keras.callbacks.Callback):
def __init__(
self,
model_file=None,
Expand All @@ -31,7 +30,7 @@ def on_train_begin(self, logs=None):
if self.save_weights_only:
self.model.load_weights(self.model_file)
else:
self.model = load_model(self.model_file)
self.model = tf.keras.models.load_model(self.model_file)

def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
logs = logs or {}
Expand Down
17 changes: 6 additions & 11 deletions tests/test_frameworks/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
@pytest.fixture()
def xor_model():
import numpy as np
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Activation, Dense
import tensorflow as tf

def make():
x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])

model = Sequential()
model.add(Dense(8, input_dim=2))
model.add(Activation("relu"))
model.add(Dense(1))
model.add(Activation("sigmoid"))
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(8, input_dim=2))
model.add(tf.keras.layers.Activation("relu"))
model.add(tf.keras.layers.Dense(1))
model.add(tf.keras.layers.Activation("sigmoid"))

model.compile(loss="binary_crossentropy", optimizer="sgd", metrics=["accuracy"])

Expand Down Expand Up @@ -90,8 +89,6 @@ def test_keras_model_file(tmp_dir, xor_model, mocker, save_weights_only):

@pytest.mark.parametrize("save_weights_only", [True, False])
def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_only):
import dvclive.keras

model, x, y = xor_model()

if save_weights_only:
Expand All @@ -100,7 +97,6 @@ def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_onl
model.save("model.h5")

load_weights = mocker.spy(model, "load_weights")
load_model = mocker.spy(dvclive.keras, "load_model")

model.fit(
x,
Expand All @@ -116,7 +112,6 @@ def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_onl
],
)

assert load_model.call_count != save_weights_only
assert load_weights.call_count == save_weights_only


Expand Down

0 comments on commit 6ddc634

Please sign in to comment.