Skip to content

Commit

Permalink
Fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Sep 3, 2024
1 parent 4fda475 commit 2007569
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 57 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ repos:
# mypy
- id: mypy
name: mypy
entry: pixi run -e default mypy
exclude: (^tests/|^src/glum_benchmarks/orig_sklearn_fork/)
entry: pixi run -e default mypy --allow-redefinition
exclude: (^tests/)
language: system
types: [python]
require_serial: true
Expand Down
19 changes: 10 additions & 9 deletions src/tabmat/categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def _extract_codes_and_categories(cat_vec):


def _row_col_indexing(
arr: np.ndarray, rows: Optional[np.ndarray], cols: Optional[np.ndarray]
arr: Union[np.ndarray, sps.spmatrix],
rows: Optional[np.ndarray],
cols: Optional[np.ndarray],
) -> np.ndarray:
if isinstance(rows, slice) and rows == slice(None, None, None):
rows = None
Expand Down Expand Up @@ -411,7 +413,7 @@ def recover_orig(self) -> np.ndarray:
def _matvec_setup(
self,
other: Union[list, np.ndarray],
cols: np.ndarray = None,
cols: Optional[np.ndarray] = None,
) -> tuple[np.ndarray, Optional[np.ndarray]]:
other = np.asarray(other)
if other.ndim > 1:
Expand All @@ -434,8 +436,8 @@ def _matvec_setup(
def matvec(
self,
other: Union[list, np.ndarray],
cols: np.ndarray = None,
out: np.ndarray = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Multiply self with vector 'other', and add vector 'out' if it is present.
Expand Down Expand Up @@ -524,8 +526,7 @@ def transpose_matvec(
"CategoricalMatrix.transpose_matvec is only implemented for 1d arrays."
)

out_is_none = out is None
if out_is_none:
if out_is_none := out is None:
out = np.zeros(self.shape[1], dtype=self.dtype)
else:
check_transpose_matvec_out_shape(self, out)
Expand Down Expand Up @@ -558,8 +559,8 @@ def transpose_matvec(
def sandwich(
self,
d: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
) -> sps.dia_matrix:
"""
Perform a sandwich product: X.T @ diag(d) @ X.
Expand Down Expand Up @@ -594,7 +595,7 @@ def sandwich(
def _cross_sandwich(
self,
other: MatrixBase,
d: Union[np.ndarray, list],
d: np.ndarray,
rows: Optional[np.ndarray] = None,
L_cols: Optional[np.ndarray] = None,
R_cols: Optional[np.ndarray] = None,
Expand Down
12 changes: 6 additions & 6 deletions src/tabmat/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
try:
import polars as pl
except ImportError:
pl = None
pl = None # type: ignore
try:
import pandas as pd
except ImportError:
pd = None
pd = None # type: ignore


def _is_boolean(series, engine: str):
Expand Down Expand Up @@ -106,7 +106,7 @@ def _from_dataframe(
"""

matrices: list[Union[DenseMatrix, SparseMatrix, CategoricalMatrix]] = []
indices: list[list[int]] = []
indices: list[np.ndarray] = []
is_cat: list[bool] = []

dense_dfidx = [] # column index in original DataFrame
Expand Down Expand Up @@ -197,7 +197,7 @@ def _from_dataframe(
term_names=np.asarray(df.columns)[dense_dfidx],
)
)
indices.append(dense_tmidx)
indices.append(np.asarray(dense_tmidx))
is_cat.append(False)
if sparse_dfidx:
matrices.append(
Expand All @@ -208,7 +208,7 @@ def _from_dataframe(
term_names=np.asarray(df.columns)[sparse_dfidx],
)
)
indices.append(sparse_tmidx)
indices.append(np.asarray(sparse_tmidx))
is_cat.append(False)

if cat_position == "end":
Expand Down Expand Up @@ -357,7 +357,7 @@ def from_polars(
)


def _reindex_cat(indices, is_cat, mxcolidx):
def _reindex_cat(indices: list[np.ndarray], is_cat: list[bool], mxcolidx: int):
new_indices = []
for mat_indices, is_cat_ in zip(indices, is_cat):
if is_cat_:
Expand Down
15 changes: 9 additions & 6 deletions src/tabmat/dense_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def unpack(self):
return self._array

def sandwich(
self, d: np.ndarray, rows: np.ndarray = None, cols: np.ndarray = None
self,
d: np.ndarray,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
d = np.asarray(d)
Expand Down Expand Up @@ -219,9 +222,9 @@ def _matvec_helper(
def transpose_matvec(
self,
vec: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
out: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform: self[rows, cols].T @ vec[rows]."""
check_transpose_matvec_out_shape(self, out)
Expand All @@ -230,8 +233,8 @@ def transpose_matvec(
def matvec(
self,
vec: Union[np.ndarray, list],
cols: np.ndarray = None,
out: np.ndarray = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform self[:, cols] @ other[cols]."""
check_matvec_out_shape(self, out)
Expand Down
15 changes: 10 additions & 5 deletions src/tabmat/matrix_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class MatrixBase(ABC):
dtype: np.dtype

@abstractmethod
def matvec(self, other, cols: np.ndarray = None, out: np.ndarray = None):
def matvec(
self, other, cols: Optional[np.ndarray] = None, out: Optional[np.ndarray] = None
):
"""
Perform: self[:, cols] @ other[cols], so result[i] = sum_j self[i, j] other[j].
Expand All @@ -32,9 +34,9 @@ def matvec(self, other, cols: np.ndarray = None, out: np.ndarray = None):
def transpose_matvec(
self,
vec: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
out: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Perform: self[rows, cols].T @ vec[rows], so result[i] = sum_j self[j, i] vec[j].
Expand All @@ -61,7 +63,10 @@ def transpose_matvec(

@abstractmethod
def sandwich(
self, d: np.ndarray, rows: np.ndarray = None, cols: np.ndarray = None
self,
d: np.ndarray,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Perform a sandwich product: (self[rows, cols].T * d[rows]) @ self[rows, cols].
Expand Down
23 changes: 14 additions & 9 deletions src/tabmat/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def dot(self, other):
return self._array.dot(other)

def sandwich(
self, d: np.ndarray, rows: np.ndarray = None, cols: np.ndarray = None
self,
d: np.ndarray,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
d = np.asarray(d)
Expand All @@ -190,7 +193,7 @@ def _cross_sandwich(
self,
other: MatrixBase,
d: np.ndarray,
rows: np.ndarray,
rows: Optional[np.ndarray],
L_cols: Optional[np.ndarray] = None,
R_cols: Optional[np.ndarray] = None,
):
Expand All @@ -209,9 +212,9 @@ def sandwich_dense(
self,
B: np.ndarray,
d: np.ndarray,
rows: np.ndarray,
L_cols: np.ndarray,
R_cols: np.ndarray,
rows: Optional[np.ndarray],
L_cols: Optional[np.ndarray],
R_cols: Optional[np.ndarray],
) -> np.ndarray:
"""Perform a sandwich product: self.T @ diag(d) @ B."""
if not hasattr(d, "dtype"):
Expand Down Expand Up @@ -276,17 +279,19 @@ def _matvec_helper(
out[rows] += res
return out

def matvec(self, vec, cols: np.ndarray = None, out: np.ndarray = None):
def matvec(
self, vec, cols: Optional[np.ndarray] = None, out: Optional[np.ndarray] = None
):
"""Perform self[:, cols] @ other[cols]."""
check_matvec_out_shape(self, out)
return self._matvec_helper(vec, None, cols, out, False)

def transpose_matvec(
self,
vec: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
out: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform: self[rows, cols].T @ vec[rows]."""
check_transpose_matvec_out_shape(self, out)
Expand Down
24 changes: 14 additions & 10 deletions src/tabmat/split_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def as_tabmat(a: Union[MatrixBase, StandardizedMatrix, np.ndarray, sps.spmatrix]
if isinstance(a, (MatrixBase, StandardizedMatrix)):
return a
elif sps.issparse(a):
return SparseMatrix(a.tocsc(copy=False))
return SparseMatrix(a.tocsc(copy=False)) # type: ignore
elif isinstance(a, np.ndarray):
return DenseMatrix(a)
else:
Expand Down Expand Up @@ -59,7 +59,7 @@ def hstack(tup: Sequence[Union[MatrixBase, np.ndarray, sps.spmatrix]]) -> Matrix
return SplitMatrix(matrices)


def _prepare_out_array(out: Optional[np.ndarray], out_shape, out_dtype):
def _prepare_out_array(out: Optional[np.ndarray], out_shape, out_dtype) -> np.ndarray:
if out is None:
out = np.zeros(out_shape, out_dtype)
else:
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(

def _split_col_subsets(
self, cols: Optional[np.ndarray]
) -> tuple[list[np.ndarray], list[Optional[np.ndarray]], int]:
) -> tuple[list[np.ndarray], Union[list[np.ndarray], list[None]], int]:
"""
Return tuple of things helpful for applying column restrictions to sub-matrices.
Expand Down Expand Up @@ -322,8 +322,8 @@ def getcol(self, i: int) -> Union[np.ndarray, sps.csr_matrix]:
def sandwich(
self,
d: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
if np.shape(d) != (self.shape[0],):
Expand Down Expand Up @@ -370,7 +370,10 @@ def _get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarra
return col_stds

def matvec(
self, v: np.ndarray, cols: np.ndarray = None, out: np.ndarray = None
self,
v: np.ndarray,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform self[:, cols] @ other[cols]."""
assert not isinstance(v, sps.spmatrix)
Expand All @@ -393,6 +396,7 @@ def matvec(
# as the target for storing the final output. This reduces the number
# of output arrays allocated from 2 to 1.
is_matrix_dense = [isinstance(m, DenseMatrix) for m in self.matrices]
dense_matrix_idx: Union[int, np.intp]
if np.any(is_matrix_dense):
dense_matrix_idx = np.argmax(is_matrix_dense)
sub_cols = subset_cols[dense_matrix_idx]
Expand All @@ -411,14 +415,14 @@ def matvec(
continue
in_vec = v[idx, ...]
mat.matvec(in_vec, sub_cols, out=out)
return out
return out # type: ignore

def transpose_matvec(
self,
v: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
out: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Perform: self[rows, cols].T @ vec[rows].
Expand Down
24 changes: 14 additions & 10 deletions src/tabmat/standardized_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self,
mat: MatrixBase,
shift: Union[np.ndarray, list],
mult: Union[np.ndarray, list] = None,
mult: Optional[Union[np.ndarray, list]] = None,
):
shift_arr = np.atleast_1d(np.squeeze(shift))
expected_shape = (mat.shape[1],)
Expand All @@ -48,14 +48,15 @@ def __init__(
but it has shape {np.asarray(shift).shape}"""
)

mult_arr = mult
if mult_arr is not None:
mult_arr = np.atleast_1d(np.squeeze(mult_arr))
if mult is not None:
mult_arr = np.atleast_1d(np.squeeze(mult))
if not mult_arr.shape == expected_shape:
raise ValueError(
f"""Expected mult to be able to conform to shape {expected_shape},
but it has shape {np.asarray(mult).shape}"""
)
else:
mult_arr = None

self.shift = shift_arr
self.mult = mult_arr
Expand All @@ -67,8 +68,8 @@ def __init__(
def matvec(
self,
other_mat: Union[np.ndarray, list],
cols: np.ndarray = None,
out: np.ndarray = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Perform self[:, cols] @ other[cols].
Expand Down Expand Up @@ -119,7 +120,10 @@ def getcol(self, i: int):
return StandardizedMatrix(col, [self.shift[i]], mult)

def sandwich(
self, d: np.ndarray, rows: np.ndarray = None, cols: np.ndarray = None
self,
d: np.ndarray,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Perform a sandwich product: X.T @ diag(d) @ X."""
if not hasattr(d, "dtype"):
Expand Down Expand Up @@ -169,9 +173,9 @@ def unstandardize(self) -> MatrixBase:
def transpose_matvec(
self,
other: Union[np.ndarray, list],
rows: np.ndarray = None,
cols: np.ndarray = None,
out: np.ndarray = None,
rows: Optional[np.ndarray] = None,
cols: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Perform: self[rows, cols].T @ vec[rows].
Expand Down

0 comments on commit 2007569

Please sign in to comment.