Skip to content

Commit

Permalink
Adjust tests to work with v4
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Aug 9, 2023
1 parent 9b04f8c commit 5c064c2
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions tests/test_formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,7 @@ def test_matrix_against_expectation(df, formula, expected):
assert len(model_df.matrices) == len(expected.matrices)
for res, exp in zip(model_df.matrices, expected.matrices):
assert type(res) == type(exp)
if isinstance(res, tm.DenseMatrix):
np.testing.assert_array_equal(res, exp)
elif isinstance(res, tm.SparseMatrix):
if isinstance(res, (tm.DenseMatrix, tm.SparseMatrix)):
np.testing.assert_array_equal(res.A, res.A)
elif isinstance(res, tm.CategoricalMatrix):
assert (exp.cat == res.cat).all()
Expand Down Expand Up @@ -269,9 +267,7 @@ def test_matrix_against_expectation_qcl(df, formula, expected):
assert len(model_df.matrices) == len(expected.matrices)
for res, exp in zip(model_df.matrices, expected.matrices):
assert type(res) == type(exp)
if isinstance(res, tm.DenseMatrix):
np.testing.assert_array_equal(res, exp)
elif isinstance(res, tm.SparseMatrix):
if isinstance(res, (tm.DenseMatrix, tm.SparseMatrix)):
np.testing.assert_array_equal(res.A, res.A)
elif isinstance(res, tm.CategoricalMatrix):
assert (exp.cat == res.cat).all()
Expand Down Expand Up @@ -694,19 +690,19 @@ def test_state(self, materializer):
mm = materializer.get_model_matrix("center(a) - 1")
assert isinstance(mm, tm.MatrixBase)
assert list(mm.model_spec.column_names) == ["center(a)"]
assert np.allclose(mm.getcol(0).squeeze(), [-1, 0, 1])
assert np.allclose(mm.getcol(0).unpack().squeeze(), [-1, 0, 1])

mm2 = TabmatMaterializer(pd.DataFrame({"a": [4, 5, 6]})).get_model_matrix(
mm.model_spec
)
assert isinstance(mm2, tm.MatrixBase)
assert list(mm2.model_spec.column_names) == ["center(a)"]
assert np.allclose(mm2.getcol(0).squeeze(), [2, 3, 4])
assert np.allclose(mm2.getcol(0).unpack().squeeze(), [2, 3, 4])

mm3 = mm.model_spec.get_model_matrix(pd.DataFrame({"a": [4, 5, 6]}))
assert isinstance(mm3, tm.MatrixBase)
assert list(mm3.model_spec.column_names) == ["center(a)"]
assert np.allclose(mm3.getcol(0).squeeze(), [2, 3, 4])
assert np.allclose(mm3.getcol(0).unpack().squeeze(), [2, 3, 4])

def test_factor_evaluation_edge_cases(self, materializer):
# Test that categorical kinds are set if type would otherwise be numerical
Expand Down

0 comments on commit 5c064c2

Please sign in to comment.