diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 822342364ae6b3..c62fa09e6fff4d 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1202,8 +1202,6 @@ class AOTInductorTestABICompatibleCpu(TestCase): "test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)), # There is a double-free issue which will be fixed in another PR "test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True), - "test_sdpa": TestFailure(("abi_compatible_cpu",)), - "test_sdpa_2": TestFailure(("abi_compatible_cpu",)), "test_simple_dynamic": TestFailure(("abi_compatible_cpu",)), }, ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 10cf6f6a97812b..4ac5e28626ce81 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1730,7 +1730,12 @@ def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args): else: raise NotImplementedError("unsupported type of {output=}") args = args + output_args - self.generate_c_shim_extern_kernel_call(extern_kernel.kernel, args) + assert ( + extern_kernel.abi_compatible_kernel is not None + ), f"abi_compatible_kernel is None for {extern_kernel.kernel=}" + self.generate_c_shim_extern_kernel_call( + extern_kernel.abi_compatible_kernel, args + ) for raii_handle in output_raii_handles: self.writeline(raii_handle) @@ -2332,7 +2337,7 @@ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: and isinstance(type_, torch.OptionalType) ): if val is None: - return "nullptr" + return "0" # nullptr is not available in C if isinstance(val, (bool, int, str, float)): var_name = f"var_{next(self.arg_var_id)}" self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b5ccf3fbb5033e..6311781361ca93 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3621,8 +3621,6 @@ def is_legacy_abi_kernel(self): return False def codegen_kwargs(self): - if not self.kwargs: - return [] if V.graph.cpp_wrapper: # FIXME: we should unconditionally fill self.kwargs with missing default values # instead of carrying an extra self.ordered_kwargs_for_cpp_kernel @@ -3777,9 +3775,35 @@ def __init__(self, count: int, device: torch.device): class ExternKernelAlloc(ExternKernel): + # Generate abi-compatible kernel names for shim kernels. + # Each individual shim kernel may have its own versioning rule. + # However, we don't expect we would end up with too many of such rules. + def _get_abi_compatible_kernel(self): + if not V.graph.cpp_wrapper: + return self.kernel + + def sdpa_ver_fn(): + # For sdpa, we need the v2 version only if any optional + # kwarg is missing. + if any( + self.get_kwargs_value(arg_name) is None + for arg_name in self.ordered_kwargs_for_cpp_kernel + ): + return f"{self.kernel}_v2" + else: + return self.kernel + + kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn} + if (ver_fn := kernel_to_ver.get(self.kernel, None)) is not None: + return ver_fn() + return self.kernel + def codegen(self, wrapper): self.codegen_comment(wrapper) args = [*self.codegen_args(), *self.codegen_kwargs()] + # Now we setup abi_compatible_kernel after self.kernel + # and kwargs are adjusted appropriately. + self.abi_compatible_kernel = self._get_abi_compatible_kernel() V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -3799,6 +3823,7 @@ def __init__( ) self.name = V.graph.register_buffer(self) self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel + self.abi_compatible_kernel = None self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel def should_allocate(self): diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 42e8603999ef73..cebb3e1d1701f5 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -177,6 +177,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( AtenTensorHandle* ret // returns new reference ); +// This version is deprecated. We will remove it later AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( AtenTensorHandle query, AtenTensorHandle key, @@ -196,6 +197,26 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( AtenTensorHandle* ret8 // returns new reference ); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + int is_causal, + int return_debug_mask, + double* scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( AtenTensorHandle self, AtenTensorHandle mat2, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 0bfef59bfbace9..edad43c3033172 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -252,14 +252,14 @@ AOTITorchError aoti_torch_create_tensor_from_blob( }); } -AOTITorchError aoti_torch__scaled_dot_product_flash_attention( +AOTITorchError aoti_torch__scaled_dot_product_flash_attention_v2( AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, - bool is_causal, - bool return_debug_mask, - double scale, + int is_causal, + int return_debug_mask, + double* scale, AtenTensorHandle* ret0, // returns new reference AtenTensorHandle* ret1, // returns new reference AtenTensorHandle* ret2, // returns new reference @@ -274,6 +274,7 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention( at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query); at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key); at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value); + auto optional_scale = pointer_to_optional(scale); auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] = at::_scaled_dot_product_flash_attention( *query_tensor, @@ -282,7 +283,7 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention( dropout_p, is_causal, return_debug_mask, - scale); + optional_scale); at::Tensor* ret0_tensor = new at::Tensor(std::move(r0)); *ret0 = tensor_pointer_to_tensor_handle(ret0_tensor); @@ -308,6 +309,43 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention( }); } +AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +) { + return aoti_torch__scaled_dot_product_flash_attention_v2( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + &scale, + ret0, + ret1, + ret2, + ret3, + ret4, + ret5, + ret6, + ret7, + ret8); +} + AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* out_tensor = new at::Tensor();