Skip to content

Commit

Permalink
add ci run test
Browse files Browse the repository at this point in the history
  • Loading branch information
pirocheto committed Nov 27, 2023
1 parent 766a14b commit 9ff0c26
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 161 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/run_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Publish Model
run-name: Publish Model
on:
push:
branches:
- train
- dev
jobs:
publish_model:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.10.13"
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-in-project: true
- name: Install dependencies
run: |
poetry install --no-root --no-interaction --with test
- name: Pull model
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
source .venv/bin/activate
dvc pull live/model/model.pkl live/model/model.onnx
- name: Run test
run: |
pytest
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,3 @@ modelcard:
optuna_dashboard:
optuna-dashboard notebooks/optunalog/optuna.db


test:
coverage run --source=src -m pytest -s
coverage report -m
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ exclude_also = ['if __name__ == "__main__":']

[tool.pylint]
good-names = ["X_train", "X_test"]

[tool.pylint.MASTER]
ignore-patterns = ["test_.*?py"]

[tool.pylint.report]
ignore = ["tests"]
2 changes: 1 addition & 1 deletion src/create_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def pkl2onnx(model):
return onx.SerializeToString()


def create_onnx() -> str: # pragma: no cover
def create_onnx() -> str:
"""Create an ONNX file from a pickled machine learning model."""

params = dvc.api.params_show()
Expand Down
6 changes: 3 additions & 3 deletions src/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ def load_pickle_model(path: str):
return model


def load_onnx_model(path: str):
def load_onnx_session(path: str):
"""Load a machine learning model from a file."""
from pathlib import Path

import onnxruntime

model_path = Path(path)

