Skip to content

Commit

Permalink
Add squeeze method to Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Aug 21, 2024
1 parent 270ba96 commit ea03213
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/ninetoothed/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def expand(self, shape):
original=self.original,
)

def squeeze(self, dim):
# TODO: Add error handling.
return type(self)(
shape=[size for i, size in enumerate(self.shape) if dim != i],
dtype=self.dtype,
strides=[stride for i, stride in enumerate(self.strides) if dim != i],
original=self.original,
)

def names(self):
return (
{self.original.pointer_string()}
Expand Down

0 comments on commit ea03213

Please sign in to comment.