Skip to content

Commit

Permalink
add factory test
Browse files Browse the repository at this point in the history
  • Loading branch information
rpreen committed Jun 27, 2024
1 parent 8bbe1cb commit 3a1de14
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
50 changes: 50 additions & 0 deletions tests/attacks/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Test attack factory."""

from __future__ import annotations

import json
import os

import pytest
import yaml
from sklearn.ensemble import RandomForestClassifier

from aisdc.attacks.factory import run_attacks
from aisdc.config.attack import _get_attack


@pytest.mark.parametrize(
"get_target", [RandomForestClassifier(random_state=1)], indirect=True
)
def test_factory(monkeypatch, get_target):
"""Test Target object creation, saving, and loading."""
# create target_dir
target = get_target
target.save("target")

# create LiRA config with default params
mock_input = "yes"
monkeypatch.setattr("builtins.input", lambda _: mock_input)
attacks = []
attacks.append(_get_attack("lira"))

# create attack.yaml
filename: str = "attack.yaml"
with open(filename, "w", encoding="utf-8") as fp:
yaml.dump({"attacks": attacks}, fp)

# run attacks
run_attacks("target", "attack.yaml")

# load JSON report
path = os.path.normpath("outputs/report.json")
with open(path, encoding="utf-8") as fp:
report = json.load(fp)

# check report output
nr = list(report.keys())[0]
metrics = report[nr]["attack_experiment_logger"]["attack_instance_logger"][
"instance_0"
]
tpr = metrics["TPR"]
assert tpr == pytest.approx(0.92, 0.01)
17 changes: 9 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from aisdc.attacks.target import Target

np.random.seed(1)

folders = [
"RES",
"dt.sav",
Expand All @@ -31,6 +33,7 @@
"release_dir",
"safekeras.tf",
"save_test",
"target",
"test_lira_target",
"test_output_lira",
"test_output_sa",
Expand Down Expand Up @@ -89,7 +92,7 @@ def _cleanup():

@pytest.fixture()
def get_target(request) -> Target: # pylint: disable=too-many-locals
"""Wrap the model and data in a Target object.
"""Return a target object with test data and fitted model.
Uses a randomly sampled 10+10% of the nursery data set.
"""
Expand Down Expand Up @@ -123,13 +126,7 @@ def get_target(request) -> Target: # pylint: disable=too-many-locals
X_test_orig,
y_train_orig,
y_test_orig,
) = train_test_split(
x,
y,
test_size=0.05,
stratify=y,
shuffle=True,
)
) = train_test_split(x, y, test_size=0.05, stratify=y, shuffle=True, random_state=1)

# now resample the training data reduce number of examples
_, X_train_orig, _, y_train_orig = train_test_split(
Expand All @@ -138,6 +135,7 @@ def get_target(request) -> Target: # pylint: disable=too-many-locals
test_size=0.05,
stratify=y_train_orig,
shuffle=True,
random_state=1,
)

# [Researcher] Preprocess dataset
Expand All @@ -162,6 +160,9 @@ def get_target(request) -> Target: # pylint: disable=too-many-locals
xmore = np.concatenate((X_train_orig, X_test_orig))
n_features = np.shape(X_train_orig)[1]

# fit model
model.fit(X_train, y_train)

# wrap
target = Target(model=model)
target.dataset_name = "nursery"
Expand Down

0 comments on commit 3a1de14

Please sign in to comment.