Skip to content

Commit

Permalink
Revert "small export from fbcode"
Browse files Browse the repository at this point in the history
This reverts commit 5ef2741.

__original_commit__ = fairinternal/xformers@e9c0887
  • Loading branch information
bottler authored and xFormers Bot committed Jul 5, 2024
1 parent cdab041 commit cb1ccf0
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 32 deletions.
10 changes: 2 additions & 8 deletions xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,12 @@ def test_flash_attention_decoder(name, case):
torch.testing.assert_close(decoder_output, baseline_out, atol=1e-2, rtol=0)


def main() -> None:
"""
run performance benchmark
"""
# run benchmark performance
if __name__ == "__main__":
benchmark_main_helper2(
"attn_decoding",
fw=True,
cases=CASES,
functions=BENCHMARKS,
min_run_time=min_run_time,
)


if __name__ == "__main__":
main() # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ __host__ __device__
const char* assertion,
const char* file,
unsigned int line,
const char* function) noexcept __attribute__((__noreturn__));
const char* function) throw() __attribute__((__noreturn__));
}
#endif // NDEBUG

Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _flash_fwd(
cu_seq_lens_q,
cu_seq_lens_k,
seqused_k,
block_tables,
block_tables, # block_table
None, # alibi_slopes
max_seq_len_q,
max_seq_len_k,
Expand Down
29 changes: 7 additions & 22 deletions xformers/triton/vararg_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import copy
import functools
import linecache
import os
import sys
import tempfile
from typing import Any, Dict, List

import triton
Expand Down Expand Up @@ -118,11 +116,6 @@ def visit_For(self, node):
_getlines_orig = None
_FILENAME_TO_SRC: Dict[str, List[str]] = {}

# Materializing the codegen to disk can be useful for external tools, e.g. ncu
# Disabled by default because writing to disk at module import time is unexpected and error-prone.
_should_materialize_codegen = os.environ.get("XFORMERS_MATERIALIZE_CODEGEN") == "1"
_tmp_dir = None


def _monkey_patched_getlines(filename, module_globals=None):
if filename in _FILENAME_TO_SRC:
Expand All @@ -139,7 +132,7 @@ def unroll_varargs(kernel, N: int):
NOTE: Because it's quite costly to call `triton.jit`,
we cache the returned value with `lru_cache`
"""
global _FILENAME_TO_SRC, _getlines_orig, _tmp_dir
global _FILENAME_TO_SRC, _getlines_orig

k = triton.JITFunction(kernel.fn)
parsed = ast.parse(k.src)
Expand All @@ -155,20 +148,7 @@ def unroll_varargs(kernel, N: int):
# Now we want to `eval` the function, but we need all this
# boilerplate code to make sure triton can run `inspect.getsource`

fn_basename = f"unroll_varargs-{kernel.fn.__name__}-{N}"
if _should_materialize_codegen:
if not _tmp_dir:
_tmp_dir = tempfile.TemporaryDirectory()
fn_filename = os.path.join(_tmp_dir.name, f"{fn_basename}.py")
with open(fn_filename, "w") as f:
f.write(new_src)
else:
# Patch `getlines` only the first time
if not _FILENAME_TO_SRC:
_getlines_orig = linecache.getlines
linecache.getlines = _monkey_patched_getlines
fn_filename = f"<{fn_basename}>"
_FILENAME_TO_SRC[fn_filename] = new_src.splitlines(keepends=True)
fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"

# Create function given source
code = compile(new_src, fn_filename, "exec")
Expand All @@ -177,6 +157,11 @@ def unroll_varargs(kernel, N: int):
exec(code, kernel.fn.__globals__, _locals)
assert len(_locals) == 1, len(_locals)
fn = next(iter(_locals.values()))
# Patch `getlines` only the first time
if not _FILENAME_TO_SRC:
_getlines_orig = linecache.getlines
linecache.getlines = _monkey_patched_getlines
_FILENAME_TO_SRC[fn_filename] = [line + "\n" for line in new_src.splitlines()]

jitted_fn = triton.jit(fn)
jitted_fn.src = new_src
Expand Down

0 comments on commit cb1ccf0

Please sign in to comment.