From f2a015602f053efffbebbca287d66d4440986490 Mon Sep 17 00:00:00 2001 From: Dan Zimmerman Date: Mon, 30 Sep 2024 11:40:18 -0700 Subject: [PATCH] Try to use triton.language.extra.libdevice when possible (#3196) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3196 X-link: https://github.com/pytorch/pytorch/pull/136997 X-link: https://github.com/facebookresearch/generative-recommenders/pull/90 X-link: https://github.com/facebookresearch/FBGEMM/pull/294 In view of https://github.com/triton-lang/triton/pull/3825 we should try to use `triton.language.extra.libdevice` instead of `triton.language.extra.cuda.libdevice`. Reviewed By: bertmaher, karthik-man Differential Revision: D63583965 fbshipit-source-id: d32f35f7524d45c1e7c95c095144ad27f16eaa5a --- .../experimental/gen_ai/test/kv_cache/rope_padded.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py index ab728db6f..4a7211c68 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py @@ -38,8 +38,12 @@ # pyre-fixme[21]: Could not find name `pow` in `triton.language.math`. from triton.language.math import pow except ImportError: - # @manual=//triton:triton - from triton.language.extra.cuda.libdevice import pow + try: + # @manual=//triton:triton + from triton.language.extra.libdevice import pow + except ImportError: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import pow _INTERNAL_DTYPE_MAP: Dict[str, int] = {"": 0, "f32": 1, "f64": 2}