Skip to content

Commit

Permalink
Fix cached(CompactSparseTensor)
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Aug 18, 2024
1 parent 5ffdafd commit c48229e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def tensor(data,
return data
else:
if None in shape.sizes:
shape = shape.with_sizes(data.shape.sizes)
shape = shape.with_sizes(data.shape)
return data._with_shape_replaced(shape)
elif isinstance(data, Shape):
if shape is None:
Expand Down Expand Up @@ -1921,7 +1921,7 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor], attr_type=variable


def cached(t: TensorOrTree) -> TensorOrTree:
from ._sparse import SparseCoordinateTensor, CompressedSparseMatrix
from ._sparse import SparseCoordinateTensor, CompressedSparseMatrix, CompactSparseTensor
assert isinstance(t, (Tensor, PhiTreeNode)), f"All arguments must be Tensors but got {type(t)}"
if isinstance(t, NativeTensor):
return t._cached()
Expand All @@ -1934,9 +1934,11 @@ def cached(t: TensorOrTree) -> TensorOrTree:
native = choose_backend(*natives).stack(natives, axis=t.shape.index(t._stack_dim.name))
return NativeTensor(native, t.shape)
elif isinstance(t, SparseCoordinateTensor):
return SparseCoordinateTensor(cached(t._indices), cached(t._values), t._dense_shape, t._can_contain_double_entries, t._indices_sorted, t._indices_constant)
return SparseCoordinateTensor(cached(t._indices), cached(t._values), t._dense_shape, t._can_contain_double_entries, t._indices_sorted, t._indices_constant, t._matrix_rank)
elif isinstance(t, CompressedSparseMatrix):
return CompressedSparseMatrix(cached(t._indices), cached(t._pointers), cached(t._values), t._uncompressed_dims, t._compressed_dims, t._indices_constant, t._uncompressed_offset, t._uncompressed_indices, t._uncompressed_indices_perm)
return CompressedSparseMatrix(cached(t._indices), cached(t._pointers), cached(t._values), t._uncompressed_dims, t._compressed_dims, t._indices_constant, t._uncompressed_offset, t._uncompressed_indices, t._uncompressed_indices_perm, t._matrix_rank)
elif isinstance(t, CompactSparseTensor):
return CompactSparseTensor(cached(t._indices), cached(t._values), t._compressed_dims, t._indices_constant, t._matrix_rank)
elif isinstance(t, Layout):
return t
elif isinstance(t, PhiTreeNode):
Expand Down

0 comments on commit c48229e

Please sign in to comment.