Skip to content

Commit

Permalink
[aotinductor] add versions for the sdpa shim api (pytorch#113487)
Browse files Browse the repository at this point in the history
In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

Pull Request resolved: pytorch#113487
Approved by: https://github.com/int3, https://github.com/desertfire
  • Loading branch information
chenyang78 authored and pytorchmergebot committed Nov 13, 2023
1 parent 6ea20f5 commit a144eb5
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 11 deletions.
2 changes: 0 additions & 2 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,8 +1202,6 @@ class AOTInductorTestABICompatibleCpu(TestCase):
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True),
"test_sdpa": TestFailure(("abi_compatible_cpu",)),
"test_sdpa_2": TestFailure(("abi_compatible_cpu",)),
"test_simple_dynamic": TestFailure(("abi_compatible_cpu",)),
},
)
Expand Down
9 changes: 7 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,12 @@ def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args):
else:
raise NotImplementedError("unsupported type of {output=}")
args = args + output_args
self.generate_c_shim_extern_kernel_call(extern_kernel.kernel, args)
assert (
extern_kernel.abi_compatible_kernel is not None
), f"abi_compatible_kernel is None for {extern_kernel.kernel=}"
self.generate_c_shim_extern_kernel_call(
extern_kernel.abi_compatible_kernel, args
)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)

Expand Down Expand Up @@ -2332,7 +2337,7 @@ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
and isinstance(type_, torch.OptionalType)
):
if val is None:
return "nullptr"
return "0" # nullptr is not available in C
if isinstance(val, (bool, int, str, float)):
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
Expand Down
29 changes: 27 additions & 2 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3621,8 +3621,6 @@ def is_legacy_abi_kernel(self):
return False

def codegen_kwargs(self):
if not self.kwargs:
return []
if V.graph.cpp_wrapper:
# FIXME: we should unconditionally fill self.kwargs with missing default values
# instead of carrying an extra self.ordered_kwargs_for_cpp_kernel
Expand Down Expand Up @@ -3777,9 +3775,35 @@ def __init__(self, count: int, device: torch.device):


class ExternKernelAlloc(ExternKernel):
# Generate abi-compatible kernel names for shim kernels.
# Each individual shim kernel may have its own versioning rule.
# However, we don't expect we would end up with too many of such rules.
def _get_abi_compatible_kernel(self):
if not V.graph.cpp_wrapper:
return self.kernel

def sdpa_ver_fn():
# For sdpa, we need the v2 version only if any optional
# kwarg is missing.
if any(
self.get_kwargs_value(arg_name) is None
for arg_name in self.ordered_kwargs_for_cpp_kernel
):
return f"{self.kernel}_v2"
else:
return self.kernel

kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
if (ver_fn := kernel_to_ver.get(self.kernel, None)) is not None:
return ver_fn()
return self.kernel

def codegen(self, wrapper):
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
# Now we setup abi_compatible_kernel after self.kernel
# and kwargs are adjusted appropriately.
self.abi_compatible_kernel = self._get_abi_compatible_kernel()
V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
Expand All @@ -3799,6 +3823,7 @@ def __init__(
)
self.name = V.graph.register_buffer(self)
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
self.abi_compatible_kernel = None
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel

def should_allocate(self):
Expand Down
21 changes: 21 additions & 0 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret // returns new reference
);

// This version is deprecated. We will remove it later
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle query,
AtenTensorHandle key,
Expand All @@ -196,6 +197,26 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle* ret8 // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch__scaled_dot_product_flash_attention_v2(
AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
int is_causal,
int return_debug_mask,
double* scale,
AtenTensorHandle* ret0, // returns new reference
AtenTensorHandle* ret1, // returns new reference
AtenTensorHandle* ret2, // returns new reference
AtenTensorHandle* ret3, // returns new reference
int64_t* ret4,
int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8 // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm(
AtenTensorHandle self,
AtenTensorHandle mat2,
Expand Down
48 changes: 43 additions & 5 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,14 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
});
}

AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AOTITorchError aoti_torch__scaled_dot_product_flash_attention_v2(
AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
double scale,
int is_causal,
int return_debug_mask,
double* scale,
AtenTensorHandle* ret0, // returns new reference
AtenTensorHandle* ret1, // returns new reference
AtenTensorHandle* ret2, // returns new reference
Expand All @@ -274,6 +274,7 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
auto optional_scale = pointer_to_optional(scale);
auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] =
at::_scaled_dot_product_flash_attention(
*query_tensor,
Expand All @@ -282,7 +283,7 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
dropout_p,
is_causal,
return_debug_mask,
scale);
optional_scale);

at::Tensor* ret0_tensor = new at::Tensor(std::move(r0));
*ret0 = tensor_pointer_to_tensor_handle(ret0_tensor);
Expand All @@ -308,6 +309,43 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
});
}

AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
double scale,
AtenTensorHandle* ret0, // returns new reference
AtenTensorHandle* ret1, // returns new reference
AtenTensorHandle* ret2, // returns new reference
AtenTensorHandle* ret3, // returns new reference
int64_t* ret4,
int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8 // returns new reference
) {
return aoti_torch__scaled_dot_product_flash_attention_v2(
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
&scale,
ret0,
ret1,
ret2,
ret3,
ret4,
ret5,
ret6,
ret7,
ret8);
}

AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* out_tensor = new at::Tensor();
Expand Down

0 comments on commit a144eb5

Please sign in to comment.