Skip to content

Commit

Permalink
IterDomain resize for pad, cat, slice
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored and jacobhinkle committed Mar 15, 2023
1 parent 7ee79cf commit 58990bd
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 56 deletions.
12 changes: 9 additions & 3 deletions csrc/executor_kernel_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,15 @@ struct ArgAbstract {
bool isType(ArgType type) const override { \
return ArgType::TARGET_TYPE == type; \
} \
ArgType type() const override { return ArgType::TARGET_TYPE; } \
const void* arg() const override { return &ARG_NAME; } \
void* arg() override { return &ARG_NAME; } \
ArgType type() const override { \
return ArgType::TARGET_TYPE; \
} \
const void* arg() const override { \
return &ARG_NAME; \
} \
void* arg() override { \
return &ARG_NAME; \
} \
std::unique_ptr<ArgAbstract> copy_unique_ptr() const override { \
return std::make_unique<TARGET_TYPE##Arg>(*this); \
}
Expand Down
41 changes: 13 additions & 28 deletions csrc/kernel_db/test/test_data/kernel_db_for_query_test/kernel_0.cu
Original file line number Diff line number Diff line change
@@ -1,41 +1,26 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
__global__ void kernel1(
Tensor<float, 3> T0,
Tensor<float, 3> T1,
Tensor<float, 3> T2) {
__global__ void kernel1(Tensor<float, 3> T0, Tensor<float, 3> T1, Tensor<float, 3> T2) {
int i76;
i76 =
((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) /
(T0.size[1] * T0.size[2]);
i76 = ((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * T0.size[2]);
int i78;
i78 = (((((nvfuser_index_t)blockIdx.x) * 128) +
((nvfuser_index_t)threadIdx.x)) %
(T0.size[1] * T0.size[2])) /
T0.size[2];
i78 = (((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * T0.size[2])) / T0.size[2];
int i79;
i79 = (((((nvfuser_index_t)blockIdx.x) * 128) +
((nvfuser_index_t)threadIdx.x)) %
(T0.size[1] * T0.size[2])) %
T0.size[2];
i79 = (((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * T0.size[2])) % T0.size[2];
int i120;
i120 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
if ((i120 < (T0.size[0] * (T0.size[1] * T0.size[2])))) {
float T4[1];
T4[0] = 0;
T4[0] =
T1[(i76 * T1.stride[0]) + (i78 * T1.stride[1]) + (i79 * T1.stride[2])];
T4[0]
= T1[(i76 * T1.stride[0]) + (i78 * T1.stride[1]) + (i79 * T1.stride[2])];
float T3[1];
T3[0] = 0;
T3[0] =
T0[(i76 * T0.stride[0]) + (i78 * T0.stride[1]) + (i79 * T0.stride[2])];
T3[0]
= T0[(i76 * T0.stride[0]) + (i78 * T0.stride[1]) + (i79 * T0.stride[2])];
float T5[1];
T5[0] = T3[0] + T4[0];
T2[i120] = T5[0];
T5[0]
= T3[0]
+ T4[0];
T2[i120]
= T5[0];
}
}
20 changes: 12 additions & 8 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,12 @@ TensorView* eye(Val* size, DataType dtype) {

// UNARY OPERATIONS

#define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \
Val* op_name(Val* v) { return unaryOp(UnaryOpType::op_type, v); } \
TensorView* op_name(TensorView* tv) { \
return unaryOp(UnaryOpType::op_type, tv); \
#define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \
Val* op_name(Val* v) { \
return unaryOp(UnaryOpType::op_type, v); \
} \
TensorView* op_name(TensorView* tv) { \
return unaryOp(UnaryOpType::op_type, tv); \
}

NVFUSER_DEFINE_UNARY_OP(set, Set)
Expand Down Expand Up @@ -686,10 +688,12 @@ NVFUSER_DEFINE_UNARY_FLOAT_OP(tan, Tan)
NVFUSER_DEFINE_UNARY_FLOAT_OP(tanh, Tanh)
#undef NVFUSER_DEFINE_UNARY_FLOAT_OP

#define NVFUSER_DEFINE_UNARY_IS_OP(op_name, op_type) \
Val* op_name(Val* v) { return unaryIsOp(UnaryOpType::op_type, v); } \
TensorView* op_name(TensorView* tv) { \
return unaryIsOp(UnaryOpType::op_type, tv); \
#define NVFUSER_DEFINE_UNARY_IS_OP(op_name, op_type) \
Val* op_name(Val* v) { \
return unaryIsOp(UnaryOpType::op_type, v); \
} \
TensorView* op_name(TensorView* tv) { \
return unaryIsOp(UnaryOpType::op_type, tv); \
}

NVFUSER_DEFINE_UNARY_IS_OP(isfinite, IsFinite)
Expand Down
5 changes: 3 additions & 2 deletions csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,9 @@ std::vector<KeyType> getSortedKeys(

// Based on https://stackoverflow.com/a/9154394
template <typename T>
static auto hasToStringHelper(int)
-> decltype(std::declval<typename std::remove_pointer<T>::type>().toString(), std::true_type{});
static auto hasToStringHelper(int) -> decltype(
std::declval<typename std::remove_pointer<T>::type>().toString(),
std::true_type{});

template <typename>
static auto hasToStringHelper(long) -> std::false_type;
Expand Down
9 changes: 4 additions & 5 deletions python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# we need to import _C here to avoid confusing error message generated from failure in this python script ended up with
# complaining on `_C` not defined for `_C._FusionDefinition`
# we need to import _C here to avoid confusing error message generated from failure in this python script ended up with complaining on `_C` not defined for `_C._FusionDefinition`
from . import _C
from ._C import * # noqa: F401,F403
from ._C import *

class FusionDefinition(_C._FusionDefinition):
def __enter__(self):
Expand Down Expand Up @@ -53,7 +52,7 @@ def execute(self, inputs, **kwargs):
def from_pytorch(self, tensor) :
"""
Defines an nvfuser input tensor from a pytorch tensor
Args:
tensor (torch.Tensor): Input tensor to nvFuser
Expand All @@ -69,7 +68,7 @@ def from_pytorch(self, tensor) :
raise ValueError("Tensor should be on a cuda device!")

return self.define_tensor(sizes=tensor.size(), strides=tensor.stride(),
dtype=torch_dtype_to_nvfuser_dtype(tensor.dtype))
dtype=torch_dtype_to_nvfuser_dtype(tensor.dtype))

from .nvfuser_version import __version__

Expand Down
4 changes: 3 additions & 1 deletion python/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._C import DataType

NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]

_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
Expand All @@ -31,3 +31,5 @@ def torch_dtype_to_nvfuser_dtype(dtype: Union[torch.dtype, NumberTypeType]):
Translates from torch.dtype to nvFuser's DataType enum
"""
return _torch_dtype_to_nvfuser_dtype_map[dtype]


8 changes: 4 additions & 4 deletions python_tests/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def fusion_func(fd: FusionDefinition) :
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

# Is there a better way to test distribution?!
self.assertTrue(nvf_out[0].mean().cpu().float().isclose(torch.tensor((hi - lo) / 2.0), rtol=1e-2, atol=1e-2).item())
self.assertTrue(nvf_out[0].mean().cpu().float().isclose(torch.tensor((hi-lo)/2.0), rtol=1e-2, atol=1e-2).item())
self.assertTrue(nvf_out[0].min().cpu().float().isclose(torch.tensor(lo), rtol=1e-2, atol=1e-2).item())
self.assertTrue(nvf_out[0].max().cpu().float().isclose(torch.tensor(hi), rtol=1e-2, atol=1e-2).item())

Expand Down Expand Up @@ -1141,11 +1141,10 @@ def fusion_func(fd: FusionDefinition):

self.assertEqual(at_rfloat, rfloat)
self.assertEqual(at_rdouble, rdouble)

def test_reduction_complex_number(self) :
def test_dtype(torch_dtype):
inputs = [torch.randn(2, 32, device='cuda', dtype=torch_dtype)]

def fusion_func(fd: FusionDefinition) :
t0 = fd.from_pytorch(inputs[0])
t1 = fd.ops.sum(t0, [-1], False, torch_dtype_to_nvfuser_dtype(torch_dtype))
Expand Down Expand Up @@ -1453,7 +1452,8 @@ def nvfuser_fusion_id(fd : FusionDefinition) -> None :
def test_real_imag(self):
for dtype in [
torch.complex128,
torch.complex64]:
torch.complex64,
]:
inputs = [
torch.randn(5, dtype=dtype, device='cuda'),
]
Expand Down
8 changes: 4 additions & 4 deletions python_tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def __init__(self):
torch._C._debug_set_autodiff_subgraph_inlining(False)
self.old_value = torch._C._jit_set_autocast_mode(True)

if (RUN_CUDA):
if(RUN_CUDA):
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)

def restore(self):
if (RUN_CUDA):
if(RUN_CUDA):
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
Expand Down Expand Up @@ -178,7 +178,7 @@ def setUp(self):
if TEST_BF16:
self.support_tensor_dtypes.append(torch.bfloat16)

if (RUN_NVFUSER):
if(RUN_NVFUSER):
self.cuda_fuser_options = CudaFuserTestOptions()

def tearDown(self):
Expand All @@ -188,7 +188,7 @@ def tearDown(self):
if not disabled_flag:
torch._C._jit_set_nvfuser_skip_node_kind(op, True)

if (RUN_NVFUSER):
if(RUN_NVFUSER):
self.cuda_fuser_options.restore()
super(TestCudaFuser, self).tearDown()

Expand Down
2 changes: 1 addition & 1 deletion tools/examples/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
fd.add_output(T16)
fd.add_output(T20)
fd.add_output(T24)
fd.add_output(T32)
fd.add_output(T32)

0 comments on commit 58990bd

Please sign in to comment.