Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasrizzo committed Mar 23, 2024
1 parent f9f8d67 commit c6ab36a
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ unfixable = []
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.lint.pydocstyle]
[tool.ruff.pydocstyle]
convention = "google"

[tool.ruff.format]
Expand Down
3 changes: 2 additions & 1 deletion src/catsim/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def __init__(
:type dist_params: tuple, optional
"""
super().__init__()
assert isinstance(dist_type, InitializationDistribution), "dist_type must be an InitializationDistribution"
self._dist_type = dist_type
self._dist_params = dist_params
self.__rng = numpy.random.default_rng(1337)
self.__rng = numpy.random.default_rng()

def initialize(self, index: int | None = None, **kwargs: dict[str, Any]) -> float:
"""Generates a value using the chosen distribution and parameters.
Expand Down
1 change: 1 addition & 0 deletions src/catsim/irt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from enum import Enum
from math import pi # noqa: F401
from typing import Any

import numexpr
Expand Down
3 changes: 2 additions & 1 deletion src/catsim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,14 @@ def __init__(
self._estimator = estimator
self._stopper = stopper

self.__rng = numpy.random.default_rng()

# `examinees` is passed to its special setter
self._examinees = self._to_distribution(examinees)

self._estimations: list[list[int]] = [[] for _ in range(self.examinees.shape[0])]
self._administered_items: list[list[int]] = [[] for _ in range(self.examinees.shape[0])]
self._response_vectors: list[list[bool]] = [[] for _ in range(self.examinees.shape[0])]
self.__rng = numpy.random.default_rng()

@property
def items(self) -> numpy.ndarray:
Expand Down
3 changes: 2 additions & 1 deletion src/catsim/stopping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import numpy

from . import irt
Expand All @@ -23,7 +24,7 @@ def __init__(self, max_itens: int) -> None:
super().__init__()
self._max_itens = max_itens

def stop(self, index: int | None = None, administered_items: numpy.ndarray = None) -> bool:
def stop(self, index: int | None = None, administered_items: numpy.ndarray = None, **kwargs: dict[str, Any]) -> bool:
"""Check whether the test reached its stopping criterion for the given user.
:param index: the index of the current examinee
Expand Down
16 changes: 9 additions & 7 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from catsim import cat, irt, plot
from catsim.cat import generate_item_bank
from catsim.estimation import NumericalSearchEstimator
from catsim.initialization import FixedPointInitializer, RandomInitializer
from catsim.initialization import FixedPointInitializer, InitializationDistribution, RandomInitializer
from catsim.selection import (
AStratBBlockSelector,
AStratSelector,
Expand Down Expand Up @@ -37,7 +37,7 @@ def one_simulation(

@pytest.mark.parametrize("examinees", [100])
@pytest.mark.parametrize("bank_size", [500])
@pytest.mark.parametrize("initializer", [RandomInitializer("uniform", (-5, 5))])
@pytest.mark.parametrize("initializer", [RandomInitializer(InitializationDistribution.UNIFORM, (-5, 5))])
@pytest.mark.parametrize("estimator", [NumericalSearchEstimator()])
@pytest.mark.parametrize("stopper", [MaxItemStopper(30), MinErrorStopper(0.4)])
def test_cism(
Expand All @@ -62,7 +62,7 @@ def test_cism(
@pytest.mark.parametrize(
"initializer",
[
RandomInitializer("uniform", (-5, 5)),
RandomInitializer(InitializationDistribution.UNIFORM, (-5, 5)),
FixedPointInitializer(0),
],
)
Expand All @@ -77,8 +77,9 @@ def test_finite_selectors(
initializer: Initializer,
estimator: Estimator,
):
rng = np.random.default_rng()
finite_selectors = [
LinearSelector(list(np.random.choice(bank_size, size=test_size, replace=False))),
LinearSelector(list(rng.choice(bank_size, size=test_size, replace=False))),
AStratSelector(test_size),
AStratBBlockSelector(test_size),
MaxInfoStratSelector(test_size),
Expand All @@ -91,7 +92,7 @@ def test_finite_selectors(
for selector in finite_selectors:
items = generate_item_bank(bank_size, itemtype=logistic_model)
responses = cat.random_response_vector(random.randint(1, test_size - 1))
administered_items = np.random.choice(bank_size, len(responses), replace=False)
administered_items = rng.choice(bank_size, len(responses), replace=False)
est_theta = initializer.initialize()
selector.select(
items=items,
Expand All @@ -118,7 +119,7 @@ def test_finite_selectors(
@pytest.mark.parametrize(
"initializer",
[
RandomInitializer("uniform", (-5, 5)),
RandomInitializer(InitializationDistribution.UNIFORM, (-5, 5)),
FixedPointInitializer(0),
],
)
Expand Down Expand Up @@ -149,10 +150,11 @@ def test_infinite_selectors(
estimator: Estimator,
stopper: Stopper,
):
rng = np.random.default_rng()
items = generate_item_bank(bank_size, itemtype=logistic_model)
max_administered_items = stopper.max_itens if type(stopper) == MaxItemStopper else bank_size
responses = cat.random_response_vector(random.randint(1, max_administered_items))
administered_items = np.random.choice(bank_size, len(responses), replace=False)
administered_items = rng.choice(bank_size, len(responses), replace=False)
est_theta = initializer.initialize()
selector.select(
items=items,
Expand Down

0 comments on commit c6ab36a

Please sign in to comment.