diff --git a/tests/test_formula.py b/tests/test_formula.py index 89d9ff9d..bd2c712a 100644 --- a/tests/test_formula.py +++ b/tests/test_formula.py @@ -158,9 +158,7 @@ def test_matrix_against_expectation(df, formula, expected): assert len(model_df.matrices) == len(expected.matrices) for res, exp in zip(model_df.matrices, expected.matrices): assert type(res) == type(exp) - if isinstance(res, tm.DenseMatrix): - np.testing.assert_array_equal(res, exp) - elif isinstance(res, tm.SparseMatrix): + if isinstance(res, (tm.DenseMatrix, tm.SparseMatrix)): np.testing.assert_array_equal(res.A, res.A) elif isinstance(res, tm.CategoricalMatrix): assert (exp.cat == res.cat).all() @@ -269,9 +267,7 @@ def test_matrix_against_expectation_qcl(df, formula, expected): assert len(model_df.matrices) == len(expected.matrices) for res, exp in zip(model_df.matrices, expected.matrices): assert type(res) == type(exp) - if isinstance(res, tm.DenseMatrix): - np.testing.assert_array_equal(res, exp) - elif isinstance(res, tm.SparseMatrix): + if isinstance(res, (tm.DenseMatrix, tm.SparseMatrix)): np.testing.assert_array_equal(res.A, res.A) elif isinstance(res, tm.CategoricalMatrix): assert (exp.cat == res.cat).all() @@ -694,19 +690,19 @@ def test_state(self, materializer): mm = materializer.get_model_matrix("center(a) - 1") assert isinstance(mm, tm.MatrixBase) assert list(mm.model_spec.column_names) == ["center(a)"] - assert np.allclose(mm.getcol(0).squeeze(), [-1, 0, 1]) + assert np.allclose(mm.getcol(0).unpack().squeeze(), [-1, 0, 1]) mm2 = TabmatMaterializer(pd.DataFrame({"a": [4, 5, 6]})).get_model_matrix( mm.model_spec ) assert isinstance(mm2, tm.MatrixBase) assert list(mm2.model_spec.column_names) == ["center(a)"] - assert np.allclose(mm2.getcol(0).squeeze(), [2, 3, 4]) + assert np.allclose(mm2.getcol(0).unpack().squeeze(), [2, 3, 4]) mm3 = mm.model_spec.get_model_matrix(pd.DataFrame({"a": [4, 5, 6]})) assert isinstance(mm3, tm.MatrixBase) assert list(mm3.model_spec.column_names) == ["center(a)"] - assert np.allclose(mm3.getcol(0).squeeze(), [2, 3, 4]) + assert np.allclose(mm3.getcol(0).unpack().squeeze(), [2, 3, 4]) def test_factor_evaluation_edge_cases(self, materializer): # Test that categorical kinds are set if type would otherwise be numerical