From cb1ccf03e4da1c030e0275c11cf4135080976353 Mon Sep 17 00:00:00 2001 From: bottler Date: Fri, 5 Jul 2024 15:03:44 +0000 Subject: [PATCH] Revert "small export from fbcode" This reverts commit 5ef2741e85a176ce4903f10d1d3c650501897e3c. __original_commit__ = fairinternal/xformers@e9c088791f7fed014ea6a5e6a3650dc445859054 --- .../benchmarks/benchmark_attn_decoding.py | 10 ++----- .../synchronization_kernels.cu | 2 +- xformers/ops/fmha/flash.py | 2 +- xformers/triton/vararg_kernel.py | 29 +++++-------------- 4 files changed, 11 insertions(+), 32 deletions(-) diff --git a/xformers/benchmarks/benchmark_attn_decoding.py b/xformers/benchmarks/benchmark_attn_decoding.py index 24b902b5a4..f53f4aa0c5 100644 --- a/xformers/benchmarks/benchmark_attn_decoding.py +++ b/xformers/benchmarks/benchmark_attn_decoding.py @@ -393,10 +393,8 @@ def test_flash_attention_decoder(name, case): torch.testing.assert_close(decoder_output, baseline_out, atol=1e-2, rtol=0) -def main() -> None: - """ - run performance benchmark - """ +# run benchmark performance +if __name__ == "__main__": benchmark_main_helper2( "attn_decoding", fw=True, @@ -404,7 +402,3 @@ def main() -> None: functions=BENCHMARKS, min_run_time=min_run_time, ) - - -if __name__ == "__main__": - main() # pragma: no cover diff --git a/xformers/csrc/sequence_parallel_fused/synchronization_kernels.cu b/xformers/csrc/sequence_parallel_fused/synchronization_kernels.cu index df3775e1b0..afc1bb5bd1 100644 --- a/xformers/csrc/sequence_parallel_fused/synchronization_kernels.cu +++ b/xformers/csrc/sequence_parallel_fused/synchronization_kernels.cu @@ -51,7 +51,7 @@ __host__ __device__ const char* assertion, const char* file, unsigned int line, - const char* function) noexcept __attribute__((__noreturn__)); + const char* function) throw() __attribute__((__noreturn__)); } #endif // NDEBUG diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 95e209d898..4b85ebecf0 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -192,7 +192,7 @@ def _flash_fwd( cu_seq_lens_q, cu_seq_lens_k, seqused_k, - block_tables, + block_tables, # block_table None, # alibi_slopes max_seq_len_q, max_seq_len_k, diff --git a/xformers/triton/vararg_kernel.py b/xformers/triton/vararg_kernel.py index c55beccb01..11d1831c6c 100644 --- a/xformers/triton/vararg_kernel.py +++ b/xformers/triton/vararg_kernel.py @@ -7,9 +7,7 @@ import copy import functools import linecache -import os import sys -import tempfile from typing import Any, Dict, List import triton @@ -118,11 +116,6 @@ def visit_For(self, node): _getlines_orig = None _FILENAME_TO_SRC: Dict[str, List[str]] = {} -# Materializing the codegen to disk can be useful for external tools, e.g. ncu -# Disabled by default because writing to disk at module import time is unexpected and error-prone. -_should_materialize_codegen = os.environ.get("XFORMERS_MATERIALIZE_CODEGEN") == "1" -_tmp_dir = None - def _monkey_patched_getlines(filename, module_globals=None): if filename in _FILENAME_TO_SRC: @@ -139,7 +132,7 @@ def unroll_varargs(kernel, N: int): NOTE: Because it's quite costly to call `triton.jit`, we cache the returned value with `lru_cache` """ - global _FILENAME_TO_SRC, _getlines_orig, _tmp_dir + global _FILENAME_TO_SRC, _getlines_orig k = triton.JITFunction(kernel.fn) parsed = ast.parse(k.src) @@ -155,20 +148,7 @@ def unroll_varargs(kernel, N: int): # Now we want to `eval` the function, but we need all this # boilerplate code to make sure triton can run `inspect.getsource` - fn_basename = f"unroll_varargs-{kernel.fn.__name__}-{N}" - if _should_materialize_codegen: - if not _tmp_dir: - _tmp_dir = tempfile.TemporaryDirectory() - fn_filename = os.path.join(_tmp_dir.name, f"{fn_basename}.py") - with open(fn_filename, "w") as f: - f.write(new_src) - else: - # Patch `getlines` only the first time - if not _FILENAME_TO_SRC: - _getlines_orig = linecache.getlines - linecache.getlines = _monkey_patched_getlines - fn_filename = f"<{fn_basename}>" - _FILENAME_TO_SRC[fn_filename] = new_src.splitlines(keepends=True) + fn_filename = f"" # Create function given source code = compile(new_src, fn_filename, "exec") @@ -177,6 +157,11 @@ def unroll_varargs(kernel, N: int): exec(code, kernel.fn.__globals__, _locals) assert len(_locals) == 1, len(_locals) fn = next(iter(_locals.values())) + # Patch `getlines` only the first time + if not _FILENAME_TO_SRC: + _getlines_orig = linecache.getlines + linecache.getlines = _monkey_patched_getlines + _FILENAME_TO_SRC[fn_filename] = [line + "\n" for line in new_src.splitlines()] jitted_fn = triton.jit(fn) jitted_fn.src = new_src