Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration test for backend-specific imports #945

Closed
wants to merge 14 commits into from
51 changes: 51 additions & 0 deletions integration_tests/basic_full_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pytest

import keras_core as keras
from keras_core import layers
from keras_core import losses
from keras_core import metrics
from keras_core import optimizers
from keras_core import testing


class MyModel(keras.Model):
def __init__(self, hidden_dim, output_dim, **kwargs):
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.dense1 = layers.Dense(hidden_dim, activation="relu")
self.dense2 = layers.Dense(hidden_dim, activation="relu")
self.dense3 = layers.Dense(output_dim)

def call(self, x):
x = self.dense1(x)
x = self.dense2(x)
return self.dense3(x)


@pytest.mark.requires_trainable_backend
class BasicFlowTest(testing.TestCase):
def test_basic_fit(self):
model = MyModel(hidden_dim=256, output_dim=16)

x = np.random.random((50000, 128))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 3

model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
output_before_fit = model(x)
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
output_after_fit = model(x)

print("History:")
print(history.history)

self.assertNotAllClose(output_before_fit, output_after_fit)
118 changes: 118 additions & 0 deletions integration_tests/import_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import re
import subprocess

from keras_core import backend

BACKEND_REQ = {
"tensorflow": "tensorflow",
"torch": "torch torchvision",
"jax": "jax jaxlib",
}


def setup_package():
subprocess.run("rm -rf tmp_build_dir", shell=True)
build_process = subprocess.run(
"python3 pip_build.py",
capture_output=True,
text=True,
shell=True,
)
print(build_process.stdout)
match = re.search(
r"\s[^\s]*\.whl",
build_process.stdout,
)
if not match:
raise ValueError("Installing Keras Core package unsuccessful. ")
print(build_process.stderr)
whl_path = match.group()
return whl_path


def create_virtualenv():
env_setup = [
# Create and activate virtual environment
"python3 -m venv test_env",
# "source ./test_env/bin/activate",
]
os.environ["PATH"] = (
"/test_env/bin/" + os.pathsep + os.environ.get("PATH", "")
)
run_commands_local(env_setup)


def manage_venv_installs(whl_path):
other_backends = list(set(BACKEND_REQ.keys()) - {backend.backend()})
install_setup = [
# Installs the backend's package and common requirements
"pip install " + BACKEND_REQ[backend.backend()],
"pip install -r requirements-common.txt",
"pip install pytest",
# Ensure other backends are uninstalled
"pip uninstall -y "
+ BACKEND_REQ[other_backends[0]]
+ " "
+ BACKEND_REQ[other_backends[1]],
# Install `.whl` package
"pip install " + whl_path + " --force-reinstall --no-dependencies",
]
run_commands_venv(install_setup)


def run_keras_core_flow():
test_script = [
# Runs the example script
"python -m pytest integration_tests/basic_full_flow.py",
]
run_commands_venv(test_script)


def cleanup():
cleanup_script = [
# Exits virtual environment, deletes files, and any
# miscellaneous install logs
"exit",
"rm -rf test_env",
"rm -rf tmp_build_dir",
"rm -f *+cpu",
]
run_commands_local(cleanup_script)


def run_commands_local(commands):
for command in commands:
subprocess.run(command, shell=True)


def run_commands_venv(commands):
for command in commands:
cmd_with_args = command.split(" ")
cmd_with_args[0] = "test_env/bin/" + cmd_with_args[0]
p = subprocess.Popen(cmd_with_args)
p.wait()


def test_keras_core_imports():
# Ensures packages from all backends are installed.
# Builds Keras core package and returns package file path.
whl_path = setup_package()

# Creates and activates a virtual environment.
create_virtualenv()

# Ensures the backend's package is installed
# and the other backends are uninstalled.
manage_venv_installs(whl_path)

# Runs test of basic flow in Keras Core.
# Tests for backend-specific imports and `model.fit()`.
run_keras_core_flow()

# Removes virtual environment and associated files
cleanup()


if __name__ == "__main__":
test_keras_core_imports()
Loading