Skip to content

Commit

Permalink
Turn tensor.Tensor.shape and tensor.Tensor.strides into tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Sep 9, 2024
1 parent ed489bc commit b09f519
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def visit_Attribute(self, node):
if isinstance(value, Tensor):
inner = value.dtype

return Symbol(inner.__dict__[node.attr]).node
return Symbol(getattr(inner, node.attr)).node

self.generic_visit(node)

Expand Down
20 changes: 18 additions & 2 deletions src/ninetoothed/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(
self.name = f"tensor_{type(self).num_instances}"

if ndim is not None:
self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
else:
self.shape = shape

Expand Down Expand Up @@ -191,6 +191,22 @@ def stride(self, dim=None):

return self.strides[dim]

@property
def shape(self):
return self._shape

@shape.setter
def shape(self, value):
self._shape = tuple(value)

@property
def strides(self):
return self._strides

@strides.setter
def strides(self, value):
self._strides = tuple(value)

@property
def ndim(self):
return len(self.shape)
Expand Down

0 comments on commit b09f519

Please sign in to comment.