Skip to content

Commit

Permalink
Swtichting len(x.shape) to x.dims()
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Sep 9, 2024
1 parent 856c316 commit 3b2b7c2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
16 changes: 8 additions & 8 deletions torch_harmonics/distributed/distributed_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {len(x.shape)} instead")
if x.dims() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dims()} instead")

# we need to ensure that we can split the channels evenly
num_chans = x.shape[-3]
Expand Down Expand Up @@ -237,8 +237,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {len(x.shape)} instead")
if x.dims() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dims()} instead")

# we need to ensure that we can split the channels evenly
num_chans = x.shape[-3]
Expand Down Expand Up @@ -366,8 +366,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 4:
raise ValueError(f"Expected tensor with at least 4 dimensions but got {len(x.shape)} instead")
if x.dims() < 4:
raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dims()} instead")

# we need to ensure that we can split the channels evenly
num_chans = x.shape[-4]
Expand Down Expand Up @@ -490,8 +490,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 4:
raise ValueError(f"Expected tensor with at least 4 dimensions but got {len(x.shape)} instead")
if x.dims() < 4:
raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dims()} instead")

# store num channels
num_chans = x.shape[-4]
Expand Down
12 changes: 6 additions & 6 deletions torch_harmonics/sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")
if x.dims() < 2:
raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dims()} instead")

assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
Expand Down Expand Up @@ -275,8 +275,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {len(x.shape)} instead")
if x.dims() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dims()} instead")

assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
Expand Down Expand Up @@ -364,8 +364,8 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

if len(x.shape) < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {len(x.shape)} instead")
if x.dims() < 3:
raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dims()} instead")

assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
Expand Down

0 comments on commit 3b2b7c2

Please sign in to comment.