Skip to content

Commit

Permalink
argmin,argmax for non-sparse dims of sparse tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Aug 14, 2024
1 parent 6ef94b1 commit 4918f04
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,8 @@ def argmax(x: Tensor, dim: DimFilter, index_dim=channel('index')):
else: # all sparse dims are reduced
result = scatter_val.true_values[0]
return rename_dims(result, channel(scatter_val), index_dim.with_sizes(dims.name_list))
elif dims.isdisjoint(sparse_dims(x)): # only argmax across values dim
return x._with_values(argmax(x._values, dims))
else:
raise NotImplementedError
v_native = reshaped_native(x, [keep, dims])
Expand Down Expand Up @@ -1755,6 +1757,8 @@ def argmin(x: Tensor, dim: DimFilter, index_dim=channel('index')):
else: # all sparse dims are reduced
result = scatter_val.true_values[0]
return rename_dims(result, channel(scatter_val), index_dim.with_sizes(dims.name_list))
elif dims.isdisjoint(sparse_dims(x)): # only argmin across values dim
return x._with_values(argmin(x._values, dims))
else:
raise NotImplementedError
v_native = reshaped_native(x, [keep, dims])
Expand Down

0 comments on commit 4918f04

Please sign in to comment.