model = onnxruntime.InferenceSession(
sess = onnxruntime.InferenceSession(
model_path,
providers=["CPUExecutionProvider"],
)
return model
return sess
42 changes: 7 additions & 35 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,19 @@
import dvc.api
import pandas as pd
import pytest
from sklearn.preprocessing import LabelEncoder

from src.helper import create_model, load_model
from src.helper import load_data, load_onnx_session, load_pickle_model


@pytest.fixture
def sample_data():
params = dvc.api.params_show()
df_test = pd.read_parquet(params["data"]["test"])
df_sample = df_test.head(50)
y = LabelEncoder().fit_transform(df_sample["status"])
return df_sample["url"], y
def X_sample():
X, _ = load_data("data/test.parquet")
return X[:50]


@pytest.fixture
def dummy_model():
return create_model()


@pytest.fixture
def fitted_dummy_model(dummy_model, sample_data):
X_train, y_train = sample_data
dummy_model.fit(X_train, y_train)
return dummy_model


@pytest.fixture
def model():
model = create_model()
return model
def onnx_sess():
return load_onnx_session("live/model/model.onnx")


@pytest.fixture
def pkl_model():
params = dvc.api.params_show()
model = load_model(params["model"]["pickle"])
return model


@pytest.fixture
def onnx_model():
params = dvc.api.params_show()
model = load_model(params["model"]["onnx"], model_format="onnx")
return model
return load_pickle_model("live/model/model.pkl")
85 changes: 0 additions & 85 deletions tests/test_code.py

This file was deleted.

8 changes: 0 additions & 8 deletions tests/test_data.py

This file was deleted.

133 changes: 108 additions & 25 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,117 @@
def test_training(sample_data, model):
X_train, y_train = sample_data
model.fit(X_train, y_train)
assert model.__sklearn_is_fitted__()
import time
from pathlib import Path

import dvc.api
import numpy as np
import pytest

def test_predict(sample_data, pkl_model):
X_test, _ = sample_data
X_pred = pkl_model.predict(X_test)
assert X_pred.shape == X_test.shape
# ! Constants for test conditions
# Maximum allowed model size in megabytes (Mo)
MAX_SIZE = 50

# Minimum required F1 score
MIN_F1 = 0.90

def test_onnx_model(onnx_model):
model_inputs = onnx_model.get_inputs()
assert len(model_inputs) == 1
# Minimum required precision score
MIN_PRECISION = 0.90

model_input = model_inputs[0]
# Minimum required recall score
MIN_RECALL = 0.90

assert model_input.name == "inputs"
assert model_input.type == "tensor(string)"
assert model_input.shape == [None]
# Minimum required ROC AUC score
MIN_ROC_AUC = 0.90

model_outputs = onnx_model.get_outputs()
assert len(model_outputs) == 2
# Minimum required accuracy score
MIN_ACCURACY = 0.90

output_label = model_outputs[0]
assert output_label.name == "label"
assert output_label.type == "tensor(int64)"
assert output_label.shape == [None]
# Maximum allowed inference time in seconds
MAX_INFERENCE_TIME = 0.5

output_proba = model_outputs[1]
assert output_proba.name == "probabilities"
assert output_proba.type == "tensor(float)"
assert output_proba.shape == [None, 2]

def test_metrics():
# Retrieve metrics using dvc.api
metrics = dvc.api.metrics_show()

# Assertions for metric values
f1 = metrics["f1"]
assert f1 > MIN_F1, f"F1 score below ({f1}) the minimum required ({MIN_F1})"

precision = metrics["precision"]
assert (
precision > MIN_PRECISION
), f"Precision score ({precision}) below the minimum required ({MIN_PRECISION})"

recall = metrics["recall"]
assert (
recall > MIN_RECALL
), f"Recall score ({recall}) below the minimum required ({MIN_RECALL})"

roc_auc = metrics["roc_auc"]
assert (
roc_auc > MIN_ROC_AUC
), f"ROC AUC score ({roc_auc}) below the minimum required ({MIN_ROC_AUC})"

accuracy = metrics["accuracy"]
assert (
accuracy > MIN_ACCURACY
), f"Accuracy score ({accuracy}) below the minimum required ({MIN_ACCURACY})"


@pytest.mark.parametrize("path", [("live/model/model.onnx"), ("live/model/model.pkl")])
def test_model_size(path):
"""Check the size of the model file"""

model_path = Path(path)
size_mo = model_path.stat().st_size / (1024**2)
assert (
size_mo < MAX_SIZE
), f"Model size ({size_mo} MB) exceeds the maximum allowed size ({MAX_SIZE} MB)"


def get_inference_time(data, predict):
"""Calculate the average inference time per input"""

inputs = np.array(data, dtype="str")

start_time = time.time()
for x in inputs:
predict(x)

inference_time = (time.time() - start_time) / len(inputs)
return inference_time


def test_onnx_inference_time(X_sample, onnx_sess):
"""Test the inference time for the ONNX model"""

def predict(x):
return onnx_sess.run(None, {"inputs": [x]})

inference_time = get_inference_time(X_sample, predict)
assert inference_time < MAX_INFERENCE_TIME, (
f"Inference time ({inference_time:.4f} seconds) for ONNX model "
f"exceeds the maximum allowed time ({MAX_INFERENCE_TIME} seconds)"
)


def test_pickle_inference_time(X_sample, pkl_model):
"""Test the inference time for the Pickle model"""

def predict(x):
return pkl_model.predict([x])

inference_time = get_inference_time(X_sample, predict)
assert inference_time < MAX_INFERENCE_TIME, (
f"Inference time ({inference_time:.4f} seconds) for pickle model "
f"exceeds the maximum allowed time ({MAX_INFERENCE_TIME} seconds)"
)


def test_equivalence(X_sample, pkl_model, onnx_sess):
"""Test equivalence between Pickle and ONNX models"""

pickle_proba = pkl_model.predict_proba(X_sample)
onnx_proba = onnx_sess.run(None, {"inputs": X_sample})[1]

# Assert that the probabilities are close within specified tolerances
np.testing.assert_allclose(pickle_proba, onnx_proba, rtol=1e-2, atol=1e-2)

0 comments on commit 9ff0c26

Please sign in to comment.