Skip to content

Commit

Permalink
Be consistent when instantiating from 1d arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Aug 9, 2023
1 parent 38813e7 commit 78d0278
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/tabmat/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ class SparseMatrix(MatrixBase):
SparseMatrix is instantiated in the same way as scipy.sparse.csc_matrix.
"""

def __init__(self, arg1, shape=None, dtype=None, copy=False):
self._array = sps.csc_matrix(arg1, shape, dtype, copy)
def __init__(self, input_array, shape=None, dtype=None, copy=False):
if isinstance(input_array, np.ndarray):
if input_array.ndim == 1:
input_array = input_array.reshape(-1, 1)
elif input_array.ndim > 2:
raise ValueError("Input array must be 1- or 2-dimensional")

self._array = sps.csc_matrix(input_array, shape, dtype, copy)

self.idx_dtype = max(self._array.indices.dtype, self._array.indptr.dtype)
if self._array.indices.dtype != self.idx_dtype:
Expand Down

0 comments on commit 78d0278

Please sign in to comment.