Skip to content

Commit

Permalink
Skip the classes instead of the methods if CUDA is not available
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Aug 2, 2024
1 parent 8940fbd commit 301c5b8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
10 changes: 5 additions & 5 deletions tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def add_kernel(
return output


class TestAdd:
@skip_if_cuda_not_available
class TestCUDA:
@classmethod
def setup_class(cls):
torch.manual_seed(0)
Expand All @@ -33,9 +34,8 @@ def setup_class(cls):
cls.lhs = torch.rand(size, device="cuda")
cls.rhs = torch.rand(size, device="cuda")

@skip_if_cuda_not_available
def test_cuda(self):
lhs = type(self).lhs
rhs = type(self).rhs
def test_fp32(self):
lhs = type(self).lhs.to(torch.float32)
rhs = type(self).rhs.to(torch.float32)

assert torch.allclose(add(lhs, rhs), lhs + rhs)
17 changes: 8 additions & 9 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,25 @@ def matmul_kernel(lhs: lhs_tiled, rhs: rhs_tiled, output: output_tiled):
return output


class TestMatMul:
@skip_if_cuda_not_available
class TestCUDA:
@classmethod
def setup_class(cls):
torch.manual_seed(0)

shape = (512, 512)

cls.lhs = torch.randn(shape, device="cuda", dtype=torch.float16)
cls.rhs = torch.randn(shape, device="cuda", dtype=torch.float16)
cls.lhs = torch.randn(shape, device="cuda")
cls.rhs = torch.randn(shape, device="cuda")

@skip_if_cuda_not_available
def test_cuda_fp16(self):
lhs = type(self).lhs
rhs = type(self).rhs
def test_fp16(self):
lhs = type(self).lhs.to(torch.float16)
rhs = type(self).rhs.to(torch.float16)

assert torch.allclose(matmul(lhs, rhs), torch.matmul(lhs, rhs))

@skip_if_cuda_not_available
@skip_if_float8_e5m2_not_supported
def test_cuda_fp8(self):
def test_fp8(self):
lhs = type(self).lhs.to(torch.float8_e5m2)
rhs = type(self).rhs.T.to(torch.float8_e5m2)

Expand Down

0 comments on commit 301c5b8

Please sign in to comment.