From 09293c7a91de1775884c9f81b4a5198633407570 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Mon, 16 Dec 2024 16:42:22 -0800 Subject: [PATCH 01/16] Disable c10::optional macros in deeplearning Reviewed By: jwfromm Differential Revision: D67293350 fbshipit-source-id: 76ee573031729fd918cbdc0133c14f3fbbe3decf --- .../gen_ai/src/quantize/quantize.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index cb17b4e23..5cc5e851d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -302,7 +302,7 @@ at::Tensor f8f8bf16_rowwise_meta( at::Tensor WQ, // FP8 at::Tensor /* x_scale */, at::Tensor /* w_scale */, - std::optional /* bias = c10::nullopt */, + std::optional /* bias = std::nullopt */, bool /* use_fast_accum = true */) { const at::SymInt M = XQ.sym_size(0); const at::SymInt N = WQ.sym_size(0); @@ -316,7 +316,7 @@ void f8f8bf16_rowwise_out_meta( at::Tensor /* x_scale */, at::Tensor /* w_scale */, at::Tensor /* output */, - std::optional /* bias = c10::nullopt */, + std::optional /* bias = std::nullopt */, bool /* use_fast_accum = true */) { return; } @@ -326,9 +326,9 @@ at::Tensor f8f8bf16_rowwise_batched_meta( at::Tensor WQ, // FP8 at::Tensor /* x_scale */, at::Tensor /* w_scale */, - std::optional /* bias = c10::nullopt */, + std::optional /* bias = std::nullopt */, bool /* use_fast_accum = true */, - std::optional /* output = c10::nullopt */) { + std::optional /* output = std::nullopt */) { int B = XQ.size(0); int M = XQ.size(1); int N = WQ.size(1); @@ -363,10 +363,10 @@ std::vector quantize_fp8_per_tensor_meta( at::Tensor f8f8bf16_cublas_meta( at::Tensor X, at::Tensor W, - std::optional /* x_scale = c10::nullopt */, - std::optional /* w_scale = c10::nullopt */, + std::optional /* x_scale = std::nullopt */, + std::optional /* w_scale = std::nullopt */, bool /* use_fast_accum = true */, - std::optional /* output = c10::nullopt */) { + std::optional /* output = std::nullopt */) { const at::SymInt M = X.sym_size(0); const at::SymInt N = W.sym_size(0); auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16)); @@ -458,7 +458,7 @@ std::vector f8f8bf16_grouped_meta( const std::vector& XQ, const std::vector& WQ, const std::vector& /* scale */, - std::optional /* zero_start_index_M = c10::nullopt */, + std::optional /* zero_start_index_M = std::nullopt */, bool /* use_fast_accum = true */) { std::vector Y; for (int i = 0; i < XQ.size(); i++) { @@ -472,7 +472,7 @@ std::vector f8f8bf16_grouped_meta( at::Tensor bf16bf16bf16_grouped_meta( const std::vector& X, const std::vector& W, - std::optional /* zero_start_index_M = c10::nullopt */ + std::optional /* zero_start_index_M = std::nullopt */ ) { int problem_count = X.size(); int total_output_size = 0; From b5140f2605e0cbcafa852890b980e03f94493255 Mon Sep 17 00:00:00 2001 From: Andrew Gallagher Date: Tue, 17 Dec 2024 14:38:44 -0800 Subject: [PATCH 02/16] Pull in PR for Kleidi-based FP16 kernel (#3507) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3507 X-link: https://github.com/facebookresearch/FBGEMM/pull/588 https://fb.workplace.com/groups/1943855219377055/permalink/1982050665557510/ https://gitlab.arm.com/kleidi/kleidiai/-/commit/8e6db85154a9dd100d5553a20c0df6ee437eb745 https://github.com/pytorch/FBGEMM/pull/3440/commits/667ce9b33c6a9dbc42c462ad08f4cafca2ac80c5 Reviewed By: meyering Differential Revision: D66766306 fbshipit-source-id: f70ef581e214edad15aa0ca093d753adbacb163d --- Makefile.FP16Benchmark.aarch64 | 46 + bench/BenchUtils.h | 2 + defs.bzl | 12 +- include/fbgemm/FbgemmFPCommon.h | 15 + include/fbgemm/FbgemmPackMatrixB.h | 10 +- src/FbgemmFP16.cc | 40 +- src/FbgemmFPCommon.cc | 129 ++ src/KleidiAIFP16UKernelsNeon.cc | 2474 ++++++++++++++++++++++++++++ src/KleidiAIFP16UKernelsNeon.h | 29 + 9 files changed, 2752 insertions(+), 5 deletions(-) create mode 100644 Makefile.FP16Benchmark.aarch64 create mode 100644 src/KleidiAIFP16UKernelsNeon.cc create mode 100644 src/KleidiAIFP16UKernelsNeon.h diff --git a/Makefile.FP16Benchmark.aarch64 b/Makefile.FP16Benchmark.aarch64 new file mode 100644 index 000000000..6c1365f4b --- /dev/null +++ b/Makefile.FP16Benchmark.aarch64 @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate +# SPDX-License-Identifier: BSD-3-Clause + +CC := g++ +SRCEXT := cc +OBJEXT := o +BUILD_DIR := obj +KLEIDIAI_DIR := external/kleidiai/kai/ukernels/matmul/matmul_f32_f32_f16p + +CFLAGS := -O3 -mcpu=native -std=c++20 -fvisibility=hidden -fopenmp +CDEFINES := -DFBGEMM_ENABLE_KLEIDIAI=1 -DCPUINFO_SUPPORTED_PLATFORM=1 -DFBGEMM_FP16_FALLBACK_TO_REF_KERNEL=1 +INCLUDES := -I./include -I./ -I./external/cpuinfo/include -I./external/asmjit/src -I./external/googletest/googletest/include -I./src + +SRC_DIR := src +BENCH_DIR := bench +TEST_DIR := test +SOURCES := src/FbgemmFP16.cc src/FbgemmFPCommon.cc src/Utils.cc src/RefImplementations.cc src/TransposeUtils.cc src/FbgemmFP16UKernelsSve128.cc +SRC_OBJECTS := $(patsubst $(SRC_DIR)/%,$(BUILD_DIR)/%,$(SOURCES:.$(SRCEXT)=.$(OBJEXT))) +KLEIDIAI_OBJECTS := $(BUILD_DIR)/KleidiAIFP16UKernelsNeon.$(OBJEXT) + +LIB := -lcpuinfo -fopenmp +LIBDIR := -L/usr/lib/aarch64-linux-gnu + +BENCH_TARGET := FP16Benchmark +TEST_TARGET := FP16Test + +all: $(BENCH_TARGET) $(TEST_TARGET) + +$(TEST_TARGET): $(SRC_OBJECTS) $(KLEIDIAI_OBJECTS) $(BUILD_DIR)/BenchUtils.$(OBJEXT) $(BUILD_DIR)/FP16Test.$(OBJEXT) + $(CC) -o $(TEST_TARGET) $^ $(LIBDIR) $(LIB) -lgtest + +$(BENCH_TARGET): $(SRC_OBJECTS) $(KLEIDIAI_OBJECTS) $(BUILD_DIR)/BenchUtils.$(OBJEXT) $(BUILD_DIR)/FP16Benchmark.$(OBJEXT) + $(CC) -o $(BENCH_TARGET) $^ $(LIBDIR) $(LIB) + +$(BUILD_DIR)/%.$(OBJEXT): $(SRC_DIR)/%.$(SRCEXT) + @mkdir -p $(dir $@) + $(CC) $(CDEFINES) $(CFLAGS) $(INCLUDES) -c -o $@ $< + +$(BUILD_DIR)/%.$(OBJEXT): $(BENCH_DIR)/%.$(SRCEXT) + $(CC) $(CDEFINES) $(CFLAGS) $(INCLUDES) -c -o $@ $< + +$(BUILD_DIR)/%.$(OBJEXT): $(KLEIDIAI_DIR)/%.$(SRCEXT) + $(CC) $(CDEFINES) $(CFLAGS) $(INCLUDES) -c -o $@ $< + +$(BUILD_DIR)/%.$(OBJEXT): $(TEST_DIR)/%.$(SRCEXT) + $(CC) $(CDEFINES) $(CFLAGS) $(INCLUDES) -c -o $@ $< \ No newline at end of file diff --git a/bench/BenchUtils.h b/bench/BenchUtils.h index e5e7dc6c1..334d4e53a 100644 --- a/bench/BenchUtils.h +++ b/bench/BenchUtils.h @@ -70,7 +70,9 @@ NOINLINE float cache_evict(const T& vec) { float dummy = 0.0f; for (std::size_t i = 0; i < dataSize; i += CACHE_LINE_SIZE) { dummy += data[i] * 1.0f; +#ifndef __aarch64__ _mm_mfence(); +#endif #ifndef _MSC_VER asm volatile("" ::: "memory"); #endif diff --git a/defs.bzl b/defs.bzl index d2366fb90..4f9c56b06 100644 --- a/defs.bzl +++ b/defs.bzl @@ -143,10 +143,18 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False): return asm_srcs if not msvc else intrinsics_srcs def get_fbgemm_inline_sve_srcs(msvc = False, buck = False): - intrinsics_srcs = ["src/FbgemmFP16UKernelsSve128.cc", "src/UtilsSve.cc"] + intrinsics_srcs = [ + "src/FbgemmFP16UKernelsSve128.cc", + "src/KleidiAIFP16UKernelsNeon.cc", + "src/UtilsSve.cc", + ] #FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different. - asm_srcs = ["src/FbgemmFP16UKernelsSve128.cc", "src/UtilsSve.cc"] + asm_srcs = [ + "src/FbgemmFP16UKernelsSve128.cc", + "src/KleidiAIFP16UKernelsNeon.cc", + "src/UtilsSve.cc", + ] if buck: return select({ "DEFAULT": asm_srcs, diff --git a/include/fbgemm/FbgemmFPCommon.h b/include/fbgemm/FbgemmFPCommon.h index 9f2f889d7..f3556123d 100644 --- a/include/fbgemm/FbgemmFPCommon.h +++ b/include/fbgemm/FbgemmFPCommon.h @@ -25,6 +25,9 @@ using partition_array_t = std::array, 2>, 121>; extern partition_array_t partition_avx2; extern partition_array_t partition_avx512; extern partition_array_t partition_sve128; +#ifdef FBGEMM_ENABLE_KLEIDIAI +extern partition_array_t partition_neon; +#endif template struct GemmParams { @@ -35,7 +38,11 @@ struct GemmParams { float* C; uint64_t ldc; uint64_t b_block_cols; +#ifdef FBGEMM_ENABLE_KLEIDIAI + uint64_t lda; +#else uint64_t b_block_size; +#endif }; template @@ -155,8 +162,12 @@ void cblas_gemm_compute( for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) { assert(kernel_nrows * kb < static_cast(scratchpad->size())); if (m != 1) { +#ifdef FBGEMM_ENABLE_KLEIDIAI + gp.A = const_cast(&A[m2 * k + k_ind]); +#else PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data()); gp.A = scratchpad->data(); +#endif } else { // When m == 1, it is actually vector matrix multiplication. We // don't need to do the transposition for packA here. Instead, we @@ -172,7 +183,11 @@ void cblas_gemm_compute( gp.C = &C[m2 * ldc]; gp.ldc = ldc * sizeof(C[0]); gp.b_block_cols = nbcol; +#ifdef FBGEMM_ENABLE_KLEIDIAI + gp.lda = k * sizeof(A[0]); +#else gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]); +#endif if ((n % Bp.blockColSize()) == 0) { int64_t jb_begin, jb_end; diff --git a/include/fbgemm/FbgemmPackMatrixB.h b/include/fbgemm/FbgemmPackMatrixB.h index 9a2de6d95..e43bfccbc 100644 --- a/include/fbgemm/FbgemmPackMatrixB.h +++ b/include/fbgemm/FbgemmPackMatrixB.h @@ -60,7 +60,15 @@ class PackedGemmMatrixB { const float alpha, const float* smat, const int brow = 512) - : nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) { + : nrow_(nrow), + ncol_(ncol), + brow_(brow), +#ifdef FBGEMM_ENABLE_KLEIDIAI + kernel_ncol_blocks_(1) +#else + kernel_ncol_blocks_(2) +#endif + { initializeParam(); initializeMemory(); // copy source matrix into packed matrix diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index 4e07b12f6..eb538387c 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -14,7 +14,12 @@ #include "./FbgemmFP16UKernelsAvx2.h" #include "./FbgemmFP16UKernelsAvx512.h" #include "./FbgemmFP16UKernelsAvx512_256.h" +#ifdef __aarch64__ #include "./FbgemmFP16UKernelsSve128.h" +#endif +#ifdef FBGEMM_ENABLE_KLEIDIAI +#include "./KleidiAIFP16UKernelsNeon.h" +#endif #include "fbgemm/Fbgemm.h" #include "fbgemm/FbgemmFPCommon.h" @@ -27,13 +32,17 @@ namespace { // the restrictions of ymm register numbers (16). constexpr kernel_array_t kernel_fp16_avx2 = { nullptr, +#ifndef __aarch64__ gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, gemmkernel_4x2_Avx2_fp16_fA0fB0fC0, gemmkernel_5x2_Avx2_fp16_fA0fB0fC0, - gemmkernel_6x2_Avx2_fp16_fA0fB0fC0}; + gemmkernel_6x2_Avx2_fp16_fA0fB0fC0 +#endif +}; +#ifndef FBGEMM_ENABLE_KLEIDIAI constexpr kernel_array_t kernel_fp16_sve128 = { nullptr, #ifdef __aarch64__ @@ -52,9 +61,25 @@ constexpr kernel_array_t kernel_fp16_sve128 = { nullptr, #endif }; +#endif + +#ifdef FBGEMM_ENABLE_KLEIDIAI +constexpr kernel_array_t kernel_fp16_neon = { + nullptr, + kleidiai::gemmkernel_1x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_2x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_3x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_4x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_5x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_6x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_7x1_Neon_fp16_fA0fB0fC0, + kleidiai::gemmkernel_8x1_Neon_fp16_fA0fB0fC0, +}; +#endif constexpr kernel_array_t kernel_fp16_avx512_256 = { nullptr, +#ifndef __aarch64__ gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, @@ -68,7 +93,9 @@ constexpr kernel_array_t kernel_fp16_avx512_256 = { gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0, gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0, gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0, - gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0}; + gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0 +#endif +}; constexpr kernel_array_t kernel_fp16_avx512 = { #ifndef __aarch64__ @@ -102,12 +129,21 @@ const isa_descriptor& getIsaHandlers(inst_set_t isa, float16) { std::make_tuple(kernel_fp16_avx512, partition_avx512); static isa_descriptor avx512_256_descriptor = std::make_tuple(kernel_fp16_avx512_256, partition_avx512); +#ifdef FBGEMM_ENABLE_KLEIDIAI + static isa_descriptor neon_descriptor = + std::make_tuple(kernel_fp16_neon, partition_neon); +#else static isa_descriptor sve128_descriptor = std::make_tuple(kernel_fp16_sve128, partition_sve128); +#endif switch (isa) { case inst_set_t::sve: +#ifdef FBGEMM_ENABLE_KLEIDIAI + return neon_descriptor; +#else return sve128_descriptor; +#endif case inst_set_t::anyarch: case inst_set_t::avx2: return avx2_descriptor; diff --git a/src/FbgemmFPCommon.cc b/src/FbgemmFPCommon.cc index 9d648d836..84969e2c1 100644 --- a/src/FbgemmFPCommon.cc +++ b/src/FbgemmFPCommon.cc @@ -301,6 +301,135 @@ partition_array_t partition_sve128 = { } }; +partition_array_t partition_neon = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + { + {{ { 0, 0 }, { 0, 0 } } }, // 0 + {{ { 1, 1 }, { 0, 0 } } }, // 1 + {{ { 2, 1 }, { 0, 0 } } }, // 2 + {{ { 3, 1 }, { 0, 0 } } }, // 3 + {{ { 4, 1 }, { 0, 0 } } }, // 4 + {{ { 5, 1 }, { 0, 0 } } }, // 5 + {{ { 6, 1 }, { 0, 0 } } }, // 6 + {{ { 7, 1 }, { 0, 0 } } }, // 7 + {{ { 8, 1 }, { 0, 0 } } }, // 8 + {{ { 5, 1 }, { 4, 1 } } }, // 9 + {{ { 5, 2 }, { 0, 0 } } }, // 10 + {{ { 6, 1 }, { 5, 1 } } }, // 11 + {{ { 6, 2 }, { 0, 0 } } }, // 12 + {{ { 7, 1 }, { 6, 1 } } }, // 13 + {{ { 8, 1 }, { 6, 1 } } }, // 14 + {{ { 8, 1 }, { 7, 1 } } }, // 15 + {{ { 8, 2 }, { 0, 0 } } }, // 16 + {{ { 8, 2 }, { 1, 1 } } }, // 17 + {{ { 6, 3 }, { 0, 0 } } }, // 18 + {{ { 8, 2 }, { 3, 1 } } }, // 19 + {{ { 5, 4 }, { 0, 0 } } }, // 20 + {{ { 5, 3 }, { 6, 1 } } }, // 21 + {{ { 8, 2 }, { 6, 1 } } }, // 22 + {{ { 8, 2 }, { 7, 1 } } }, // 23 + {{ { 8, 3 }, { 0, 0 } } }, // 24 + {{ { 8, 3 }, { 1, 1 } } }, // 25 + {{ { 8, 3 }, { 2, 1 } } }, // 26 + {{ { 8, 3 }, { 3, 1 } } }, // 27 + {{ { 8, 3 }, { 4, 1 } } }, // 28 + {{ { 8, 3 }, { 5, 1 } } }, // 29 + {{ { 8, 3 }, { 6, 1 } } }, // 30 + {{ { 8, 3 }, { 7, 1 } } }, // 31 + {{ { 8, 4 }, { 0, 0 } } }, // 32 + {{ { 8, 4 }, { 1, 1 } } }, // 33 + {{ { 8, 4 }, { 2, 1 } } }, // 34 + {{ { 8, 4 }, { 3, 1 } } }, // 35 + {{ { 8, 4 }, { 4, 1 } } }, // 36 + {{ { 8, 4 }, { 5, 1 } } }, // 37 + {{ { 8, 4 }, { 6, 1 } } }, // 38 + {{ { 8, 4 }, { 7, 1 } } }, // 39 + {{ { 8, 5 }, { 0, 0 } } }, // 40 + {{ { 8, 5 }, { 1, 1 } } }, // 41 + {{ { 8, 5 }, { 2, 1 } } }, // 42 + {{ { 8, 5 }, { 3, 1 } } }, // 43 + {{ { 8, 5 }, { 4, 1 } } }, // 44 + {{ { 8, 5 }, { 5, 1 } } }, // 45 + {{ { 8, 5 }, { 6, 1 } } }, // 46 + {{ { 8, 5 }, { 7, 1 } } }, // 47 + {{ { 8, 6 }, { 0, 0 } } }, // 48 + {{ { 8, 6 }, { 1, 1 } } }, // 49 + {{ { 8, 6 }, { 2, 1 } } }, // 50 + {{ { 8, 6 }, { 3, 1 } } }, // 51 + {{ { 8, 6 }, { 4, 1 } } }, // 52 + {{ { 8, 6 }, { 5, 1 } } }, // 53 + {{ { 8, 6 }, { 6, 1 } } }, // 54 + {{ { 8, 6 }, { 7, 1 } } }, // 55 + {{ { 8, 7 }, { 0, 0 } } }, // 56 + {{ { 8, 7 }, { 1, 1 } } }, // 57 + {{ { 8, 7 }, { 2, 1 } } }, // 58 + {{ { 8, 7 }, { 3, 1 } } }, // 59 + {{ { 8, 7 }, { 4, 1 } } }, // 60 + {{ { 8, 7 }, { 5, 1 } } }, // 61 + {{ { 8, 7 }, { 6, 1 } } }, // 62 + {{ { 8, 7 }, { 7, 1 } } }, // 63 + {{ { 8, 8 }, { 0, 0 } } }, // 64 + {{ { 8, 8 }, { 1, 1 } } }, // 65 + {{ { 8, 8 }, { 2, 1 } } }, // 66 + {{ { 8, 8 }, { 3, 1 } } }, // 67 + {{ { 8, 8 }, { 4, 1 } } }, // 68 + {{ { 8, 8 }, { 5, 1 } } }, // 69 + {{ { 8, 8 }, { 6, 1 } } }, // 70 + {{ { 8, 8 }, { 7, 1 } } }, // 71 + {{ { 8, 9 }, { 0, 0 } } }, // 72 + {{ { 8, 9 }, { 1, 1 } } }, // 73 + {{ { 8, 9 }, { 2, 1 } } }, // 74 + {{ { 8, 9 }, { 3, 1 } } }, // 75 + {{ { 8, 9 }, { 4, 1 } } }, // 76 + {{ { 8, 9 }, { 5, 1 } } }, // 77 + {{ { 8, 9 }, { 6, 1 } } }, // 78 + {{ { 8, 9 }, { 7, 1 } } }, // 79 + {{ { 8, 10 }, { 0, 0 } } }, // 80 + {{ { 8, 10 }, { 1, 1 } } }, // 81 + {{ { 8, 10 }, { 2, 1 } } }, // 82 + {{ { 8, 10 }, { 3, 1 } } }, // 83 + {{ { 8, 10 }, { 4, 1 } } }, // 84 + {{ { 8, 10 }, { 5, 1 } } }, // 85 + {{ { 8, 10 }, { 6, 1 } } }, // 86 + {{ { 8, 10 }, { 7, 1 } } }, // 87 + {{ { 8, 11 }, { 0, 0 } } }, // 88 + {{ { 8, 11 }, { 1, 1 } } }, // 89 + {{ { 8, 11 }, { 2, 1 } } }, // 90 + {{ { 8, 11 }, { 3, 1 } } }, // 91 + {{ { 8, 11 }, { 4, 1 } } }, // 92 + {{ { 8, 11 }, { 5, 1 } } }, // 93 + {{ { 8, 11 }, { 6, 1 } } }, // 94 + {{ { 8, 11 }, { 7, 1 } } }, // 95 + {{ { 8, 12 }, { 0, 0 } } }, // 96 + {{ { 8, 12 }, { 1, 1 } } }, // 97 + {{ { 8, 12 }, { 2, 1 } } }, // 98 + {{ { 8, 12 }, { 3, 1 } } }, // 99 + {{ { 8, 12 }, { 4, 1 } } }, // 100 + {{ { 8, 12 }, { 5, 1 } } }, // 101 + {{ { 8, 12 }, { 6, 1 } } }, // 102 + {{ { 8, 12 }, { 7, 1 } } }, // 103 + {{ { 8, 13 }, { 0, 0 } } }, // 104 + {{ { 8, 13 }, { 1, 1 } } }, // 105 + {{ { 8, 13 }, { 2, 1 } } }, // 106 + {{ { 8, 13 }, { 3, 1 } } }, // 107 + {{ { 8, 13 }, { 4, 1 } } }, // 108 + {{ { 8, 13 }, { 5, 1 } } }, // 109 + {{ { 8, 13 }, { 6, 1 } } }, // 110 + {{ { 8, 13 }, { 7, 1 } } }, // 111 + {{ { 8, 14 }, { 0, 0 } } }, // 112 + {{ { 8, 14 }, { 1, 1 } } }, // 113 + {{ { 8, 14 }, { 2, 1 } } }, // 114 + {{ { 8, 14 }, { 3, 1 } } }, // 115 + {{ { 8, 14 }, { 4, 1 } } }, // 116 + {{ { 8, 14 }, { 5, 1 } } }, // 117 + {{ { 8, 14 }, { 6, 1 } } }, // 118 + {{ { 8, 14 }, { 7, 1 } } }, // 119 + {{ { 8, 15 }, { 0, 0 } } }, // 120 + } +}; + + partition_array_t partition_avx512 = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. diff --git a/src/KleidiAIFP16UKernelsNeon.cc b/src/KleidiAIFP16UKernelsNeon.cc new file mode 100644 index 000000000..7e6eacae8 --- /dev/null +++ b/src/KleidiAIFP16UKernelsNeon.cc @@ -0,0 +1,2474 @@ +// @lint-ignore-every LICENSELINT +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifdef FBGEMM_ENABLE_KLEIDIAI + +#include "KleidiAIFP16UKernelsNeon.h" + +namespace kleidiai { + +void NOINLINE gemmkernel_1x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x25, #0x1\n" + "fmov v29.8h, #1.0\n" + "ldr x24, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x23, [%x[gp], %[offsetof_B]]\n" + "ldr x22, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x25, XZR, x25, EQ\n" + "1:" // Height 1: Column loop + "tbz x25, #0, 2f\n" + "ldr q30, [x22, #0x0]\n" + "ldr q31, [x22, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 1: no accumulate + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 1: setup done + "ldr x20, [%x[gp], %[offsetof_A]]\n" + "ldr x21, [%x[gp], %[offsetof_k]]\n" + "mov x20, x20\n" + "cmp x21, #0x4\n" + "blt 7f\n" + "ldr q0, [x20, #0x0]\n" + "ldr q1, [x23, #0x0]\n" + "cmp x21, #0x8\n" + "ldr q4, [x23, #0x10]\n" + "ldr q7, [x23, #0x20]\n" + "ldr q10, [x23, #0x30]\n" + "blt 6f\n" + "5:" // Height 1: Multiply loop: Main loop head + "movi v2.16b, #0x0\n" + "movi v3.16b, #0x0\n" + "sub x21, x21, #0x4\n" + "add x20, x20, #0x10\n" + "movi v5.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "cmp x21, #0x8\n" + "add x23, x23, #0x40\n" + "fmlal v2.4s, v1.4h, v29.4h\n" + "fmlal2 v3.4s, v1.4h, v29.4h\n" + "ldr q1, [x23, #0x0]\n" + "movi v8.16b, #0x0\n" + "fmlal v5.4s, v4.4h, v29.4h\n" + "fmlal2 v6.4s, v4.4h, v29.4h\n" + "ldr q4, [x23, #0x10]\n" + "movi v9.16b, #0x0\n" + "fmlal v8.4s, v7.4h, v29.4h\n" + "movi v11.16b, #0x0\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmlal2 v9.4s, v7.4h, v29.4h\n" + "ldr q7, [x23, #0x20]\n" + "movi v12.16b, #0x0\n" + "fmla v30.4s, v2.4s, v0.s[0]\n" + "fmla v31.4s, v3.4s, v0.s[0]\n" + "fmlal v11.4s, v10.4h, v29.4h\n" + "fmlal2 v12.4s, v10.4h, v29.4h\n" + "ldr q10, [x23, #0x30]\n" + "fmla v30.4s, v5.4s, v0.s[1]\n" + "fmla v31.4s, v6.4s, v0.s[1]\n" + "fmla v30.4s, v8.4s, v0.s[2]\n" + "fmla v31.4s, v9.4s, v0.s[2]\n" + "fmla v30.4s, v11.4s, v0.s[3]\n" + "fmla v31.4s, v12.4s, v0.s[3]\n" + "ldr q0, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 1: Multiply loop: Single iteration only + "movi v2.16b, #0x0\n" + "movi v3.16b, #0x0\n" + "add x20, x20, #0x10\n" + "sub x21, x21, #0x4\n" + "movi v5.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "prfm pldl1keep, [x20, #0x80]\n" + "add x23, x23, #0x40\n" + "fmlal v2.4s, v1.4h, v29.4h\n" + "fmlal2 v3.4s, v1.4h, v29.4h\n" + "movi v8.16b, #0x0\n" + "fmlal v5.4s, v4.4h, v29.4h\n" + "fmlal2 v6.4s, v4.4h, v29.4h\n" + "movi v9.16b, #0x0\n" + "fmlal v8.4s, v7.4h, v29.4h\n" + "movi v11.16b, #0x0\n" + "fmlal2 v9.4s, v7.4h, v29.4h\n" + "movi v12.16b, #0x0\n" + "fmla v30.4s, v2.4s, v0.s[0]\n" + "fmla v31.4s, v3.4s, v0.s[0]\n" + "fmlal v11.4s, v10.4h, v29.4h\n" + "fmlal2 v12.4s, v10.4h, v29.4h\n" + "fmla v30.4s, v5.4s, v0.s[1]\n" + "fmla v31.4s, v6.4s, v0.s[1]\n" + "fmla v30.4s, v8.4s, v0.s[2]\n" + "fmla v31.4s, v9.4s, v0.s[2]\n" + "fmla v30.4s, v11.4s, v0.s[3]\n" + "fmla v31.4s, v12.4s, v0.s[3]\n" + "7:" // Height 1: Multiply loop: Main loop skip + "cbz x21, 9f\n" + "8:" // Height 1: Multiply loop: Odd block loop + "ldr q13, [x23, #0x0]\n" + "ldr s0, [x20], #0x4\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "sub x21, x21, #0x1\n" + "add x23, x23, #0x10\n" + "fmlal v14.4s, v13.4h, v29.4h\n" + "fmlal2 v15.4s, v13.4h, v29.4h\n" + "fmla v30.4s, v14.4s, v0.s[0]\n" + "fmla v31.4s, v15.4s, v0.s[0]\n" + "cbnz x21, 8b\n" + "9:" // Height 1: Multiply loop: No odd multiplies + "prfm pstl1keep, [x22, #0x0]\n" + "str q30, [x22, #0x0]\n" + "str q31, [x22, #0x10]\n" + "add x22, x22, #0x20\n" + "subs x24, x24, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v2", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_2x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x26, #0x1\n" + "fmov v27.8h, #1.0\n" + "ldr x25, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x24, [%x[gp], %[offsetof_B]]\n" + "ldr x23, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x26, XZR, x26, EQ\n" + "1:" // Height 2: Column loop + "tbz x26, #0, 2f\n" + "ldr q28, [x23, #0x0]\n" + "ldr q29, [x23, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "ldr x20, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x23, x20\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "fmul v29.4s, v29.4s, v16.4s\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 2: no accumulate + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 2: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x22, [%x[gp], %[offsetof_k]]\n" + "mov x21, x21\n" + "add x20, x21, x20\n" + "cmp x22, #0x4\n" + "blt 7f\n" + "ldr q0, [x21, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "cmp x22, #0x8\n" + "ldr q1, [x20, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "ldr q8, [x24, #0x20]\n" + "ldr q11, [x24, #0x30]\n" + "blt 6f\n" + "5:" // Height 2: Multiply loop: Main loop head + "movi v3.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "sub x22, x22, #0x4\n" + "add x21, x21, #0x10\n" + "movi v6.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "add x20, x20, #0x10\n" + "cmp x22, #0x8\n" + "fmlal v3.4s, v2.4h, v27.4h\n" + "fmlal2 v4.4s, v2.4h, v27.4h\n" + "movi v9.16b, #0x0\n" + "add x24, x24, #0x40\n" + "ldr q2, [x24, #0x0]\n" + "fmlal v6.4s, v5.4h, v27.4h\n" + "fmlal2 v7.4s, v5.4h, v27.4h\n" + "ldr q5, [x24, #0x10]\n" + "movi v10.16b, #0x0\n" + "fmlal v9.4s, v8.4h, v27.4h\n" + "prfm pldl1keep, [x21, #0x80]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "fmla v28.4s, v3.4s, v0.s[0]\n" + "fmla v30.4s, v3.4s, v1.s[0]\n" + "fmla v29.4s, v4.4s, v0.s[0]\n" + "fmla v31.4s, v4.4s, v1.s[0]\n" + "fmlal2 v10.4s, v8.4h, v27.4h\n" + "ldr q8, [x24, #0x20]\n" + "fmlal v12.4s, v11.4h, v27.4h\n" + "fmlal2 v13.4s, v11.4h, v27.4h\n" + "ldr q11, [x24, #0x30]\n" + "fmla v28.4s, v6.4s, v0.s[1]\n" + "fmla v30.4s, v6.4s, v1.s[1]\n" + "fmla v29.4s, v7.4s, v0.s[1]\n" + "fmla v31.4s, v7.4s, v1.s[1]\n" + "fmla v28.4s, v9.4s, v0.s[2]\n" + "fmla v30.4s, v9.4s, v1.s[2]\n" + "fmla v29.4s, v10.4s, v0.s[2]\n" + "fmla v31.4s, v10.4s, v1.s[2]\n" + "fmla v28.4s, v12.4s, v0.s[3]\n" + "fmla v30.4s, v12.4s, v1.s[3]\n" + "fmla v29.4s, v13.4s, v0.s[3]\n" + "ldr q0, [x21, #0x0]\n" + "fmla v31.4s, v13.4s, v1.s[3]\n" + "ldr q1, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 2: Multiply loop: Single iteration only + "movi v3.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "movi v6.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "sub x22, x22, #0x4\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmlal v3.4s, v2.4h, v27.4h\n" + "fmlal2 v4.4s, v2.4h, v27.4h\n" + "movi v9.16b, #0x0\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmlal v6.4s, v5.4h, v27.4h\n" + "fmlal2 v7.4s, v5.4h, v27.4h\n" + "movi v10.16b, #0x0\n" + "add x24, x24, #0x40\n" + "fmlal v9.4s, v8.4h, v27.4h\n" + "movi v12.16b, #0x0\n" + "fmlal2 v10.4s, v8.4h, v27.4h\n" + "movi v13.16b, #0x0\n" + "fmla v28.4s, v3.4s, v0.s[0]\n" + "fmla v30.4s, v3.4s, v1.s[0]\n" + "fmla v29.4s, v4.4s, v0.s[0]\n" + "fmla v31.4s, v4.4s, v1.s[0]\n" + "fmlal v12.4s, v11.4h, v27.4h\n" + "fmlal2 v13.4s, v11.4h, v27.4h\n" + "fmla v28.4s, v6.4s, v0.s[1]\n" + "fmla v30.4s, v6.4s, v1.s[1]\n" + "fmla v29.4s, v7.4s, v0.s[1]\n" + "fmla v31.4s, v7.4s, v1.s[1]\n" + "fmla v28.4s, v9.4s, v0.s[2]\n" + "fmla v30.4s, v9.4s, v1.s[2]\n" + "fmla v29.4s, v10.4s, v0.s[2]\n" + "fmla v31.4s, v10.4s, v1.s[2]\n" + "fmla v28.4s, v12.4s, v0.s[3]\n" + "fmla v30.4s, v12.4s, v1.s[3]\n" + "fmla v29.4s, v13.4s, v0.s[3]\n" + "fmla v31.4s, v13.4s, v1.s[3]\n" + "7:" // Height 2: Multiply loop: Main loop skip + "cbz x22, 9f\n" + "8:" // Height 2: Multiply loop: Odd block loop + "ldr q14, [x24, #0x0]\n" + "ldr s0, [x21], #0x4\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "ldr s1, [x20], #0x4\n" + "sub x22, x22, #0x1\n" + "add x24, x24, #0x10\n" + "fmlal v15.4s, v14.4h, v27.4h\n" + "fmlal2 v16.4s, v14.4h, v27.4h\n" + "fmla v28.4s, v15.4s, v0.s[0]\n" + "fmla v30.4s, v15.4s, v1.s[0]\n" + "fmla v29.4s, v16.4s, v0.s[0]\n" + "fmla v31.4s, v16.4s, v1.s[0]\n" + "cbnz x22, 8b\n" + "9:" // Height 2: Multiply loop: No odd multiplies + "ldr x20, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x23, #0x0]\n" + "str q28, [x23, #0x0]\n" + "str q29, [x23, #0x10]\n" + "add x20, x23, x20\n" + "add x23, x23, #0x20\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x25, x25, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v2", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_3x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x27, #0x1\n" + "fmov v25.8h, #1.0\n" + "ldr x26, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x25, [%x[gp], %[offsetof_B]]\n" + "ldr x24, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x27, XZR, x27, EQ\n" + "1:" // Height 3: Column loop + "tbz x27, #0, 2f\n" + "ldr q26, [x24, #0x0]\n" + "ldr q27, [x24, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "ldr x21, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x24, x21\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v26.4s, v26.4s, v16.4s\n" + "fmul v27.4s, v27.4s, v16.4s\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "fmul v29.4s, v29.4s, v16.4s\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 3: no accumulate + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 3: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x23, [%x[gp], %[offsetof_k]]\n" + "mov x22, x21\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "cmp x23, #0x4\n" + "blt 7f\n" + "ldr q0, [x22, #0x0]\n" + "ldr q3, [x25, #0x0]\n" + "cmp x23, #0x8\n" + "ldr q1, [x21, #0x0]\n" + "ldr q2, [x20, #0x0]\n" + "ldr q6, [x25, #0x10]\n" + "ldr q9, [x25, #0x20]\n" + "ldr q12, [x25, #0x30]\n" + "blt 6f\n" + "5:" // Height 3: Multiply loop: Main loop head + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "sub x23, x23, #0x4\n" + "add x22, x22, #0x10\n" + "movi v7.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "fmlal v4.4s, v3.4h, v25.4h\n" + "fmlal2 v5.4s, v3.4h, v25.4h\n" + "movi v10.16b, #0x0\n" + "cmp x23, #0x8\n" + "fmlal v7.4s, v6.4h, v25.4h\n" + "fmlal2 v8.4s, v6.4h, v25.4h\n" + "movi v11.16b, #0x0\n" + "add x25, x25, #0x40\n" + "ldr q3, [x25, #0x0]\n" + "ldr q6, [x25, #0x10]\n" + "fmlal v10.4s, v9.4h, v25.4h\n" + "movi v13.16b, #0x0\n" + "fmlal2 v11.4s, v9.4h, v25.4h\n" + "ldr q9, [x25, #0x20]\n" + "movi v14.16b, #0x0\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v26.4s, v4.4s, v0.s[0]\n" + "fmla v28.4s, v4.4s, v1.s[0]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v30.4s, v4.4s, v2.s[0]\n" + "fmla v27.4s, v5.4s, v0.s[0]\n" + "fmla v29.4s, v5.4s, v1.s[0]\n" + "fmla v31.4s, v5.4s, v2.s[0]\n" + "fmlal v13.4s, v12.4h, v25.4h\n" + "fmlal2 v14.4s, v12.4h, v25.4h\n" + "ldr q12, [x25, #0x30]\n" + "fmla v26.4s, v7.4s, v0.s[1]\n" + "fmla v28.4s, v7.4s, v1.s[1]\n" + "fmla v30.4s, v7.4s, v2.s[1]\n" + "fmla v27.4s, v8.4s, v0.s[1]\n" + "fmla v29.4s, v8.4s, v1.s[1]\n" + "fmla v31.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v10.4s, v0.s[2]\n" + "fmla v28.4s, v10.4s, v1.s[2]\n" + "fmla v30.4s, v10.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v0.s[2]\n" + "fmla v29.4s, v11.4s, v1.s[2]\n" + "fmla v31.4s, v11.4s, v2.s[2]\n" + "fmla v26.4s, v13.4s, v0.s[3]\n" + "fmla v28.4s, v13.4s, v1.s[3]\n" + "fmla v30.4s, v13.4s, v2.s[3]\n" + "fmla v27.4s, v14.4s, v0.s[3]\n" + "ldr q0, [x22, #0x0]\n" + "fmla v29.4s, v14.4s, v1.s[3]\n" + "ldr q1, [x21, #0x0]\n" + "fmla v31.4s, v14.4s, v2.s[3]\n" + "ldr q2, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 3: Multiply loop: Single iteration only + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "movi v7.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "add x20, x20, #0x10\n" + "sub x23, x23, #0x4\n" + "fmlal v4.4s, v3.4h, v25.4h\n" + "fmlal2 v5.4s, v3.4h, v25.4h\n" + "movi v10.16b, #0x0\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmlal v7.4s, v6.4h, v25.4h\n" + "fmlal2 v8.4s, v6.4h, v25.4h\n" + "movi v11.16b, #0x0\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmlal v10.4s, v9.4h, v25.4h\n" + "movi v13.16b, #0x0\n" + "prfm pldl1keep, [x20, #0x80]\n" + "add x25, x25, #0x40\n" + "fmlal2 v11.4s, v9.4h, v25.4h\n" + "movi v14.16b, #0x0\n" + "fmla v26.4s, v4.4s, v0.s[0]\n" + "fmla v28.4s, v4.4s, v1.s[0]\n" + "fmla v30.4s, v4.4s, v2.s[0]\n" + "fmla v27.4s, v5.4s, v0.s[0]\n" + "fmla v29.4s, v5.4s, v1.s[0]\n" + "fmla v31.4s, v5.4s, v2.s[0]\n" + "fmlal v13.4s, v12.4h, v25.4h\n" + "fmla v26.4s, v7.4s, v0.s[1]\n" + "fmlal2 v14.4s, v12.4h, v25.4h\n" + "fmla v28.4s, v7.4s, v1.s[1]\n" + "fmla v30.4s, v7.4s, v2.s[1]\n" + "fmla v27.4s, v8.4s, v0.s[1]\n" + "fmla v29.4s, v8.4s, v1.s[1]\n" + "fmla v31.4s, v8.4s, v2.s[1]\n" + "fmla v26.4s, v10.4s, v0.s[2]\n" + "fmla v28.4s, v10.4s, v1.s[2]\n" + "fmla v30.4s, v10.4s, v2.s[2]\n" + "fmla v27.4s, v11.4s, v0.s[2]\n" + "fmla v29.4s, v11.4s, v1.s[2]\n" + "fmla v31.4s, v11.4s, v2.s[2]\n" + "fmla v26.4s, v13.4s, v0.s[3]\n" + "fmla v28.4s, v13.4s, v1.s[3]\n" + "fmla v30.4s, v13.4s, v2.s[3]\n" + "fmla v27.4s, v14.4s, v0.s[3]\n" + "fmla v29.4s, v14.4s, v1.s[3]\n" + "fmla v31.4s, v14.4s, v2.s[3]\n" + "7:" // Height 3: Multiply loop: Main loop skip + "cbz x23, 9f\n" + "8:" // Height 3: Multiply loop: Odd block loop + "ldr q15, [x25, #0x0]\n" + "ldr s0, [x22], #0x4\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "ldr s1, [x21], #0x4\n" + "ldr s2, [x20], #0x4\n" + "sub x23, x23, #0x1\n" + "add x25, x25, #0x10\n" + "fmlal v16.4s, v15.4h, v25.4h\n" + "fmlal2 v17.4s, v15.4h, v25.4h\n" + "fmla v26.4s, v16.4s, v0.s[0]\n" + "fmla v28.4s, v16.4s, v1.s[0]\n" + "fmla v30.4s, v16.4s, v2.s[0]\n" + "fmla v27.4s, v17.4s, v0.s[0]\n" + "fmla v29.4s, v17.4s, v1.s[0]\n" + "fmla v31.4s, v17.4s, v2.s[0]\n" + "cbnz x23, 8b\n" + "9:" // Height 3: Multiply loop: No odd multiplies + "ldr x20, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x24, #0x0]\n" + "str q26, [x24, #0x0]\n" + "str q27, [x24, #0x10]\n" + "add x21, x24, x20\n" + "add x24, x24, #0x20\n" + "prfm pstl1keep, [x21, #0x0]\n" + "str q28, [x21, #0x0]\n" + "add x20, x21, x20\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q29, [x21, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x26, x26, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v2", + "v25", + "v26", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_4x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x28, #0x1\n" + "fmov v23.8h, #1.0\n" + "ldr x27, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x26, [%x[gp], %[offsetof_B]]\n" + "ldr x25, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x28, XZR, x28, EQ\n" + "1:" // Height 4: Column loop + "tbz x28, #0, 2f\n" + "ldr q24, [x25, #0x0]\n" + "ldr q25, [x25, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "ldr x21, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x25, x21\n" + "ldr q26, [x20, #0x0]\n" + "ldr q27, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v24.4s, v24.4s, v16.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmul v26.4s, v26.4s, v16.4s\n" + "fmul v27.4s, v27.4s, v16.4s\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "fmul v29.4s, v29.4s, v16.4s\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 4: no accumulate + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 4: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x24, [%x[gp], %[offsetof_k]]\n" + "mov x23, x21\n" + "add x22, x23, x20\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "cmp x24, #0x4\n" + "blt 7f\n" + "ldr q0, [x23, #0x0]\n" + "ldr q4, [x26, #0x0]\n" + "cmp x24, #0x8\n" + "ldr q1, [x22, #0x0]\n" + "ldr q2, [x21, #0x0]\n" + "ldr q3, [x20, #0x0]\n" + "ldr q7, [x26, #0x10]\n" + "ldr q10, [x26, #0x20]\n" + "ldr q13, [x26, #0x30]\n" + "blt 6f\n" + "5:" // Height 4: Multiply loop: Main loop head + "movi v5.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "sub x24, x24, #0x4\n" + "add x23, x23, #0x10\n" + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmlal v5.4s, v4.4h, v23.4h\n" + "fmlal2 v6.4s, v4.4h, v23.4h\n" + "movi v11.16b, #0x0\n" + "add x20, x20, #0x10\n" + "fmlal v8.4s, v7.4h, v23.4h\n" + "fmlal2 v9.4s, v7.4h, v23.4h\n" + "movi v12.16b, #0x0\n" + "cmp x24, #0x8\n" + "fmlal v11.4s, v10.4h, v23.4h\n" + "movi v14.16b, #0x0\n" + "add x26, x26, #0x40\n" + "prfm pldl1keep, [x23, #0x80]\n" + "ldr q4, [x26, #0x0]\n" + "ldr q7, [x26, #0x10]\n" + "fmlal2 v12.4s, v10.4h, v23.4h\n" + "movi v15.16b, #0x0\n" + "ldr q10, [x26, #0x20]\n" + "fmla v24.4s, v5.4s, v0.s[0]\n" + "fmla v26.4s, v5.4s, v1.s[0]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v28.4s, v5.4s, v2.s[0]\n" + "fmla v30.4s, v5.4s, v3.s[0]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v25.4s, v6.4s, v0.s[0]\n" + "fmla v27.4s, v6.4s, v1.s[0]\n" + "fmla v29.4s, v6.4s, v2.s[0]\n" + "fmla v31.4s, v6.4s, v3.s[0]\n" + "fmla v24.4s, v8.4s, v0.s[1]\n" + "fmla v26.4s, v8.4s, v1.s[1]\n" + "fmla v28.4s, v8.4s, v2.s[1]\n" + "fmla v30.4s, v8.4s, v3.s[1]\n" + "fmla v25.4s, v9.4s, v0.s[1]\n" + "fmla v27.4s, v9.4s, v1.s[1]\n" + "fmla v29.4s, v9.4s, v2.s[1]\n" + "fmla v31.4s, v9.4s, v3.s[1]\n" + "fmlal v14.4s, v13.4h, v23.4h\n" + "fmla v24.4s, v11.4s, v0.s[2]\n" + "fmla v26.4s, v11.4s, v1.s[2]\n" + "fmla v28.4s, v11.4s, v2.s[2]\n" + "fmla v30.4s, v11.4s, v3.s[2]\n" + "fmla v25.4s, v12.4s, v0.s[2]\n" + "fmla v27.4s, v12.4s, v1.s[2]\n" + "fmla v29.4s, v12.4s, v2.s[2]\n" + "fmla v31.4s, v12.4s, v3.s[2]\n" + "fmlal2 v15.4s, v13.4h, v23.4h\n" + "ldr q13, [x26, #0x30]\n" + "fmla v24.4s, v14.4s, v0.s[3]\n" + "fmla v26.4s, v14.4s, v1.s[3]\n" + "fmla v28.4s, v14.4s, v2.s[3]\n" + "fmla v30.4s, v14.4s, v3.s[3]\n" + "fmla v25.4s, v15.4s, v0.s[3]\n" + "ldr q0, [x23, #0x0]\n" + "fmla v27.4s, v15.4s, v1.s[3]\n" + "ldr q1, [x22, #0x0]\n" + "fmla v29.4s, v15.4s, v2.s[3]\n" + "ldr q2, [x21, #0x0]\n" + "fmla v31.4s, v15.4s, v3.s[3]\n" + "ldr q3, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 4: Multiply loop: Single iteration only + "movi v5.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "fmlal v5.4s, v4.4h, v23.4h\n" + "fmlal2 v6.4s, v4.4h, v23.4h\n" + "movi v11.16b, #0x0\n" + "sub x24, x24, #0x4\n" + "fmlal v8.4s, v7.4h, v23.4h\n" + "fmlal2 v9.4s, v7.4h, v23.4h\n" + "movi v12.16b, #0x0\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmlal v11.4s, v10.4h, v23.4h\n" + "movi v14.16b, #0x0\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmlal2 v12.4s, v10.4h, v23.4h\n" + "movi v15.16b, #0x0\n" + "prfm pldl1keep, [x20, #0x80]\n" + "add x26, x26, #0x40\n" + "fmla v24.4s, v5.4s, v0.s[0]\n" + "fmla v26.4s, v5.4s, v1.s[0]\n" + "fmla v28.4s, v5.4s, v2.s[0]\n" + "fmla v30.4s, v5.4s, v3.s[0]\n" + "fmla v25.4s, v6.4s, v0.s[0]\n" + "fmla v27.4s, v6.4s, v1.s[0]\n" + "fmla v29.4s, v6.4s, v2.s[0]\n" + "fmla v31.4s, v6.4s, v3.s[0]\n" + "fmla v24.4s, v8.4s, v0.s[1]\n" + "fmla v26.4s, v8.4s, v1.s[1]\n" + "fmla v28.4s, v8.4s, v2.s[1]\n" + "fmla v30.4s, v8.4s, v3.s[1]\n" + "fmla v25.4s, v9.4s, v0.s[1]\n" + "fmla v27.4s, v9.4s, v1.s[1]\n" + "fmla v29.4s, v9.4s, v2.s[1]\n" + "fmla v31.4s, v9.4s, v3.s[1]\n" + "fmlal v14.4s, v13.4h, v23.4h\n" + "fmla v24.4s, v11.4s, v0.s[2]\n" + "fmla v26.4s, v11.4s, v1.s[2]\n" + "fmla v28.4s, v11.4s, v2.s[2]\n" + "fmla v30.4s, v11.4s, v3.s[2]\n" + "fmla v25.4s, v12.4s, v0.s[2]\n" + "fmla v27.4s, v12.4s, v1.s[2]\n" + "fmla v29.4s, v12.4s, v2.s[2]\n" + "fmla v31.4s, v12.4s, v3.s[2]\n" + "fmlal2 v15.4s, v13.4h, v23.4h\n" + "fmla v24.4s, v14.4s, v0.s[3]\n" + "fmla v26.4s, v14.4s, v1.s[3]\n" + "fmla v28.4s, v14.4s, v2.s[3]\n" + "fmla v30.4s, v14.4s, v3.s[3]\n" + "fmla v25.4s, v15.4s, v0.s[3]\n" + "fmla v27.4s, v15.4s, v1.s[3]\n" + "fmla v29.4s, v15.4s, v2.s[3]\n" + "fmla v31.4s, v15.4s, v3.s[3]\n" + "7:" // Height 4: Multiply loop: Main loop skip + "cbz x24, 9f\n" + "8:" // Height 4: Multiply loop: Odd block loop + "ldr q16, [x26, #0x0]\n" + "ldr s0, [x23], #0x4\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "ldr s1, [x22], #0x4\n" + "ldr s2, [x21], #0x4\n" + "sub x24, x24, #0x1\n" + "add x26, x26, #0x10\n" + "ldr s3, [x20], #0x4\n" + "fmlal v17.4s, v16.4h, v23.4h\n" + "fmlal2 v18.4s, v16.4h, v23.4h\n" + "fmla v24.4s, v17.4s, v0.s[0]\n" + "fmla v26.4s, v17.4s, v1.s[0]\n" + "fmla v28.4s, v17.4s, v2.s[0]\n" + "fmla v30.4s, v17.4s, v3.s[0]\n" + "fmla v25.4s, v18.4s, v0.s[0]\n" + "fmla v27.4s, v18.4s, v1.s[0]\n" + "fmla v29.4s, v18.4s, v2.s[0]\n" + "fmla v31.4s, v18.4s, v3.s[0]\n" + "cbnz x24, 8b\n" + "9:" // Height 4: Multiply loop: No odd multiplies + "ldr x20, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x25, #0x0]\n" + "str q24, [x25, #0x0]\n" + "str q25, [x25, #0x10]\n" + "add x22, x25, x20\n" + "add x25, x25, #0x20\n" + "prfm pstl1keep, [x22, #0x0]\n" + "str q26, [x22, #0x0]\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "prfm pstl1keep, [x21, #0x0]\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q27, [x22, #0x10]\n" + "str q28, [x21, #0x0]\n" + "str q29, [x21, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x27, x27, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v2", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27", + "x28"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_5x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x9, #0x1\n" + "fmov v21.8h, #1.0\n" + "ldr x28, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x27, [%x[gp], %[offsetof_B]]\n" + "ldr x26, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x9, XZR, x9, EQ\n" + "1:" // Height 5: Column loop + "tbz x9, #0, 2f\n" + "ldr q22, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "ldr x21, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x26, x21\n" + "ldr q24, [x20, #0x0]\n" + "ldr q25, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q26, [x20, #0x0]\n" + "ldr q27, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v23.4s, v23.4s, v16.4s\n" + "fmul v24.4s, v24.4s, v16.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmul v26.4s, v26.4s, v16.4s\n" + "fmul v27.4s, v27.4s, v16.4s\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "fmul v29.4s, v29.4s, v16.4s\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 5: no accumulate + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 5: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x25, [%x[gp], %[offsetof_k]]\n" + "mov x24, x21\n" + "add x23, x24, x20\n" + "add x22, x23, x20\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "cmp x25, #0x4\n" + "blt 7f\n" + "ldr q0, [x24, #0x0]\n" + "ldr q5, [x27, #0x0]\n" + "cmp x25, #0x8\n" + "ldr q1, [x23, #0x0]\n" + "ldr q2, [x22, #0x0]\n" + "ldr q3, [x21, #0x0]\n" + "ldr q4, [x20, #0x0]\n" + "ldr q8, [x27, #0x10]\n" + "ldr q11, [x27, #0x20]\n" + "ldr q14, [x27, #0x30]\n" + "blt 6f\n" + "5:" // Height 5: Multiply loop: Main loop head + "movi v6.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "sub x25, x25, #0x4\n" + "add x24, x24, #0x10\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmlal v6.4s, v5.4h, v21.4h\n" + "fmlal2 v7.4s, v5.4h, v21.4h\n" + "movi v12.16b, #0x0\n" + "add x21, x21, #0x10\n" + "fmlal v9.4s, v8.4h, v21.4h\n" + "fmlal2 v10.4s, v8.4h, v21.4h\n" + "movi v13.16b, #0x0\n" + "add x20, x20, #0x10\n" + "fmlal v12.4s, v11.4h, v21.4h\n" + "movi v15.16b, #0x0\n" + "cmp x25, #0x8\n" + "add x27, x27, #0x40\n" + "ldr q5, [x27, #0x0]\n" + "ldr q8, [x27, #0x10]\n" + "fmlal2 v13.4s, v11.4h, v21.4h\n" + "movi v16.16b, #0x0\n" + "ldr q11, [x27, #0x20]\n" + "fmla v22.4s, v6.4s, v0.s[0]\n" + "fmla v24.4s, v6.4s, v1.s[0]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v26.4s, v6.4s, v2.s[0]\n" + "fmla v28.4s, v6.4s, v3.s[0]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v30.4s, v6.4s, v4.s[0]\n" + "fmla v23.4s, v7.4s, v0.s[0]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v25.4s, v7.4s, v1.s[0]\n" + "fmla v27.4s, v7.4s, v2.s[0]\n" + "fmla v29.4s, v7.4s, v3.s[0]\n" + "fmla v31.4s, v7.4s, v4.s[0]\n" + "fmla v22.4s, v9.4s, v0.s[1]\n" + "fmla v24.4s, v9.4s, v1.s[1]\n" + "fmla v26.4s, v9.4s, v2.s[1]\n" + "fmla v28.4s, v9.4s, v3.s[1]\n" + "fmla v30.4s, v9.4s, v4.s[1]\n" + "fmla v23.4s, v10.4s, v0.s[1]\n" + "fmla v25.4s, v10.4s, v1.s[1]\n" + "fmla v27.4s, v10.4s, v2.s[1]\n" + "fmla v29.4s, v10.4s, v3.s[1]\n" + "fmla v31.4s, v10.4s, v4.s[1]\n" + "fmla v22.4s, v12.4s, v0.s[2]\n" + "fmla v24.4s, v12.4s, v1.s[2]\n" + "fmla v26.4s, v12.4s, v2.s[2]\n" + "fmla v28.4s, v12.4s, v3.s[2]\n" + "fmla v30.4s, v12.4s, v4.s[2]\n" + "fmla v23.4s, v13.4s, v0.s[2]\n" + "fmla v25.4s, v13.4s, v1.s[2]\n" + "fmla v27.4s, v13.4s, v2.s[2]\n" + "fmla v29.4s, v13.4s, v3.s[2]\n" + "fmla v31.4s, v13.4s, v4.s[2]\n" + "fmlal v15.4s, v14.4h, v21.4h\n" + "fmlal2 v16.4s, v14.4h, v21.4h\n" + "ldr q14, [x27, #0x30]\n" + "fmla v22.4s, v15.4s, v0.s[3]\n" + "fmla v24.4s, v15.4s, v1.s[3]\n" + "fmla v26.4s, v15.4s, v2.s[3]\n" + "fmla v28.4s, v15.4s, v3.s[3]\n" + "fmla v30.4s, v15.4s, v4.s[3]\n" + "fmla v23.4s, v16.4s, v0.s[3]\n" + "ldr q0, [x24, #0x0]\n" + "fmla v25.4s, v16.4s, v1.s[3]\n" + "ldr q1, [x23, #0x0]\n" + "fmla v27.4s, v16.4s, v2.s[3]\n" + "ldr q2, [x22, #0x0]\n" + "fmla v29.4s, v16.4s, v3.s[3]\n" + "ldr q3, [x21, #0x0]\n" + "fmla v31.4s, v16.4s, v4.s[3]\n" + "ldr q4, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 5: Multiply loop: Single iteration only + "movi v6.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmlal v6.4s, v5.4h, v21.4h\n" + "fmlal2 v7.4s, v5.4h, v21.4h\n" + "movi v12.16b, #0x0\n" + "add x20, x20, #0x10\n" + "fmlal v9.4s, v8.4h, v21.4h\n" + "fmlal2 v10.4s, v8.4h, v21.4h\n" + "movi v13.16b, #0x0\n" + "sub x25, x25, #0x4\n" + "fmlal v12.4s, v11.4h, v21.4h\n" + "movi v15.16b, #0x0\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmlal2 v13.4s, v11.4h, v21.4h\n" + "movi v16.16b, #0x0\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v22.4s, v6.4s, v0.s[0]\n" + "fmla v24.4s, v6.4s, v1.s[0]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "add x27, x27, #0x40\n" + "fmla v26.4s, v6.4s, v2.s[0]\n" + "fmla v28.4s, v6.4s, v3.s[0]\n" + "fmla v30.4s, v6.4s, v4.s[0]\n" + "fmla v23.4s, v7.4s, v0.s[0]\n" + "fmla v25.4s, v7.4s, v1.s[0]\n" + "fmla v27.4s, v7.4s, v2.s[0]\n" + "fmla v29.4s, v7.4s, v3.s[0]\n" + "fmla v31.4s, v7.4s, v4.s[0]\n" + "fmla v22.4s, v9.4s, v0.s[1]\n" + "fmla v24.4s, v9.4s, v1.s[1]\n" + "fmla v26.4s, v9.4s, v2.s[1]\n" + "fmla v28.4s, v9.4s, v3.s[1]\n" + "fmla v30.4s, v9.4s, v4.s[1]\n" + "fmla v23.4s, v10.4s, v0.s[1]\n" + "fmla v25.4s, v10.4s, v1.s[1]\n" + "fmla v27.4s, v10.4s, v2.s[1]\n" + "fmla v29.4s, v10.4s, v3.s[1]\n" + "fmla v31.4s, v10.4s, v4.s[1]\n" + "fmla v22.4s, v12.4s, v0.s[2]\n" + "fmla v24.4s, v12.4s, v1.s[2]\n" + "fmla v26.4s, v12.4s, v2.s[2]\n" + "fmla v28.4s, v12.4s, v3.s[2]\n" + "fmla v30.4s, v12.4s, v4.s[2]\n" + "fmla v23.4s, v13.4s, v0.s[2]\n" + "fmla v25.4s, v13.4s, v1.s[2]\n" + "fmla v27.4s, v13.4s, v2.s[2]\n" + "fmla v29.4s, v13.4s, v3.s[2]\n" + "fmla v31.4s, v13.4s, v4.s[2]\n" + "fmlal v15.4s, v14.4h, v21.4h\n" + "fmlal2 v16.4s, v14.4h, v21.4h\n" + "fmla v22.4s, v15.4s, v0.s[3]\n" + "fmla v24.4s, v15.4s, v1.s[3]\n" + "fmla v26.4s, v15.4s, v2.s[3]\n" + "fmla v28.4s, v15.4s, v3.s[3]\n" + "fmla v30.4s, v15.4s, v4.s[3]\n" + "fmla v23.4s, v16.4s, v0.s[3]\n" + "fmla v25.4s, v16.4s, v1.s[3]\n" + "fmla v27.4s, v16.4s, v2.s[3]\n" + "fmla v29.4s, v16.4s, v3.s[3]\n" + "fmla v31.4s, v16.4s, v4.s[3]\n" + "7:" // Height 5: Multiply loop: Main loop skip + "cbz x25, 9f\n" + "8:" // Height 5: Multiply loop: Odd block loop + "ldr q17, [x27, #0x0]\n" + "ldr s0, [x24], #0x4\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "ldr s1, [x23], #0x4\n" + "ldr s2, [x22], #0x4\n" + "sub x25, x25, #0x1\n" + "add x27, x27, #0x10\n" + "ldr s3, [x21], #0x4\n" + "ldr s4, [x20], #0x4\n" + "fmlal v18.4s, v17.4h, v21.4h\n" + "fmlal2 v19.4s, v17.4h, v21.4h\n" + "fmla v22.4s, v18.4s, v0.s[0]\n" + "fmla v24.4s, v18.4s, v1.s[0]\n" + "fmla v26.4s, v18.4s, v2.s[0]\n" + "fmla v28.4s, v18.4s, v3.s[0]\n" + "fmla v30.4s, v18.4s, v4.s[0]\n" + "fmla v23.4s, v19.4s, v0.s[0]\n" + "fmla v25.4s, v19.4s, v1.s[0]\n" + "fmla v27.4s, v19.4s, v2.s[0]\n" + "fmla v29.4s, v19.4s, v3.s[0]\n" + "fmla v31.4s, v19.4s, v4.s[0]\n" + "cbnz x25, 8b\n" + "9:" // Height 5: Multiply loop: No odd multiplies + "ldr x20, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x26, #0x0]\n" + "str q22, [x26, #0x0]\n" + "str q23, [x26, #0x10]\n" + "add x23, x26, x20\n" + "add x26, x26, #0x20\n" + "prfm pstl1keep, [x23, #0x0]\n" + "str q24, [x23, #0x0]\n" + "add x22, x23, x20\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "prfm pstl1keep, [x22, #0x0]\n" + "prfm pstl1keep, [x21, #0x0]\n" + "str q25, [x23, #0x10]\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q26, [x22, #0x0]\n" + "str q27, [x22, #0x10]\n" + "str q28, [x21, #0x0]\n" + "str q29, [x21, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x28, x28, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v2", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27", + "x28", + "x9"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_6x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x10, #0x1\n" + "fmov v19.8h, #1.0\n" + "ldr x9, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x28, [%x[gp], %[offsetof_B]]\n" + "ldr x27, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x10, XZR, x10, EQ\n" + "1:" // Height 6: Column loop + "tbz x10, #0, 2f\n" + "ldr q20, [x27, #0x0]\n" + "ldr q21, [x27, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "ldr x21, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x27, x21\n" + "ldr q22, [x20, #0x0]\n" + "ldr q23, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q24, [x20, #0x0]\n" + "ldr q25, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q26, [x20, #0x0]\n" + "ldr q27, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v20.4s, v20.4s, v16.4s\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v21.4s, v21.4s, v16.4s\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmul v23.4s, v23.4s, v16.4s\n" + "fmul v24.4s, v24.4s, v16.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmul v26.4s, v26.4s, v16.4s\n" + "fmul v27.4s, v27.4s, v16.4s\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "fmul v29.4s, v29.4s, v16.4s\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 6: no accumulate + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 6: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x26, [%x[gp], %[offsetof_k]]\n" + "mov x25, x21\n" + "add x24, x25, x20\n" + "add x23, x24, x20\n" + "add x22, x23, x20\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "cmp x26, #0x4\n" + "blt 7f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q6, [x28, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q1, [x24, #0x0]\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q5, [x20, #0x0]\n" + "ldr q9, [x28, #0x10]\n" + "ldr q12, [x28, #0x20]\n" + "ldr q15, [x28, #0x30]\n" + "blt 6f\n" + "5:" // Height 6: Multiply loop: Main loop head + "movi v7.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "sub x26, x26, #0x4\n" + "add x25, x25, #0x10\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmlal v7.4s, v6.4h, v19.4h\n" + "fmlal2 v8.4s, v6.4h, v19.4h\n" + "movi v13.16b, #0x0\n" + "add x22, x22, #0x10\n" + "fmlal v10.4s, v9.4h, v19.4h\n" + "fmlal2 v11.4s, v9.4h, v19.4h\n" + "movi v14.16b, #0x0\n" + "add x21, x21, #0x10\n" + "fmlal v13.4s, v12.4h, v19.4h\n" + "movi v16.16b, #0x0\n" + "add x20, x20, #0x10\n" + "cmp x26, #0x8\n" + "fmlal2 v14.4s, v12.4h, v19.4h\n" + "movi v17.16b, #0x0\n" + "add x28, x28, #0x40\n" + "prfm pldl1keep, [x25, #0x80]\n" + "ldr q6, [x28, #0x0]\n" + "ldr q9, [x28, #0x10]\n" + "fmla v20.4s, v7.4s, v0.s[0]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "ldr q12, [x28, #0x20]\n" + "fmla v24.4s, v7.4s, v2.s[0]\n" + "fmla v26.4s, v7.4s, v3.s[0]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v28.4s, v7.4s, v4.s[0]\n" + "fmla v30.4s, v7.4s, v5.s[0]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v21.4s, v8.4s, v0.s[0]\n" + "fmla v23.4s, v8.4s, v1.s[0]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v25.4s, v8.4s, v2.s[0]\n" + "fmla v27.4s, v8.4s, v3.s[0]\n" + "fmla v29.4s, v8.4s, v4.s[0]\n" + "fmla v31.4s, v8.4s, v5.s[0]\n" + "fmla v20.4s, v10.4s, v0.s[1]\n" + "fmla v22.4s, v10.4s, v1.s[1]\n" + "fmla v24.4s, v10.4s, v2.s[1]\n" + "fmla v26.4s, v10.4s, v3.s[1]\n" + "fmla v28.4s, v10.4s, v4.s[1]\n" + "fmla v30.4s, v10.4s, v5.s[1]\n" + "fmla v21.4s, v11.4s, v0.s[1]\n" + "fmla v23.4s, v11.4s, v1.s[1]\n" + "fmla v25.4s, v11.4s, v2.s[1]\n" + "fmla v27.4s, v11.4s, v3.s[1]\n" + "fmla v29.4s, v11.4s, v4.s[1]\n" + "fmla v31.4s, v11.4s, v5.s[1]\n" + "fmla v20.4s, v13.4s, v0.s[2]\n" + "fmla v22.4s, v13.4s, v1.s[2]\n" + "fmla v24.4s, v13.4s, v2.s[2]\n" + "fmla v26.4s, v13.4s, v3.s[2]\n" + "fmla v28.4s, v13.4s, v4.s[2]\n" + "fmla v30.4s, v13.4s, v5.s[2]\n" + "fmla v21.4s, v14.4s, v0.s[2]\n" + "fmla v23.4s, v14.4s, v1.s[2]\n" + "fmla v25.4s, v14.4s, v2.s[2]\n" + "fmla v27.4s, v14.4s, v3.s[2]\n" + "fmla v29.4s, v14.4s, v4.s[2]\n" + "fmla v31.4s, v14.4s, v5.s[2]\n" + "fmlal v16.4s, v15.4h, v19.4h\n" + "fmlal2 v17.4s, v15.4h, v19.4h\n" + "ldr q15, [x28, #0x30]\n" + "fmla v20.4s, v16.4s, v0.s[3]\n" + "fmla v22.4s, v16.4s, v1.s[3]\n" + "fmla v24.4s, v16.4s, v2.s[3]\n" + "fmla v26.4s, v16.4s, v3.s[3]\n" + "fmla v28.4s, v16.4s, v4.s[3]\n" + "fmla v30.4s, v16.4s, v5.s[3]\n" + "fmla v21.4s, v17.4s, v0.s[3]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v23.4s, v17.4s, v1.s[3]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v25.4s, v17.4s, v2.s[3]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v27.4s, v17.4s, v3.s[3]\n" + "ldr q3, [x22, #0x0]\n" + "fmla v29.4s, v17.4s, v4.s[3]\n" + "ldr q4, [x21, #0x0]\n" + "fmla v31.4s, v17.4s, v5.s[3]\n" + "ldr q5, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 6: Multiply loop: Single iteration only + "movi v7.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmlal v7.4s, v6.4h, v19.4h\n" + "fmlal2 v8.4s, v6.4h, v19.4h\n" + "movi v13.16b, #0x0\n" + "add x21, x21, #0x10\n" + "fmlal v10.4s, v9.4h, v19.4h\n" + "fmlal2 v11.4s, v9.4h, v19.4h\n" + "movi v14.16b, #0x0\n" + "add x20, x20, #0x10\n" + "fmlal v13.4s, v12.4h, v19.4h\n" + "movi v16.16b, #0x0\n" + "prfm pldl1keep, [x25, #0x80]\n" + "sub x26, x26, #0x4\n" + "fmlal2 v14.4s, v12.4h, v19.4h\n" + "movi v17.16b, #0x0\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmla v20.4s, v7.4s, v0.s[0]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v24.4s, v7.4s, v2.s[0]\n" + "fmla v26.4s, v7.4s, v3.s[0]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "add x28, x28, #0x40\n" + "fmla v28.4s, v7.4s, v4.s[0]\n" + "fmla v30.4s, v7.4s, v5.s[0]\n" + "fmla v21.4s, v8.4s, v0.s[0]\n" + "fmla v23.4s, v8.4s, v1.s[0]\n" + "fmla v25.4s, v8.4s, v2.s[0]\n" + "fmla v27.4s, v8.4s, v3.s[0]\n" + "fmla v29.4s, v8.4s, v4.s[0]\n" + "fmla v31.4s, v8.4s, v5.s[0]\n" + "fmla v20.4s, v10.4s, v0.s[1]\n" + "fmla v22.4s, v10.4s, v1.s[1]\n" + "fmla v24.4s, v10.4s, v2.s[1]\n" + "fmla v26.4s, v10.4s, v3.s[1]\n" + "fmla v28.4s, v10.4s, v4.s[1]\n" + "fmla v30.4s, v10.4s, v5.s[1]\n" + "fmla v21.4s, v11.4s, v0.s[1]\n" + "fmla v23.4s, v11.4s, v1.s[1]\n" + "fmla v25.4s, v11.4s, v2.s[1]\n" + "fmla v27.4s, v11.4s, v3.s[1]\n" + "fmla v29.4s, v11.4s, v4.s[1]\n" + "fmla v31.4s, v11.4s, v5.s[1]\n" + "fmla v20.4s, v13.4s, v0.s[2]\n" + "fmla v22.4s, v13.4s, v1.s[2]\n" + "fmla v24.4s, v13.4s, v2.s[2]\n" + "fmla v26.4s, v13.4s, v3.s[2]\n" + "fmla v28.4s, v13.4s, v4.s[2]\n" + "fmla v30.4s, v13.4s, v5.s[2]\n" + "fmlal v16.4s, v15.4h, v19.4h\n" + "fmla v21.4s, v14.4s, v0.s[2]\n" + "fmla v23.4s, v14.4s, v1.s[2]\n" + "fmla v25.4s, v14.4s, v2.s[2]\n" + "fmla v27.4s, v14.4s, v3.s[2]\n" + "fmla v29.4s, v14.4s, v4.s[2]\n" + "fmla v31.4s, v14.4s, v5.s[2]\n" + "fmlal2 v17.4s, v15.4h, v19.4h\n" + "fmla v20.4s, v16.4s, v0.s[3]\n" + "fmla v22.4s, v16.4s, v1.s[3]\n" + "fmla v24.4s, v16.4s, v2.s[3]\n" + "fmla v26.4s, v16.4s, v3.s[3]\n" + "fmla v28.4s, v16.4s, v4.s[3]\n" + "fmla v30.4s, v16.4s, v5.s[3]\n" + "fmla v21.4s, v17.4s, v0.s[3]\n" + "fmla v23.4s, v17.4s, v1.s[3]\n" + "fmla v25.4s, v17.4s, v2.s[3]\n" + "fmla v27.4s, v17.4s, v3.s[3]\n" + "fmla v29.4s, v17.4s, v4.s[3]\n" + "fmla v31.4s, v17.4s, v5.s[3]\n" + "7:" // Height 6: Multiply loop: Main loop skip + "cbz x26, 9f\n" + "8:" // Height 6: Multiply loop: Odd block loop + "ldr q18, [x28, #0x0]\n" + "ldr s0, [x25], #0x4\n" + "movi v6.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "ldr s1, [x24], #0x4\n" + "ldr s2, [x23], #0x4\n" + "sub x26, x26, #0x1\n" + "add x28, x28, #0x10\n" + "ldr s3, [x22], #0x4\n" + "ldr s4, [x21], #0x4\n" + "ldr s5, [x20], #0x4\n" + "fmlal v6.4s, v18.4h, v19.4h\n" + "fmlal2 v7.4s, v18.4h, v19.4h\n" + "fmla v20.4s, v6.4s, v0.s[0]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v3.s[0]\n" + "fmla v28.4s, v6.4s, v4.s[0]\n" + "fmla v30.4s, v6.4s, v5.s[0]\n" + "fmla v21.4s, v7.4s, v0.s[0]\n" + "fmla v23.4s, v7.4s, v1.s[0]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "fmla v27.4s, v7.4s, v3.s[0]\n" + "fmla v29.4s, v7.4s, v4.s[0]\n" + "fmla v31.4s, v7.4s, v5.s[0]\n" + "cbnz x26, 8b\n" + "9:" // Height 6: Multiply loop: No odd multiplies + "ldr x24, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x27, #0x0]\n" + "str q20, [x27, #0x0]\n" + "str q21, [x27, #0x10]\n" + "add x20, x27, x24\n" + "add x27, x27, #0x20\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q22, [x20, #0x0]\n" + "add x23, x20, x24\n" + "add x22, x23, x24\n" + "add x21, x22, x24\n" + "prfm pstl1keep, [x23, #0x0]\n" + "prfm pstl1keep, [x22, #0x0]\n" + "str q23, [x20, #0x10]\n" + "add x20, x21, x24\n" + "prfm pstl1keep, [x21, #0x0]\n" + "str q24, [x23, #0x0]\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q25, [x23, #0x10]\n" + "str q26, [x22, #0x0]\n" + "str q27, [x22, #0x10]\n" + "str q28, [x21, #0x0]\n" + "str q29, [x21, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x9, x9, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v2", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x10", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27", + "x28", + "x9"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_7x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x11, #0x1\n" + "fmov v17.8h, #1.0\n" + "ldr x10, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x9, [%x[gp], %[offsetof_B]]\n" + "ldr x28, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x11, XZR, x11, EQ\n" + "1:" // Height 7: Column loop + "tbz x11, #0, 2f\n" + "ldr q18, [x28, #0x0]\n" + "ldr q19, [x28, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v16.4s }, [x20]\n" + "ldr x21, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x28, x21\n" + "ldr q20, [x20, #0x0]\n" + "ldr q21, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q22, [x20, #0x0]\n" + "ldr q23, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q24, [x20, #0x0]\n" + "ldr q25, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v18.4s, v18.4s, v16.4s\n" + "ldr q26, [x20, #0x0]\n" + "ldr q27, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v19.4s, v19.4s, v16.4s\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v20.4s, v20.4s, v16.4s\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v21.4s, v21.4s, v16.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmul v23.4s, v23.4s, v16.4s\n" + "fmul v24.4s, v24.4s, v16.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmul v26.4s, v26.4s, v16.4s\n" + "fmul v27.4s, v27.4s, v16.4s\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "fmul v29.4s, v29.4s, v16.4s\n" + "fmul v30.4s, v30.4s, v16.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "b 3f\n" + "2:" // Height 7: no accumulate + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 7: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x27, [%x[gp], %[offsetof_k]]\n" + "mov x26, x21\n" + "add x25, x26, x20\n" + "add x24, x25, x20\n" + "add x23, x24, x20\n" + "add x22, x23, x20\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "cmp x27, #0x4\n" + "blt 7f\n" + "ldr q0, [x26, #0x0]\n" + "ldr q7, [x9, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q1, [x25, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "ldr q3, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x21, #0x0]\n" + "ldr q6, [x20, #0x0]\n" + "ldr q10, [x9, #0x10]\n" + "ldr q13, [x9, #0x20]\n" + "ldr q16, [x9, #0x30]\n" + "blt 6f\n" + "5:" // Height 7: Multiply loop: Main loop head + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "sub x27, x27, #0x4\n" + "add x26, x26, #0x10\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmlal v8.4s, v7.4h, v17.4h\n" + "fmlal2 v9.4s, v7.4h, v17.4h\n" + "movi v14.16b, #0x0\n" + "add x23, x23, #0x10\n" + "fmlal v11.4s, v10.4h, v17.4h\n" + "fmlal2 v12.4s, v10.4h, v17.4h\n" + "movi v15.16b, #0x0\n" + "add x22, x22, #0x10\n" + "fmlal v14.4s, v13.4h, v17.4h\n" + "movi v7.16b, #0x0\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "fmlal2 v15.4s, v13.4h, v17.4h\n" + "cmp x27, #0x8\n" + "add x9, x9, #0x40\n" + "prfm pldl1keep, [x26, #0x80]\n" + "ldr q10, [x9, #0x10]\n" + "ldr q13, [x9, #0x20]\n" + "fmla v18.4s, v8.4s, v0.s[0]\n" + "fmla v20.4s, v8.4s, v1.s[0]\n" + "fmla v22.4s, v8.4s, v2.s[0]\n" + "fmla v24.4s, v8.4s, v3.s[0]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v26.4s, v8.4s, v4.s[0]\n" + "fmla v28.4s, v8.4s, v5.s[0]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v30.4s, v8.4s, v6.s[0]\n" + "fmla v19.4s, v9.4s, v0.s[0]\n" + "movi v8.16b, #0x0\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v21.4s, v9.4s, v1.s[0]\n" + "fmla v23.4s, v9.4s, v2.s[0]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v25.4s, v9.4s, v3.s[0]\n" + "fmla v27.4s, v9.4s, v4.s[0]\n" + "fmla v29.4s, v9.4s, v5.s[0]\n" + "fmla v31.4s, v9.4s, v6.s[0]\n" + "fmla v18.4s, v11.4s, v0.s[1]\n" + "fmla v20.4s, v11.4s, v1.s[1]\n" + "fmla v22.4s, v11.4s, v2.s[1]\n" + "fmla v24.4s, v11.4s, v3.s[1]\n" + "fmla v26.4s, v11.4s, v4.s[1]\n" + "fmla v28.4s, v11.4s, v5.s[1]\n" + "fmla v30.4s, v11.4s, v6.s[1]\n" + "fmla v19.4s, v12.4s, v0.s[1]\n" + "fmla v21.4s, v12.4s, v1.s[1]\n" + "fmla v23.4s, v12.4s, v2.s[1]\n" + "fmla v25.4s, v12.4s, v3.s[1]\n" + "fmla v27.4s, v12.4s, v4.s[1]\n" + "fmla v29.4s, v12.4s, v5.s[1]\n" + "fmla v31.4s, v12.4s, v6.s[1]\n" + "fmlal v7.4s, v16.4h, v17.4h\n" + "fmlal2 v8.4s, v16.4h, v17.4h\n" + "ldr q16, [x9, #0x30]\n" + "fmla v18.4s, v14.4s, v0.s[2]\n" + "fmla v20.4s, v14.4s, v1.s[2]\n" + "fmla v22.4s, v14.4s, v2.s[2]\n" + "fmla v24.4s, v14.4s, v3.s[2]\n" + "fmla v26.4s, v14.4s, v4.s[2]\n" + "fmla v28.4s, v14.4s, v5.s[2]\n" + "fmla v30.4s, v14.4s, v6.s[2]\n" + "fmla v19.4s, v15.4s, v0.s[2]\n" + "fmla v21.4s, v15.4s, v1.s[2]\n" + "fmla v23.4s, v15.4s, v2.s[2]\n" + "fmla v25.4s, v15.4s, v3.s[2]\n" + "fmla v27.4s, v15.4s, v4.s[2]\n" + "fmla v29.4s, v15.4s, v5.s[2]\n" + "fmla v31.4s, v15.4s, v6.s[2]\n" + "fmla v18.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v7.4s, v1.s[3]\n" + "fmla v22.4s, v7.4s, v2.s[3]\n" + "fmla v24.4s, v7.4s, v3.s[3]\n" + "fmla v26.4s, v7.4s, v4.s[3]\n" + "fmla v28.4s, v7.4s, v5.s[3]\n" + "fmla v30.4s, v7.4s, v6.s[3]\n" + "ldr q7, [x9, #0x0]\n" + "fmla v19.4s, v8.4s, v0.s[3]\n" + "ldr q0, [x26, #0x0]\n" + "fmla v21.4s, v8.4s, v1.s[3]\n" + "ldr q1, [x25, #0x0]\n" + "fmla v23.4s, v8.4s, v2.s[3]\n" + "ldr q2, [x24, #0x0]\n" + "fmla v25.4s, v8.4s, v3.s[3]\n" + "ldr q3, [x23, #0x0]\n" + "fmla v27.4s, v8.4s, v4.s[3]\n" + "ldr q4, [x22, #0x0]\n" + "fmla v29.4s, v8.4s, v5.s[3]\n" + "ldr q5, [x21, #0x0]\n" + "fmla v31.4s, v8.4s, v6.s[3]\n" + "ldr q6, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 7: Multiply loop: Single iteration only + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmlal v8.4s, v7.4h, v17.4h\n" + "fmlal2 v9.4s, v7.4h, v17.4h\n" + "movi v14.16b, #0x0\n" + "add x22, x22, #0x10\n" + "fmlal v11.4s, v10.4h, v17.4h\n" + "fmlal2 v12.4s, v10.4h, v17.4h\n" + "movi v15.16b, #0x0\n" + "add x21, x21, #0x10\n" + "fmlal v14.4s, v13.4h, v17.4h\n" + "movi v7.16b, #0x0\n" + "add x20, x20, #0x10\n" + "prfm pldl1keep, [x26, #0x80]\n" + "fmlal2 v15.4s, v13.4h, v17.4h\n" + "prfm pldl1keep, [x25, #0x80]\n" + "sub x27, x27, #0x4\n" + "prfm pldl1keep, [x24, #0x80]\n" + "fmla v18.4s, v8.4s, v0.s[0]\n" + "fmla v20.4s, v8.4s, v1.s[0]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "fmla v22.4s, v8.4s, v2.s[0]\n" + "fmla v24.4s, v8.4s, v3.s[0]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v26.4s, v8.4s, v4.s[0]\n" + "fmla v28.4s, v8.4s, v5.s[0]\n" + "add x9, x9, #0x40\n" + "fmla v30.4s, v8.4s, v6.s[0]\n" + "fmla v19.4s, v9.4s, v0.s[0]\n" + "movi v8.16b, #0x0\n" + "fmla v21.4s, v9.4s, v1.s[0]\n" + "fmla v23.4s, v9.4s, v2.s[0]\n" + "fmla v25.4s, v9.4s, v3.s[0]\n" + "fmla v27.4s, v9.4s, v4.s[0]\n" + "fmla v29.4s, v9.4s, v5.s[0]\n" + "fmla v31.4s, v9.4s, v6.s[0]\n" + "fmla v18.4s, v11.4s, v0.s[1]\n" + "fmla v20.4s, v11.4s, v1.s[1]\n" + "fmla v22.4s, v11.4s, v2.s[1]\n" + "fmla v24.4s, v11.4s, v3.s[1]\n" + "fmla v26.4s, v11.4s, v4.s[1]\n" + "fmla v28.4s, v11.4s, v5.s[1]\n" + "fmla v30.4s, v11.4s, v6.s[1]\n" + "fmla v19.4s, v12.4s, v0.s[1]\n" + "fmla v21.4s, v12.4s, v1.s[1]\n" + "fmla v23.4s, v12.4s, v2.s[1]\n" + "fmla v25.4s, v12.4s, v3.s[1]\n" + "fmla v27.4s, v12.4s, v4.s[1]\n" + "fmla v29.4s, v12.4s, v5.s[1]\n" + "fmla v31.4s, v12.4s, v6.s[1]\n" + "fmlal v7.4s, v16.4h, v17.4h\n" + "fmlal2 v8.4s, v16.4h, v17.4h\n" + "fmla v18.4s, v14.4s, v0.s[2]\n" + "fmla v20.4s, v14.4s, v1.s[2]\n" + "fmla v22.4s, v14.4s, v2.s[2]\n" + "fmla v24.4s, v14.4s, v3.s[2]\n" + "fmla v26.4s, v14.4s, v4.s[2]\n" + "fmla v28.4s, v14.4s, v5.s[2]\n" + "fmla v30.4s, v14.4s, v6.s[2]\n" + "fmla v19.4s, v15.4s, v0.s[2]\n" + "fmla v21.4s, v15.4s, v1.s[2]\n" + "fmla v23.4s, v15.4s, v2.s[2]\n" + "fmla v25.4s, v15.4s, v3.s[2]\n" + "fmla v27.4s, v15.4s, v4.s[2]\n" + "fmla v29.4s, v15.4s, v5.s[2]\n" + "fmla v31.4s, v15.4s, v6.s[2]\n" + "fmla v18.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v7.4s, v1.s[3]\n" + "fmla v22.4s, v7.4s, v2.s[3]\n" + "fmla v24.4s, v7.4s, v3.s[3]\n" + "fmla v26.4s, v7.4s, v4.s[3]\n" + "fmla v28.4s, v7.4s, v5.s[3]\n" + "fmla v30.4s, v7.4s, v6.s[3]\n" + "fmla v19.4s, v8.4s, v0.s[3]\n" + "fmla v21.4s, v8.4s, v1.s[3]\n" + "fmla v23.4s, v8.4s, v2.s[3]\n" + "fmla v25.4s, v8.4s, v3.s[3]\n" + "fmla v27.4s, v8.4s, v4.s[3]\n" + "fmla v29.4s, v8.4s, v5.s[3]\n" + "fmla v31.4s, v8.4s, v6.s[3]\n" + "7:" // Height 7: Multiply loop: Main loop skip + "cbz x27, 9f\n" + "8:" // Height 7: Multiply loop: Odd block loop + "ldr q9, [x9, #0x0]\n" + "ldr s0, [x26], #0x4\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "ldr s1, [x25], #0x4\n" + "ldr s2, [x24], #0x4\n" + "sub x27, x27, #0x1\n" + "add x9, x9, #0x10\n" + "ldr s3, [x23], #0x4\n" + "ldr s4, [x22], #0x4\n" + "ldr s5, [x21], #0x4\n" + "ldr s6, [x20], #0x4\n" + "fmlal v10.4s, v9.4h, v17.4h\n" + "fmlal2 v11.4s, v9.4h, v17.4h\n" + "fmla v18.4s, v10.4s, v0.s[0]\n" + "fmla v20.4s, v10.4s, v1.s[0]\n" + "fmla v22.4s, v10.4s, v2.s[0]\n" + "fmla v24.4s, v10.4s, v3.s[0]\n" + "fmla v26.4s, v10.4s, v4.s[0]\n" + "fmla v28.4s, v10.4s, v5.s[0]\n" + "fmla v30.4s, v10.4s, v6.s[0]\n" + "fmla v19.4s, v11.4s, v0.s[0]\n" + "fmla v21.4s, v11.4s, v1.s[0]\n" + "fmla v23.4s, v11.4s, v2.s[0]\n" + "fmla v25.4s, v11.4s, v3.s[0]\n" + "fmla v27.4s, v11.4s, v4.s[0]\n" + "fmla v29.4s, v11.4s, v5.s[0]\n" + "fmla v31.4s, v11.4s, v6.s[0]\n" + "cbnz x27, 8b\n" + "9:" // Height 7: Multiply loop: No odd multiplies + "ldr x25, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x28, #0x0]\n" + "str q18, [x28, #0x0]\n" + "str q19, [x28, #0x10]\n" + "add x20, x28, x25\n" + "add x28, x28, #0x20\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q20, [x20, #0x0]\n" + "add x24, x20, x25\n" + "add x23, x24, x25\n" + "add x22, x23, x25\n" + "prfm pstl1keep, [x24, #0x0]\n" + "prfm pstl1keep, [x23, #0x0]\n" + "str q21, [x20, #0x10]\n" + "add x21, x22, x25\n" + "prfm pstl1keep, [x22, #0x0]\n" + "str q22, [x24, #0x0]\n" + "add x20, x21, x25\n" + "prfm pstl1keep, [x21, #0x0]\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q23, [x24, #0x10]\n" + "str q24, [x23, #0x0]\n" + "str q25, [x23, #0x10]\n" + "str q26, [x22, #0x0]\n" + "str q27, [x22, #0x10]\n" + "str q28, [x21, #0x0]\n" + "str q29, [x21, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x10, x10, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v2", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x10", + "x11", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27", + "x28", + "x9"); +#endif // __aarch64__ +} + +void NOINLINE gemmkernel_8x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp) { +#ifdef __aarch64__ + __asm__ __volatile__( + "ldr w20, [%x[gp], %[offsetof_beta]]\n" + "mov x12, #0x1\n" + "fmov v15.8h, #1.0\n" + "ldr x11, [%x[gp], %[offsetof_b_block_cols]]\n" + "ldr x10, [%x[gp], %[offsetof_B]]\n" + "ldr x9, [%x[gp], %[offsetof_C]]\n" + "bic x20, x20, #0x80000000\n" + "cmp x20, #0x0\n" + "csel x12, XZR, x12, EQ\n" + "1:" // Height 8: Column loop + "tbz x12, #0, 2f\n" + "ldr q16, [x9, #0x0]\n" + "ldr q17, [x9, #0x10]\n" + "add x20, %x[gp], %[offsetof_beta]\n" + "ld1r { v0.4s }, [x20]\n" + "ldr x21, [%x[gp], %[offsetof_ldc]]\n" + "add x20, x9, x21\n" + "ldr q18, [x20, #0x0]\n" + "ldr q19, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q20, [x20, #0x0]\n" + "ldr q21, [x20, #0x10]\n" + "add x20, x20, x21\n" + "ldr q22, [x20, #0x0]\n" + "ldr q23, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v16.4s, v16.4s, v0.4s\n" + "ldr q24, [x20, #0x0]\n" + "ldr q25, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v17.4s, v17.4s, v0.4s\n" + "ldr q26, [x20, #0x0]\n" + "ldr q27, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v18.4s, v18.4s, v0.4s\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "add x20, x20, x21\n" + "fmul v19.4s, v19.4s, v0.4s\n" + "ldr q30, [x20, #0x0]\n" + "ldr q31, [x20, #0x10]\n" + "fmul v20.4s, v20.4s, v0.4s\n" + "fmul v21.4s, v21.4s, v0.4s\n" + "fmul v22.4s, v22.4s, v0.4s\n" + "fmul v23.4s, v23.4s, v0.4s\n" + "fmul v24.4s, v24.4s, v0.4s\n" + "fmul v25.4s, v25.4s, v0.4s\n" + "fmul v26.4s, v26.4s, v0.4s\n" + "fmul v27.4s, v27.4s, v0.4s\n" + "fmul v28.4s, v28.4s, v0.4s\n" + "fmul v29.4s, v29.4s, v0.4s\n" + "fmul v30.4s, v30.4s, v0.4s\n" + "fmul v31.4s, v31.4s, v0.4s\n" + "b 3f\n" + "2:" // Height 8: no accumulate + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "3:" // Height 8: setup done + "ldr x21, [%x[gp], %[offsetof_A]]\n" + "ldr x20, [%x[gp], %[offsetof_lda]]\n" + "ldr x28, [%x[gp], %[offsetof_k]]\n" + "mov x27, x21\n" + "add x26, x27, x20\n" + "add x25, x26, x20\n" + "add x24, x25, x20\n" + "add x23, x24, x20\n" + "add x22, x23, x20\n" + "add x21, x22, x20\n" + "add x20, x21, x20\n" + "cmp x28, #0x4\n" + "blt 7f\n" + "ldr q0, [x27, #0x0]\n" + "ldr q8, [x10, #0x0]\n" + "cmp x28, #0x8\n" + "ldr q1, [x26, #0x0]\n" + "ldr q2, [x25, #0x0]\n" + "ldr q3, [x24, #0x0]\n" + "ldr q4, [x23, #0x0]\n" + "ldr q5, [x22, #0x0]\n" + "ldr q6, [x21, #0x0]\n" + "ldr q7, [x20, #0x0]\n" + "ldr q11, [x10, #0x10]\n" + "ldr q14, [x10, #0x20]\n" + "blt 6f\n" + "5:" // Height 8: Multiply loop: Main loop head + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "sub x28, x28, #0x4\n" + "add x27, x27, #0x10\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x26, x26, #0x10\n" + "add x25, x25, #0x10\n" + "fmlal v9.4s, v8.4h, v15.4h\n" + "fmlal2 v10.4s, v8.4h, v15.4h\n" + "movi v8.16b, #0x0\n" + "add x24, x24, #0x10\n" + "fmlal v12.4s, v11.4h, v15.4h\n" + "fmlal2 v13.4s, v11.4h, v15.4h\n" + "movi v11.16b, #0x0\n" + "add x23, x23, #0x10\n" + "fmlal v8.4s, v14.4h, v15.4h\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "prfm pldl1keep, [x27, #0x80]\n" + "add x20, x20, #0x10\n" + "cmp x28, #0x8\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "fmla v16.4s, v9.4s, v0.s[0]\n" + "fmla v18.4s, v9.4s, v1.s[0]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmla v20.4s, v9.4s, v2.s[0]\n" + "fmla v22.4s, v9.4s, v3.s[0]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v24.4s, v9.4s, v4.s[0]\n" + "fmla v26.4s, v9.4s, v5.s[0]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v28.4s, v9.4s, v6.s[0]\n" + "fmla v30.4s, v9.4s, v7.s[0]\n" + "movi v9.16b, #0x0\n" + "fmla v17.4s, v10.4s, v0.s[0]\n" + "fmla v19.4s, v10.4s, v1.s[0]\n" + "fmla v21.4s, v10.4s, v2.s[0]\n" + "fmla v23.4s, v10.4s, v3.s[0]\n" + "fmla v25.4s, v10.4s, v4.s[0]\n" + "fmla v27.4s, v10.4s, v5.s[0]\n" + "fmla v29.4s, v10.4s, v6.s[0]\n" + "fmla v31.4s, v10.4s, v7.s[0]\n" + "ldr q10, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmlal2 v9.4s, v14.4h, v15.4h\n" + "ldr q14, [x10, #0x20]\n" + "fmla v16.4s, v12.4s, v0.s[1]\n" + "fmla v18.4s, v12.4s, v1.s[1]\n" + "fmla v20.4s, v12.4s, v2.s[1]\n" + "fmla v22.4s, v12.4s, v3.s[1]\n" + "fmla v24.4s, v12.4s, v4.s[1]\n" + "fmla v26.4s, v12.4s, v5.s[1]\n" + "fmla v28.4s, v12.4s, v6.s[1]\n" + "fmla v30.4s, v12.4s, v7.s[1]\n" + "fmla v17.4s, v13.4s, v0.s[1]\n" + "movi v12.16b, #0x0\n" + "fmla v19.4s, v13.4s, v1.s[1]\n" + "fmla v21.4s, v13.4s, v2.s[1]\n" + "fmla v23.4s, v13.4s, v3.s[1]\n" + "fmla v25.4s, v13.4s, v4.s[1]\n" + "fmla v27.4s, v13.4s, v5.s[1]\n" + "fmla v29.4s, v13.4s, v6.s[1]\n" + "fmla v31.4s, v13.4s, v7.s[1]\n" + "fmlal v11.4s, v10.4h, v15.4h\n" + "fmla v16.4s, v8.4s, v0.s[2]\n" + "fmla v18.4s, v8.4s, v1.s[2]\n" + "fmla v20.4s, v8.4s, v2.s[2]\n" + "fmla v22.4s, v8.4s, v3.s[2]\n" + "fmla v24.4s, v8.4s, v4.s[2]\n" + "fmla v26.4s, v8.4s, v5.s[2]\n" + "fmla v28.4s, v8.4s, v6.s[2]\n" + "fmla v30.4s, v8.4s, v7.s[2]\n" + "ldr q8, [x10, #0x0]\n" + "fmla v17.4s, v9.4s, v0.s[2]\n" + "fmla v19.4s, v9.4s, v1.s[2]\n" + "fmla v21.4s, v9.4s, v2.s[2]\n" + "fmla v23.4s, v9.4s, v3.s[2]\n" + "fmla v25.4s, v9.4s, v4.s[2]\n" + "fmla v27.4s, v9.4s, v5.s[2]\n" + "fmla v29.4s, v9.4s, v6.s[2]\n" + "fmla v31.4s, v9.4s, v7.s[2]\n" + "fmlal2 v12.4s, v10.4h, v15.4h\n" + "fmla v16.4s, v11.4s, v0.s[3]\n" + "fmla v18.4s, v11.4s, v1.s[3]\n" + "fmla v20.4s, v11.4s, v2.s[3]\n" + "fmla v22.4s, v11.4s, v3.s[3]\n" + "fmla v24.4s, v11.4s, v4.s[3]\n" + "fmla v26.4s, v11.4s, v5.s[3]\n" + "fmla v28.4s, v11.4s, v6.s[3]\n" + "fmla v30.4s, v11.4s, v7.s[3]\n" + "ldr q11, [x10, #0x10]\n" + "fmla v17.4s, v12.4s, v0.s[3]\n" + "ldr q0, [x27, #0x0]\n" + "fmla v19.4s, v12.4s, v1.s[3]\n" + "ldr q1, [x26, #0x0]\n" + "fmla v21.4s, v12.4s, v2.s[3]\n" + "ldr q2, [x25, #0x0]\n" + "fmla v23.4s, v12.4s, v3.s[3]\n" + "ldr q3, [x24, #0x0]\n" + "fmla v25.4s, v12.4s, v4.s[3]\n" + "ldr q4, [x23, #0x0]\n" + "fmla v27.4s, v12.4s, v5.s[3]\n" + "ldr q5, [x22, #0x0]\n" + "fmla v29.4s, v12.4s, v6.s[3]\n" + "ldr q6, [x21, #0x0]\n" + "fmla v31.4s, v12.4s, v7.s[3]\n" + "ldr q7, [x20, #0x0]\n" + "bge 5b\n" + "6:" // Height 8: Multiply loop: Single iteration only + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x27, x27, #0x10\n" + "add x26, x26, #0x10\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmlal v9.4s, v8.4h, v15.4h\n" + "fmlal2 v10.4s, v8.4h, v15.4h\n" + "movi v8.16b, #0x0\n" + "add x23, x23, #0x10\n" + "fmlal v12.4s, v11.4h, v15.4h\n" + "fmlal2 v13.4s, v11.4h, v15.4h\n" + "movi v11.16b, #0x0\n" + "add x22, x22, #0x10\n" + "fmlal v8.4s, v14.4h, v15.4h\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "prfm pldl1keep, [x27, #0x80]\n" + "prfm pldl1keep, [x26, #0x80]\n" + "prfm pldl1keep, [x25, #0x80]\n" + "sub x28, x28, #0x4\n" + "fmla v16.4s, v9.4s, v0.s[0]\n" + "fmla v18.4s, v9.4s, v1.s[0]\n" + "fmla v20.4s, v9.4s, v2.s[0]\n" + "prfm pldl1keep, [x24, #0x80]\n" + "prfm pldl1keep, [x23, #0x80]\n" + "fmla v22.4s, v9.4s, v3.s[0]\n" + "fmla v24.4s, v9.4s, v4.s[0]\n" + "prfm pldl1keep, [x22, #0x80]\n" + "prfm pldl1keep, [x21, #0x80]\n" + "fmla v26.4s, v9.4s, v5.s[0]\n" + "fmla v28.4s, v9.4s, v6.s[0]\n" + "prfm pldl1keep, [x20, #0x80]\n" + "fmla v30.4s, v9.4s, v7.s[0]\n" + "fmla v17.4s, v10.4s, v0.s[0]\n" + "movi v9.16b, #0x0\n" + "fmla v19.4s, v10.4s, v1.s[0]\n" + "fmla v21.4s, v10.4s, v2.s[0]\n" + "fmla v23.4s, v10.4s, v3.s[0]\n" + "fmla v25.4s, v10.4s, v4.s[0]\n" + "fmla v27.4s, v10.4s, v5.s[0]\n" + "fmla v29.4s, v10.4s, v6.s[0]\n" + "fmla v31.4s, v10.4s, v7.s[0]\n" + "ldr q10, [x10, #0x30]\n" + "fmlal2 v9.4s, v14.4h, v15.4h\n" + "add x10, x10, #0x40\n" + "fmla v16.4s, v12.4s, v0.s[1]\n" + "fmla v18.4s, v12.4s, v1.s[1]\n" + "fmla v20.4s, v12.4s, v2.s[1]\n" + "fmla v22.4s, v12.4s, v3.s[1]\n" + "fmla v24.4s, v12.4s, v4.s[1]\n" + "fmla v26.4s, v12.4s, v5.s[1]\n" + "fmla v28.4s, v12.4s, v6.s[1]\n" + "fmla v30.4s, v12.4s, v7.s[1]\n" + "movi v12.16b, #0x0\n" + "fmla v17.4s, v13.4s, v0.s[1]\n" + "fmla v19.4s, v13.4s, v1.s[1]\n" + "fmla v21.4s, v13.4s, v2.s[1]\n" + "fmla v23.4s, v13.4s, v3.s[1]\n" + "fmla v25.4s, v13.4s, v4.s[1]\n" + "fmla v27.4s, v13.4s, v5.s[1]\n" + "fmla v29.4s, v13.4s, v6.s[1]\n" + "fmla v31.4s, v13.4s, v7.s[1]\n" + "fmlal v11.4s, v10.4h, v15.4h\n" + "fmla v16.4s, v8.4s, v0.s[2]\n" + "fmla v18.4s, v8.4s, v1.s[2]\n" + "fmla v20.4s, v8.4s, v2.s[2]\n" + "fmla v22.4s, v8.4s, v3.s[2]\n" + "fmla v24.4s, v8.4s, v4.s[2]\n" + "fmla v26.4s, v8.4s, v5.s[2]\n" + "fmla v28.4s, v8.4s, v6.s[2]\n" + "fmla v30.4s, v8.4s, v7.s[2]\n" + "fmla v17.4s, v9.4s, v0.s[2]\n" + "fmla v19.4s, v9.4s, v1.s[2]\n" + "fmla v21.4s, v9.4s, v2.s[2]\n" + "fmla v23.4s, v9.4s, v3.s[2]\n" + "fmla v25.4s, v9.4s, v4.s[2]\n" + "fmla v27.4s, v9.4s, v5.s[2]\n" + "fmla v29.4s, v9.4s, v6.s[2]\n" + "fmla v31.4s, v9.4s, v7.s[2]\n" + "fmlal2 v12.4s, v10.4h, v15.4h\n" + "fmla v16.4s, v11.4s, v0.s[3]\n" + "fmla v18.4s, v11.4s, v1.s[3]\n" + "fmla v20.4s, v11.4s, v2.s[3]\n" + "fmla v22.4s, v11.4s, v3.s[3]\n" + "fmla v24.4s, v11.4s, v4.s[3]\n" + "fmla v26.4s, v11.4s, v5.s[3]\n" + "fmla v28.4s, v11.4s, v6.s[3]\n" + "fmla v30.4s, v11.4s, v7.s[3]\n" + "fmla v17.4s, v12.4s, v0.s[3]\n" + "fmla v19.4s, v12.4s, v1.s[3]\n" + "fmla v21.4s, v12.4s, v2.s[3]\n" + "fmla v23.4s, v12.4s, v3.s[3]\n" + "fmla v25.4s, v12.4s, v4.s[3]\n" + "fmla v27.4s, v12.4s, v5.s[3]\n" + "fmla v29.4s, v12.4s, v6.s[3]\n" + "fmla v31.4s, v12.4s, v7.s[3]\n" + "7:" // Height 8: Multiply loop: Main loop skip + "cbz x28, 9f\n" + "8:" // Height 8: Multiply loop: Odd block loop + "ldr q13, [x10, #0x0]\n" + "ldr s0, [x27], #0x4\n" + "movi v14.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "ldr s1, [x26], #0x4\n" + "ldr s2, [x25], #0x4\n" + "sub x28, x28, #0x1\n" + "add x10, x10, #0x10\n" + "ldr s3, [x24], #0x4\n" + "ldr s4, [x23], #0x4\n" + "ldr s5, [x22], #0x4\n" + "ldr s6, [x21], #0x4\n" + "fmlal v14.4s, v13.4h, v15.4h\n" + "fmlal2 v8.4s, v13.4h, v15.4h\n" + "ldr s7, [x20], #0x4\n" + "fmla v16.4s, v14.4s, v0.s[0]\n" + "fmla v18.4s, v14.4s, v1.s[0]\n" + "fmla v20.4s, v14.4s, v2.s[0]\n" + "fmla v22.4s, v14.4s, v3.s[0]\n" + "fmla v24.4s, v14.4s, v4.s[0]\n" + "fmla v26.4s, v14.4s, v5.s[0]\n" + "fmla v28.4s, v14.4s, v6.s[0]\n" + "fmla v30.4s, v14.4s, v7.s[0]\n" + "fmla v17.4s, v8.4s, v0.s[0]\n" + "fmla v19.4s, v8.4s, v1.s[0]\n" + "fmla v21.4s, v8.4s, v2.s[0]\n" + "fmla v23.4s, v8.4s, v3.s[0]\n" + "fmla v25.4s, v8.4s, v4.s[0]\n" + "fmla v27.4s, v8.4s, v5.s[0]\n" + "fmla v29.4s, v8.4s, v6.s[0]\n" + "fmla v31.4s, v8.4s, v7.s[0]\n" + "cbnz x28, 8b\n" + "9:" // Height 8: Multiply loop: No odd multiplies + "ldr x26, [%x[gp], %[offsetof_ldc]]\n" + "prfm pstl1keep, [x9, #0x0]\n" + "str q16, [x9, #0x0]\n" + "str q17, [x9, #0x10]\n" + "add x20, x9, x26\n" + "add x9, x9, #0x20\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q18, [x20, #0x0]\n" + "add x25, x20, x26\n" + "add x24, x25, x26\n" + "add x23, x24, x26\n" + "prfm pstl1keep, [x25, #0x0]\n" + "prfm pstl1keep, [x24, #0x0]\n" + "str q19, [x20, #0x10]\n" + "add x22, x23, x26\n" + "prfm pstl1keep, [x23, #0x0]\n" + "str q20, [x25, #0x0]\n" + "add x21, x22, x26\n" + "add x20, x21, x26\n" + "prfm pstl1keep, [x22, #0x0]\n" + "prfm pstl1keep, [x21, #0x0]\n" + "str q21, [x25, #0x10]\n" + "prfm pstl1keep, [x20, #0x0]\n" + "str q22, [x24, #0x0]\n" + "str q23, [x24, #0x10]\n" + "str q24, [x23, #0x0]\n" + "str q25, [x23, #0x10]\n" + "str q26, [x22, #0x0]\n" + "str q27, [x22, #0x10]\n" + "str q28, [x21, #0x0]\n" + "str q29, [x21, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "subs x11, x11, #0x1\n" + "bgt 1b\n" + : + : [gp] "r"(gp), + [offsetof_A] "I"(offsetof(GemmParamsFP16, A)), + [offsetof_B] "I"(offsetof(GemmParamsFP16, B)), + [offsetof_C] "I"(offsetof(GemmParamsFP16, C)), + [offsetof_b_block_cols] "I"(offsetof(GemmParamsFP16, b_block_cols)), + [offsetof_beta] "I"(offsetof(GemmParamsFP16, beta)), + [offsetof_k] "I"(offsetof(GemmParamsFP16, k)), + [offsetof_lda] "I"(offsetof(GemmParamsFP16, lda)), + [offsetof_ldc] "I"(offsetof(GemmParamsFP16, ldc)) + : "cc", + "memory", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v2", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v3", + "v30", + "v31", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "x10", + "x11", + "x12", + "x20", + "x21", + "x22", + "x23", + "x24", + "x25", + "x26", + "x27", + "x28", + "x9"); +#endif // __aarch64__ +} + +} // namespace kleidiai + +#endif diff --git a/src/KleidiAIFP16UKernelsNeon.h b/src/KleidiAIFP16UKernelsNeon.h new file mode 100644 index 000000000..3d265dfb9 --- /dev/null +++ b/src/KleidiAIFP16UKernelsNeon.h @@ -0,0 +1,29 @@ +/* + * @lint-ignore-every LICENSELINT + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate + * SPDX-License-Identifier: BSD-3-Clause + */ +#ifdef FBGEMM_ENABLE_KLEIDIAI + +#pragma once +#include +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/FbgemmFPCommon.h" +#include "fbgemm/Types.h" + +namespace kleidiai { + +using GemmParamsFP16 = fbgemm::GemmParams; + +void NOINLINE gemmkernel_1x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_2x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_3x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_4x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_5x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_6x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_7x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); +void NOINLINE gemmkernel_8x1_Neon_fp16_fA0fB0fC0(GemmParamsFP16* gp); + +} // namespace kleidiai + +#endif From c02799e7dd6871ab3a1930b369e7eb89476c6ba3 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 17 Dec 2024 22:06:24 -0800 Subject: [PATCH 03/16] Optimzed backward pass for ROCm devices (#3488) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/592 X-link: https://github.com/facebookresearch/FBGEMM/pull/568 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3488 - Break up D66310520 into backend and frontend diffs Reviewed By: leitian Differential Revision: D66986498 fbshipit-source-id: 1779a9a2a4611eda1298afc0e840839c7da46b10 --- fbgemm_gpu/cmake/TbeTraining.cmake | 10 + fbgemm_gpu/cmake/tbe_sources.py | 19 + .../genscript/generate_backward_split.py | 22 + .../embedding_backward_dense_host_cpu.cpp | 7 +- ...embedding_backward_split_host_template.cpp | 15 +- ...ing_backward_split_kernel_warp_template.cu | 309 ++++++++++ ...embedding_backward_split_meta_template.cpp | 3 + .../embedding_backward_split_template.cu | 136 ++++- ..._backward_split_device_kernel_template.hip | 461 +++++++++++++++ ...t_table_batched_embeddings_ops_training.py | 1 + fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 1 - .../include/fbgemm_gpu/rocm/cdna_guard.h | 51 ++ .../fbgemm_gpu/rocm/split_embeddings_common.h | 550 ++++++++++++++++++ fbgemm_gpu/test/tbe/cache/cache_common.py | 3 +- fbgemm_gpu/test/tbe/cache/cache_test.py | 6 + .../tbe/training/backward_optimizers_test.py | 76 +++ fbgemm_gpu/test/test_utils.py | 20 + 17 files changed, 1679 insertions(+), 11 deletions(-) create mode 100644 fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip create mode 100644 fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h create mode 100644 fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h diff --git a/fbgemm_gpu/cmake/TbeTraining.cmake b/fbgemm_gpu/cmake/TbeTraining.cmake index 866b481e6..b02c7b582 100644 --- a/fbgemm_gpu/cmake/TbeTraining.cmake +++ b/fbgemm_gpu/cmake/TbeTraining.cmake @@ -42,6 +42,14 @@ handle_genfiles(gen_py_files_training) handle_genfiles(gen_py_files_defused_optim) +################################################################################ +# FBGEMM_GPU Generated HIP-Specific Sources +################################################################################ + +get_tbe_sources_list(gen_hip_files_training) +handle_genfiles_rocm(gen_hip_files_training) + + ################################################################################ # TBE C++ Training Targets ################################################################################ @@ -152,6 +160,8 @@ gpu_cpp_library( ${gen_cpu_files_training} GPU_SRCS ${gen_gpu_files_training} + HIP_SPECIFIC_SRCS + ${gen_hip_files_training} GPU_FLAGS ${TORCH_CUDA_OPTIONS} DEPS diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index c36bd06d2..ecbdf8063 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -472,6 +472,25 @@ ] ) +gen_hip_files_training = [ + "gen_embedding_backward_split_{}{}_device_kernel_hip.hip".format( + "weighted" if weighted else "unweighted", + "_nobag" if nobag else "", + ) + for nobag in [ + True, + False, + ] + for weighted in ( + [ + True, + False, + ] + if not nobag + else [False] + ) +] + ################################################################################ # Python Training Code ################################################################################ diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index ac60a8dad..c97714857 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -310,6 +310,27 @@ def generate_backward_indices() -> None: ssd=ssd, ) + @staticmethod + def generate_rocm_backward_split(**kwargs: Any) -> None: + # Generate backward device kernels based on weighted (True/False), VBE + # (True/False), no bag (True/False) + template_filepath = ( + "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" + ) + + BackwardSplitGenerator.render_backward_templates( + template_filepath, + "", + "{}gen_embedding_backward_{}_device_kernel_hip.hip", + { + "has_gpu_support": True, + "has_vbe_support": False, + "has_ssd_support": False, + "dense": False, + "gen_once": False, + }, + ) + @staticmethod def generate_python_sources( all_optimizers: List[str], ssd_optimizers: List[str] @@ -369,6 +390,7 @@ def generate() -> None: BackwardSplitGenerator.generate_backward_split( ssd_tensors=ssd_tensors, **optimizer ) + BackwardSplitGenerator.generate_rocm_backward_split() # Generate common device kernels for backwards BackwardSplitGenerator.generate_backward_device() diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index ee608e83e..626838e93 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -171,7 +171,8 @@ Tensor split_embedding_codegen_lookup_dense_function( Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */, c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, - c10::SymInt /* vbe_output_size = -1 */) { + c10::SymInt /* vbe_output_size = -1 */, + bool /* mixed_D = true */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, @@ -190,7 +191,7 @@ Tensor split_embedding_codegen_lookup_dense_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); @@ -198,7 +199,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 63fa373a5..3efec2527 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -152,6 +152,7 @@ enum SSDTensor { {%- else %} D_offsets, max_D, + mixed_D, {%- endif %} {# /* if nobag */ #} hash_size_cumsum, total_hash_size_bits, @@ -224,6 +225,7 @@ enum SSDTensor { Variable(), // D_offsets Variable(), // total_D Variable(), // max_D + Variable(), // mixed_D {%- endif %} Variable(), // hash_size_cumsum Variable(), //total_hash_size_bits @@ -304,6 +306,7 @@ enum SSDTensor { D_offsets, total_D, max_D, + mixed_D, {%- endif %} hash_size_cumsum, total_hash_size_bits, @@ -484,6 +487,7 @@ Tensor {%- else %} const Tensor& D_offsets, const c10::SymInt max_D, + const bool mixed_D, {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, @@ -566,6 +570,7 @@ class {{ autograd_func }} : const Tensor& D_offsets, const c10::SymInt total_D, const c10::SymInt max_D, + const bool mixed_D, {%- else %} const c10::SymInt D, {%- endif %} @@ -762,6 +767,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; @@ -877,6 +883,7 @@ class {{ autograd_func }} : {%- if not nobag %} auto max_D = ctx->saved_data["max_D"].toSymInt(); + const auto mixed_D = ctx->saved_data["mixed_D"].toBool(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); {%- else %} auto D = ctx->saved_data["D"].toSymInt(); @@ -1072,10 +1079,11 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- if ssd %} const std::optional& ssd_tensors = std::nullopt, {%- endif %} - const double gwd_lower_bound = 0 + const double gwd_lower_bound = 0, {%- else %} - const c10::SymInt vbe_output_size = -1 + const c10::SymInt vbe_output_size = -1, {%- endif %} + const bool mixed_D = true ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -1191,7 +1199,8 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { {%- if ssd %} " Tensor[]? ssd_tensors=None," {%- endif %} - " float gwd_lower_bound=0 " + " float gwd_lower_bound=0, " + " bool mixed_D=True" ") -> Tensor", {PT2_COMPLIANT_TAG}); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 3b230b010..0e4f552eb 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -521,5 +521,314 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row #endif //////////////////////////////////////////////////////////////////////////////// +{%- endif %} + +{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +#include +#include +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#include "gen_embedding_backward_split_{{ desc_suffix }}{{ ndesc }}_device_kernel_hip.hip" + +template < + typename emb_t, + typename grad_t, + typename cache_t, + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking, + int32_t embedding_dim, + int32_t weight_decay_mode_v> +__global__ void +hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { + {%- if not nobag %} + int32_t T = D_offsets.size(0) - 1; + {%- else %} + int32_t T = weights_offsets.size(0); + {%- endif %} + + auto p_output_grad = grad_output.data(); + auto p_emb_table = dev_weights.data(); + auto p_hash_size_cumsum = hash_size_cumsum.data(); + auto p_sorted_linear_indices_run = sorted_linear_indices_run.data(); + auto p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.data(); + auto p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.data(); + auto p_sorted_infos = sorted_infos.data(); + {%- if weighted %} + auto p_indice_weights_sorted = sorted_indice_weights.data(); + {%- endif %} + auto emb_dim = embedding_dim; + constexpr int32_t segment_prefetch = 2; + constexpr int32_t segment_unroll = 8; + constexpr int32_t segment_split = 0; + auto batch = grad_output.size(0); + auto num_rows = dev_weights.size(0) / T / max_D; + {%- if weighted %} + constexpr bool is_weighted = true; + {%- else %} + constexpr bool is_weighted = false; + {%- endif %} + rocm::{{optimizer}}_kernel_arg_t opt_karg; + opt_karg.p_momentum = momentum1_dev.data(); + opt_karg.eps = eps; + opt_karg.learning_rate = learning_rate; + // weight_decay(_mode) is supplied as args.split_function_args_no_defaults + opt_karg.weight_decay_mode = weight_decay_mode_v; + opt_karg.weight_decay = weight_decay; + auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for(shift = 0; shift < 32; shift++) + if((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + rocm::magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; + }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< + rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_kernel_arg_t, + emb_t, + cache_t, + grad_t, + BLOCK_SIZE, + embedding_dim, + segment_prefetch, + segment_unroll, + segment_split, + is_weighted>(p_output_grad, + p_emb_table, + p_hash_size_cumsum, + p_sorted_linear_indices_run, + p_sorted_linear_indices_cumulative_run_lengths, + p_sorted_linear_indices_num_runs, + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + p_sorted_infos, + batch_mdiv, + max_segment_length_per_warp, + emb_dim, + batch, + num_rows, + T, + opt_karg + {%- if weighted %} + , p_indice_weights_sorted + {%- endif %}); +} + +{%- macro hip_template_instantiation( + emb_type, + grad_type, + cache_type, + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking, + kEmbeddingDim, + kWeighDecayMode + ) +%} +template __global__ __launch_bounds__(kBackwardMaxThreads) void +hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 +< {{ emb_type }}, + {{ grad_type }}, + {{ cache_type }}, + {{ kFixedMaxVecsPerThread }}, + {{ kThreadGroupSize }}, + {{ kUseVecBlocking }}, + {{ kEmbeddingDim }}, + {{ kWeighDecayMode }} +> ( + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64< {{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }} + {%- endif %} +); +{%- endmacro %} + +{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} + {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for emb_type in ['float', 'at::Half'] %} + {%- for cache_type in ['float', 'at::Half'] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kWeighDecayMode in [0, 1, 2] %} + {{ hip_template_instantiation( + emb_type, + grad_type, + cache_type, + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking, + kEmbeddingDim, + kWeighDecayMode + ) + }} + {%- endfor %} + {%- endfor %} + {%- endfor %} + {%- endfor %} + {%- endfor %} +{%- endmacro %} + +{%- macro hip_instantiate_templates(use_subwarp_shuffle) %} +{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) + in get_max_vecs_template_configs( + items_per_warp, + fixed_max_vecs_per_thread["backward"], + use_subwarp_shuffle, + use_vec_blocking=True, + ) +%} + {{ + hip_bulk_template_instantiations( + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking, + ) + }} +{%- endfor %} +{%- endmacro %} + +//////////////////////////////////////////////////////////////////////////////// +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +//////////////////////////////////////////////////////////////////////////////// + +{#- /* + Explicitly instantiate kernels for the FBGEMM_USE_SUBWARP_SHUFFLE case + Please see get_max_vecs_template_configs in + codegen/embedding_common_code_generator.py for more details +*/ #} + +{{ hip_instantiate_templates(use_subwarp_shuffle=True) }} + +//////////////////////////////////////////////////////////////////////////////// +#else +//////////////////////////////////////////////////////////////////////////////// + +{#- /* + Explicitly instantiate kernels for the non-FBGEMM_USE_SUBWARP_SHUFFLE case + Please see get_max_vecs_template_configs in + codegen/embedding_common_code_generator.py for more details +*/ #} + +{{ hip_instantiate_templates(use_subwarp_shuffle=False) }} + +//////////////////////////////////////////////////////////////////////////////// +#endif +//////////////////////////////////////////////////////////////////////////////// {%- endif %} // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp index 6b3d5604d..def21bd39 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp @@ -72,6 +72,9 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc {%- else %} const c10::SymInt D, {%- endif %} + {%- if not nobag and not is_index_select %} + const bool mixed_D, + {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fdd9c0f79..5029a382a 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -26,6 +26,10 @@ #include "fbgemm_gpu/split_embeddings_utils.cuh" #include "fbgemm_gpu/utils/ops_utils.h" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -211,6 +215,78 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ); + +{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select + and not is_gwd_kernel and not vbe and not ssd %} +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +template < + typename emb_t, + typename grad_t, + typename cache_t, + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking, + int32_t embedding_dim, + int32_t weight_decay_mode_v> +__global__ void +hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +); +{%- endif %} {% if is_index_select %} namespace index_select { {% else %} @@ -452,6 +528,9 @@ Tensor {{ embedding_cuda_op }}( {%- else %} const c10::SymInt D_, {%- endif %} + {%- if not nobag and not is_index_select %} + const bool mixed_D, + {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -775,6 +854,17 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} + {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select + and not is_gwd_kernel and not vbe and not ssd %} + {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} + {%- endif %} + DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -1070,7 +1160,7 @@ Tensor {{ embedding_cuda_op }}( desc_suffix, ) %} - const auto backward_warp_per_row_kernel = + auto backward_warp_per_row_kernel = {{ warp_kernel }} (), segments_per_workgroup); + blockSize = dim3(256); + warp_per_row_smem_bytes = 0; + + backward_warp_per_row_kernel = + {{ hip_kernel }} + ; + } + {%- endfor %} + {%- endfor %} + } + {%- endif %} +#endif + + #ifdef FBGEMM_GPU_MEMCHECK const auto func_name4 = "{{ warp_kernel }}"; #endif backward_warp_per_row_kernel <<>>( grad_output_accessor, @@ -1222,6 +1349,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- else %} " SymInt D, " {%- endif %} + {%- if not nobag and not is_index_select %} + " bool mixed_D, " + {%- endif %} " Tensor hash_size_cumsum, " " int total_hash_size_bits, " " Tensor indices, " diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip new file mode 100644 index 000000000..0374a1724 --- /dev/null +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -0,0 +1,461 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ + +#include +#include + +#include "fbgemm_gpu/rocm/split_embeddings_common.h" + +namespace fbgemm_gpu::rocm { +template +struct rowwise_adagrad_optimizer_t +{ + __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) + : karg(karg_) + { + } + + template + __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + { + if constexpr(segment_split == 0) + { + cache_t * p_momentum = reinterpret_cast(karg.p_momentum); + cache_t momentum = p_momentum[row_index]; // should be s_load + // compute per row square sum + cache_t local_sum_squre = .0f; + if constexpr(weight_decay_mode == 1) + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i] + w * karg.weight_decay; + local_sum_squre += a * a; + } + } + else + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t a = acc[i]; + local_sum_squre += a * a; + } + } + + cache_t avg_square = + wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / + embedding_dim; + + cache_t momentum_new = momentum + avg_square; + + cache_t multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); + cache_t correction; + + if constexpr(weight_decay_mode == 1) + { + correction = 1.0 - multiplier * karg.weight_decay; + } + else if constexpr(weight_decay_mode == 2) + { + correction = 1.0 - karg.learning_rate * karg.weight_decay; + } + else + { + correction = 1.0; + } + +// update new weight value +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i]; + w = correction * w - multiplier * a; + weight[i] = static_cast(w); + } + + p_momentum[row_index] = momentum_new; + } + } + + rowwise_adagrad_kernel_arg_t karg; +}; + +template +__device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( + const grad_t* p_output_grad, + emb_t* p_emb_table, + const int64_t* p_hash_size_cumsum, + const int64_t* p_sorted_linear_indices_run, + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, + const int32_t* p_sorted_linear_indices_num_runs, + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + {%- if not nobag %} + const int32_t* p_sorted_infos, + {%- else %} + const int64_t* p_sorted_infos, + {%- endif %} + magic_div_u32_t batch_mdiv, + uint32_t max_segment_length_per_warp, + uint32_t emb_dim, + uint32_t batch, + uint32_t num_rows, + uint32_t num_tables, + optimizer_karg_t opt_karg, + const float * p_sorted_indice_weights = nullptr) +{ + constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; + constexpr uint32_t length_mask = ~(segment_unroll - 1); + const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); + const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; + const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; + + if(run_id >= p_sorted_linear_indices_num_runs[0]) + { + return; + } + + const int64_t linear_index = p_sorted_linear_indices_run[run_id]; + + const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; + const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; + + {%- if nobag %} + const auto info_0 = p_sorted_infos[segment_start]; + int32_t t_0 = info_0 % num_tables; + {%- else %} + const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; + const auto t_0 = info_0 >> info_B_num_bits; + {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; + + const int64_t emb_idx = linear_index - hash_size; + + p_emb_table += hash_size * emb_dim; + opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); + + const int32_t segment_length = segment_end - segment_start; + + if(segment_length >= max_segment_length_per_warp) + return; + + const int32_t segment_length_mod = segment_length & length_mask; + + cache_t grad_acc[dword_per_row]; + int32_t infos[segment_unroll]; + grad_t grad_data[dword_per_row * segment_prefetch]; + emb_t emb_data[dword_per_row]; + float indice_weights[segment_unroll]; + + #pragma unroll + for(int i=0; i < dword_per_row; i++) + { + grad_acc[i] = .0f; + } + + int itr = 0; + if(segment_length_mod == 0) + goto L_tail_grad_acc; + + if constexpr (!weighted) { + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + } else { + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + indice_weights[i] = p_sorted_indice_weights[segment_start + i]; + } + } + + itr += segment_unroll; + p_sorted_infos += segment_unroll; + + if constexpr (weighted) { + p_sorted_indice_weights += segment_unroll; + } + + uint32_t bag_index; + uint32_t table_index; + + // LOOP + for(; itr < segment_length_mod; itr += segment_unroll) + { + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); + {%- else %} + table_index = infos[1] >> info_B_num_bits; + bag_index = infos[1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + if constexpr (!weighted){ + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + p_sorted_infos += segment_unroll; + + + } else { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); + + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + indice_weights[i] = p_sorted_indice_weights[segment_start + i]; + } + p_sorted_infos += segment_unroll; + p_sorted_indice_weights += segment_unroll; + } + } + + // LAST + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); + {%- else %} + table_index = infos[1] >> info_B_num_bits; + bag_index = infos[1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + if constexpr (!weighted) { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + } else { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); + } + +L_tail_grad_acc: + if(segment_length & (segment_unroll - 1)) + { + if constexpr (!weighted){ + // last, load one by one + do + { + infos[0] = p_sorted_infos[segment_start]; + p_sorted_infos++; + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + + itr++; + } while(itr < segment_length); + } else { + do + { + infos[0] = p_sorted_infos[segment_start]; + indice_weights[0] = p_sorted_indice_weights[segment_start]; + p_sorted_infos++; + p_sorted_indice_weights++; + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); + + itr++; + } while(itr < segment_length); + } + } + + // load the old emb weight data + load_row_per_warp::run( + &emb_data[0], emb_idx, p_emb_table, lane_id); + optimizer_t optimizer(opt_karg); + optimizer.template update(grad_acc, emb_data, emb_idx); + + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); +} +} // namespace fbgemm_gpu::rocm \ No newline at end of file diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 01785a00b..d8667abe0 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -3565,6 +3565,7 @@ def __init__( torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) assert self.D_offsets.numel() == T + 1 + # Required for VBE self.register_buffer( "feature_dims", diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index eb716ea6d..067e8c89f 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -180,7 +180,6 @@ def __init__( "D_offsets", torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) - assert self.D_offsets.numel() == T + 1 hash_size_cumsum = [0] + list(itertools.accumulate(rows)) if hash_size_cumsum[-1] == 0: diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h new file mode 100644 index 000000000..b55fd72fc --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#define HIP_CHECK(c) \ + { \ + if (c != hipSuccess) { \ + printf("HIP Error : %s", hipGetErrorString(c)); \ + printf(" %s %d\n", __FILE__, __LINE__); \ + exit(c); \ + } \ + } + +namespace fbgemm_gpu::rocm { + +[[nodiscard]] inline bool is_supported_cdna() { + const std::set supported_archs{"gfx942", "gfx90a"}; + int device_id = 0; + HIP_CHECK(hipGetDevice(&device_id)); + hipDeviceProp_t dev_props; + HIP_CHECK(hipGetDeviceProperties(&dev_props, device_id)); + std::string gcn_arch = dev_props.gcnArchName; + gcn_arch = gcn_arch.substr(0, gcn_arch.find(":")); + return supported_archs.contains(gcn_arch); +} + +} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h new file mode 100644 index 000000000..b3a56c4b5 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -0,0 +1,550 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ +#pragma once +#include +#include +#include + +/******************************************************************************/ +typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); +typedef float floatx2_t __attribute__((ext_vector_type(2))); +#define AMDGCN_BUFFER_RES_3 0x00027000 +#define AMDGCN_WAVE_SIZE 64 +#define THREADS_PER_ROW 64 +#define BLOCK_SIZE 256 + +namespace fbgemm_gpu::rocm { +template +union amdgcn_buffer_resource { + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + struct { + T* address; + int32_t range; + int32_t config; + }; +}; + +template +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { + amdgcn_buffer_resource buffer_resource; + buffer_resource.address = const_cast(addr); + buffer_resource.range = 0xffffffff; + buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 + + return buffer_resource.content; +} + +// buffer load fp32 +__device__ half llvm_amdgcn_raw_buffer_load_fp16( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ float llvm_amdgcn_raw_buffer_load_fp32( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ void llvm_amdgcn_raw_buffer_store_fp32( + float vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void llvm_amdgcn_raw_buffer_store_fp32x2( + floatx2_t vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +/******************************************************************************/ + +template +struct load_row_per_warp { + static __device__ void run( + emb_t* emb_data, + index_t row_index, + const emb_t* p_emb_table, + int lane_id) {} +}; + +template +struct load_row_per_warp { + static constexpr int dword_per_row = + (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + if constexpr (embedding_dim == 160) { + if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { + emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + } else { + emb_data[i] = 0.f; + } + } else { + emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + } + } + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + if ((lane_id + 128) % 192 < 160) { + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half), 0, 0); + } else { + emb_data[2] = __float2half(0.0); + } + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[4]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[6]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + } +}; + +template < + typename emb_t, + int32_t embedding_dim, + typename output_t, + bool weighted> +struct accumulate_row_per_warp { + static constexpr int dword_per_row = + (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void + run(output_t* acc, emb_t* emb_data, int lane_id, float row_weight = 1.0) { + if constexpr (!weighted) { +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + acc[i] += static_cast(emb_data[i]); + } + } else { +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + acc[i] += static_cast((float)emb_data[i] * row_weight); + } + } + } +}; + +template +struct store_row_per_warp { + static constexpr int dword_per_row = + (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { + if constexpr (embedding_dim == 160) { + for (int i = 0; i < dword_per_row; i++) { + if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { + p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; + } + } + } else { +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; + } + } + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + if ((lane_id + 128) % 192 < 160) { + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + } + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(&acc[2]), + out_res, + (lane_id + 64) * sizeof(floatx2_t), + 0, + 0); + } +}; + +// Helper function to pack fp16 and fp32 into int to further pass +// into mov_dpp and readfirstlane() +template + requires( + (sizeof(to_t) == 4 || sizeof(to_t) == 2) && + (sizeof(from_t) == 4 || sizeof(from_t) == 2)) +__device__ to_t pack(const from_t& v) { + to_t result = 0; + if constexpr (sizeof(to_t) == sizeof(from_t)) { + result = __builtin_bit_cast(to_t, v); + return result; + } + + memcpy(&result, &v, 2); + + return result; +} + +namespace reduce_op { +struct sum {}; +struct sub {}; +struct mul {}; +struct div {}; +} // namespace reduce_op + +template +struct reduce_op_sum_t { + __device__ data_t operator()(const data_t& a, const data_t& b) { + return a + b; + } +}; + +#define DPP_REDUCE(OP, TYPE) \ + __asm__ volatile( \ + "v_nop\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 quad_perm:[1,0,3,2]\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 quad_perm:[2,3,0,1]\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_shr:4\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_shr:8\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_bcast:15\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_bcast:31\n" \ + "v_nop\n" \ + "v_nop\n" \ + : "=v"(result) \ + : "0"(result)) + +#define DPP_REDUCE_F16_F32(OP) \ + if constexpr (std::is_same_v) { \ + DPP_REDUCE(OP, f32); \ + } \ + \ + if constexpr (std::is_same_v) { \ + DPP_REDUCE(OP, f16); \ + } + +template +__device__ __forceinline__ void generic_dpp_reduction(data_t& result) { + constexpr int row_mask = 0xf; + constexpr int bank_mask = 0xf; + constexpr bool bound_ctrl = false; + + reduce_op_t reduce_op; + + if constexpr (wave_size > 1) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0xb1, + row_mask, + bank_mask, + bound_ctrl))); // quad_perm:[1,0,3,2] + } + if constexpr (wave_size > 2) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x4e, + row_mask, + bank_mask, + bound_ctrl))); // quad_perm:[2,3,0,1] + } + if constexpr (wave_size > 4) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x114, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:4 + } + if constexpr (wave_size > 8) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 + } + if constexpr (wave_size > 16) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x142, + row_mask, + bank_mask, + bound_ctrl))); // row_bcast:15 + } + if constexpr (wave_size > 32) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x143, + row_mask, + bank_mask, + bound_ctrl))); // row_bcast:31 + } +} + +// Use corresponding assebly instruction for dpp reduction in case +// of trivial operation with an option to use custom operation +template +__device__ __forceinline__ void dpp_reduction(data_t& result) { +#if defined(__gfx942__) || defined(__gfx90a__) + if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(add); + return; + } else if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(sub); + return; + } else if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(mul); + return; + } else if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(div); + return; + } else { + generic_dpp_reduction(result); + } +#endif +} + +template +__device__ inline data_t wave_reduce(const data_t& thread_data) { + data_t result = thread_data; + + // now the reduced value is in the last lane of wave + dpp_reduction(result); + return pack( + __builtin_amdgcn_readlane(pack(result), wave_size - 1)); +} + +struct rowwise_adagrad_kernel_arg_t { + void* p_momentum; + float eps; + float learning_rate; + float weight_decay; + int64_t weight_decay_mode; +}; + +typedef struct { + uint32_t magic; + uint32_t shift; // actually 8 bit is enough +} magic_div_u32_t; + +static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) { + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for (shift = 0; shift < 32; shift++) + if ((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; +} + +// numer / denom = quotient, reminder +__device__ inline uint32_t magic_div_u32_run( + const magic_div_u32_t& mdiv, + const uint32_t& n) { + uint32_t tmp = __umulhi(n, mdiv.magic); + return (tmp + n) >> mdiv.shift; +} + +__device__ inline void magic_div_u32_run_with_mod( + const magic_div_u32_t& mdiv, + const uint32_t& n, + const uint32_t d, + uint32_t& quo, + uint32_t& rem) { + quo = magic_div_u32_run(mdiv, n); + rem = n - quo * d; +} +} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/test/tbe/cache/cache_common.py b/fbgemm_gpu/test/tbe/cache/cache_common.py index f74418669..48b1df66e 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache/cache_common.py @@ -33,11 +33,12 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_unavailable, optests, running_on_rocm + from test_utils import gpu_unavailable, optests, running_on_github, running_on_rocm else: from fbgemm_gpu.test.test_utils import ( # noqa: F401 gpu_unavailable, # noqa: F401 optests, # noqa: F401 + running_on_github, # noqa: F401 running_on_rocm, # noqa: F401 ) diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index 6250c529a..a19579bd9 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -43,6 +43,7 @@ generate_cache_tbes, gpu_unavailable, optests, + running_on_github, running_on_rocm, TestingStatsReporter, TestingStatsReporterConfig, @@ -77,6 +78,7 @@ def _compute_grad_output_shape( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -450,6 +452,7 @@ def assert_event_not_exist(event_name: str) -> None: @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -478,6 +481,7 @@ def test_cache_prefetch_pipeline( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -507,6 +511,7 @@ def test_cache_prefetch_pipeline_stream_1( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -588,6 +593,7 @@ def test_get_prefetch_passes( self.assertTrue(torch.equal(torch.full_like(output_tensor, 1), output_tensor)) @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_github) @given( L=st.integers(min_value=0, max_value=16), H=st.integers(min_value=512, max_value=1024), diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index cf7f0cbd8..5e65e40bf 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -58,6 +58,7 @@ additional_decorators, gpu_unavailable, optests, + skipIfNotRocm, TEST_WITH_ROCM, use_cpu_strategy, ) @@ -66,6 +67,7 @@ additional_decorators, gpu_unavailable, optests, + skipIfNotRocm, TEST_WITH_ROCM, use_cpu_strategy, ) @@ -1080,6 +1082,80 @@ def test_backward_optimizers_adagrad( # noqa C901 weight_decay_mode, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.sampled_from([16, 32, 40, 48, 64]), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=2, max_value=20), + weighted=st.booleans(), + mixed=st.just(False), + mixed_B=st.just(False), + optimizer=st.sampled_from( + [ + OptimType.EXACT_ROWWISE_ADAGRAD, + ] + ), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + ] + ), + use_cpu=st.just(False), + weight_decay_mode=st.sampled_from( + [ + WeightDecayMode.NONE, + WeightDecayMode.L2, + WeightDecayMode.DECOUPLE, + ] + ), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + @unittest.skipIf(*gpu_unavailable) + @skipIfNotRocm("Test only evaluates ROCm optimized kernels") + def test_new_bwd_kernel( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + mixed_B: bool, + optimizer: OptimType, + long_segments: bool, + pooling_mode: PoolingMode, + use_cpu: bool, + weight_decay_mode: WeightDecayMode, + ) -> None: + if ( + pooling_mode == PoolingMode.NONE + or optimizer != OptimType.EXACT_ROWWISE_ADAGRAD + ): + mixed_B = False + self.execute_backward_optimizers_( + T, + D, + B, + log_E, + L, + weighted, + mixed, + mixed_B, + optimizer, + long_segments, + pooling_mode, + use_cpu, + weight_decay_mode, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 853b2d070..e073f7a38 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -254,6 +254,26 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator +# pyre-fixme[3]: Return annotation cannot be `Any`. +def skipIfNotRocm( + reason: str = "Test currently doesn work only on the ROCm stack", +) -> Any: + # pyre-fixme[3]: Return annotation cannot be `Any`. + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + def decorator(fn: Callable) -> Any: + @wraps(fn) + # pyre-fixme[3]: Return annotation cannot be `Any`. + def wrapper(*args: Any, **kwargs: Any) -> Any: + if TEST_WITH_ROCM: + fn(*args, **kwargs) + else: + raise unittest.SkipTest(reason) + + return wrapper + + return decorator + + # pyre-fixme[3]: Return annotation cannot be `Any`. def skipIfRocmLessThan(min_version: int) -> Any: # pyre-fixme[3]: Return annotation cannot be `Any`. From 62f9db77c3a8c09cd167cc18930a49e94537ff5e Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Wed, 18 Dec 2024 13:57:22 -0800 Subject: [PATCH 04/16] Fix grid size overflow in generate_vbe_metadata (#3484) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3484 X-link: https://github.com/facebookresearch/FBGEMM/pull/565 Use 3D grid to reduce the risk of running into grid size overflow in generate_vbe_metadata Reviewed By: r-barnes Differential Revision: D66948760 fbshipit-source-id: 505d9b72e0d74d1707e4aa0ab9af48f26cf18b4a --- .../generate_vbe_metadata.cu | 85 ++++++++++++------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu index 81672f369..17905e0e1 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu @@ -33,27 +33,15 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( D_offsets, const int32_t D, const bool nobag, - FixedDivisor fd_max_B, - FixedDivisor fd_max_B_T, const int32_t info_B_num_bits) { - const auto r_b_t = blockIdx.x * blockDim.x + threadIdx.x; - const auto T = B_offsets.size(0) - 1; // Num tables - const auto R = B_offsets_rank_per_feature.size(1) - 1; // Num ranks - - int32_t b_t; - int32_t r; // Rank ID - int32_t t; // Table ID - int32_t b; // Relative sample ID in the rank-table matrix - - fd_max_B_T.DivMod(r_b_t, &r, &b_t); - if (r >= R) { - return; - } - - fd_max_B.DivMod(b_t, &t, &b); - if (t >= T) { - return; - } + // Relative sample ID in the rank-table matrix + const auto b = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + // Rank ID + const auto r = blockIdx.y; + // Table ID + const auto t = blockIdx.z; + // Num tables + const auto T = B_offsets.size(0) - 1; const auto B_start_r_t = B_offsets_rank_per_feature[t][r]; const auto B_r_t = B_offsets_rank_per_feature[t][r + 1] - B_start_r_t; @@ -61,22 +49,36 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( return; } + const auto* __restrict__ output_offsets_feature = + &output_offsets_feature_rank[r * T]; + const auto B_start_t = B_offsets[t]; - // Update b_t - b_t = B_start_t + B_start_r_t + b; - const auto D_ = nobag ? D : D_offsets[t + 1] - D_offsets[t]; - row_output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; + const auto b_t = + static_cast(B_start_t) + static_cast(B_start_r_t) + b; + const auto D_ = nobag ? D : (D_offsets[t + 1] - D_offsets[t]); + row_output_offsets[b_t] = + output_offsets_feature[t] + b * static_cast(D_); // Relative sample ID in the table const auto b_ = B_start_r_t + b; // b_t is always positive. *reinterpret_cast(&b_t_map[b_t]) = - (reinterpret_cast(&t)[0] << info_B_num_bits) | + (reinterpret_cast(&t)[0] << info_B_num_bits) | reinterpret_cast(&b_)[0]; } } // namespace +std::tuple get_max_grid_size(int device) { + static auto max_grid = [&]() -> std::tuple { + cudaDeviceProp prop; + C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, at::cuda::current_device())); + return {prop.maxGridSize[0], prop.maxGridSize[1], prop.maxGridSize[2]}; + }(); + + return max_grid; +} + /// Generate VBE metadata namely output_offsets and b_t_map /// /// row_output_offsets A 1D tensor that contains the output offset of each b @@ -121,7 +123,8 @@ generate_vbe_metadata( TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); TENSOR_NDIM_EQUALS(output_offsets_feature_rank, 1); - const int32_t T = B_offsets.numel() - 1; + const auto T = B_offsets.numel() - 1; + if (!nobag) { TENSOR_ON_CUDA_GPU(D_offsets); TENSORS_ON_SAME_DEVICE(B_offsets, D_offsets); @@ -129,6 +132,14 @@ generate_vbe_metadata( } const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; + TORCH_CHECK( + num_ranks > 0, "generate_vbe_metadata: Invalid num_ranks ", num_ranks); + TORCH_CHECK(T > 0, "generate_vbe_metadata: Invalid T ", T); + TORCH_CHECK( + max_B_feature_rank > 0, + "generate_vbe_metadata: Invalid max_B_feature_rank ", + max_B_feature_rank); + TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); @@ -138,13 +149,29 @@ generate_vbe_metadata( at::empty({total_B}, output_offsets_feature_rank.options()); Tensor b_t_map = at::empty({total_B}, B_offsets.options()); + const auto grid_dim_x = div_round_up(max_B_feature_rank, kMaxThreads); + const dim3 grid_size(grid_dim_x, num_ranks, T); + const auto& [max_grid_x, max_grid_y, max_grid_z] = + get_max_grid_size(at::cuda::current_device()); + TORCH_CHECK( + grid_size.x > 0 && grid_size.x <= max_grid_x, + "generate_vbe_metadata: Invalid grid_size.x ", + grid_size.x); + TORCH_CHECK( + grid_size.y > 0 && grid_size.y <= max_grid_y, + "generate_vbe_metadata: Invalid grid_size.y ", + grid_size.y); + TORCH_CHECK( + grid_size.z > 0 && grid_size.z <= max_grid_z, + "generate_vbe_metadata: Invalid grid_size.z ", + grid_size.z); + #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "generate_vbe_metadata_foreach_sample_kernel"; #endif - // Over allocate total number of threads to avoid using binary search generate_vbe_metadata_foreach_sample_kernel<<< - div_round_up(max_B_feature_rank * T * num_ranks, kMaxThreads), + grid_size, kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( @@ -157,8 +184,6 @@ generate_vbe_metadata( MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), D, nobag, - FixedDivisor(max_B_feature_rank), - FixedDivisor(max_B_feature_rank * T), info_B_num_bits); C10_CUDA_KERNEL_LAUNCH_CHECK(); From 0b1739c5321a1d4406ca1642048e122952ec46fa Mon Sep 17 00:00:00 2001 From: Fei Yu Date: Wed, 18 Dec 2024 14:52:06 -0800 Subject: [PATCH 05/16] Support config based bound check version via extended modes (#3454) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3454 X-link: https://github.com/facebookresearch/FBGEMM/pull/538 2/2 of enabling bounds check V2 for APS FM, following APS principles, we would like to surface the V2 switch up to the APS user config, hence in this diff we are extending existing BoundsCheckMode with V2 counterparts, and pass the version flag into the operator. this diff enabled v2 via backward compatible modes update with V2 prefix which is intuitive for user to switch More context can be found in https://docs.google.com/document/d/1hEhk2isMOXuWPyQJxiOzNq0ivfECsZUT7kT_IBmou_I/edit?tab=t.0#heading=h.q89rllowo3eb Reviewed By: sryap Differential Revision: D66512098 fbshipit-source-id: d2181a82462ca1c2c93360d4108766edeb38d000 --- .../split_table_batched_embeddings_ops_common.py | 6 ++++++ ...split_table_batched_embeddings_ops_training.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 069f66b02..82e9c9f06 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -67,6 +67,12 @@ class BoundsCheckMode(enum.IntEnum): IGNORE = 2 # No bounds checks. NONE = 3 + # IGNORE with V2 enabled + V2_IGNORE = 4 + # WARNING with V2 enabled + V2_WARNING = 5 + # FATAL with V2 enabled + V2_FATAL = 6 class EmbeddingSpecInfo(enum.IntEnum): diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d8667abe0..85ebd69f2 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -638,6 +638,20 @@ def __init__( # noqa C901 self.pooling_mode = pooling_mode self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE # If environment variable is set, it overwrites the default bounds check mode. + self.bounds_check_version: int = 1 + if bounds_check_mode.name.startswith("V2_"): + self.bounds_check_version = 2 + if bounds_check_mode == BoundsCheckMode.V2_IGNORE: + bounds_check_mode = BoundsCheckMode.IGNORE + elif bounds_check_mode == BoundsCheckMode.V2_WARNING: + bounds_check_mode = BoundsCheckMode.WARNING + elif bounds_check_mode == BoundsCheckMode.V2_FATAL: + bounds_check_mode = BoundsCheckMode.FATAL + else: + raise NotImplementedError( + f"Did not recognize V2 bounds check mode: {bounds_check_mode}" + ) + self.bounds_check_mode_int: int = int( os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) ) @@ -3352,6 +3366,7 @@ def prepare_inputs( b_t_map=b_t_map, info_B_num_bits=info_B_num_bits, info_B_mask=info_B_mask, + bounds_check_version=self.bounds_check_version, ) return indices, offsets, per_sample_weights, vbe_metadata From 804a499c24650c3ecdbb4f46cea850e80376a92a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 18 Dec 2024 16:48:42 -0800 Subject: [PATCH 06/16] Enable dynamic M grouped gemm (#3444) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3444 X-link: https://github.com/facebookresearch/FBGEMM/pull/530 This diff adds support for true dynamic M as is found in grouped_gemm. To do so, we add a new `zero_start_index_M` argument that must be provided by the user and indicates the number of non-zero M in each tensor. One nice thing about this approach is that we can now do a single kernel call to set up the gemm arguments. We make `zero_start_index_M` optional as it requires fixed N and K. When N and K vary across group, we use the previous static shape approach. Reviewed By: bradleyhd, jiawenliu64 Differential Revision: D66682886 fbshipit-source-id: 9c4554dba9becf33fcc87cd1b01266fead716916 --- .../gen_ai/bench/quantize_bench.py | 21 +- .../experimental/gen_ai/bench/quantize_ops.py | 77 ++++++- .../fp8_rowwise_grouped_gemm.hip | 199 ++++++++++++++---- .../gen_ai/src/quantize/quantize.cpp | 3 +- 4 files changed, 247 insertions(+), 53 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 850a8c257..fa8dc2142 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -62,6 +62,7 @@ def benchmark_grouped( kernels: Optional[List[str]] = None, bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, + use_cuda_graph: bool = True, ) -> Dict[str, Any]: num_groups = len(m) # Create input tensors. @@ -92,6 +93,8 @@ def benchmark_grouped( quantized_vals = quantize_op.quantize(A, B) # Compute the output given quantized values. output = quantize_op.compute(*quantized_vals) + # Some kernels may pad output, just take the first m values of each row. + output = [o[: m[i]] for i, o in enumerate(output)] # Compare the quantize op output to reference as a sanity check. sim_check: float = 0 for i in range(num_groups): @@ -107,14 +110,14 @@ def benchmark_grouped( B, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=True, + use_cuda_graph=use_cuda_graph, ) else: ms_runtime = quantize_op.benchmark( *quantized_vals, bench_quantize=False, use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=True, + use_cuda_graph=use_cuda_graph, ) # Print out results for this op. @@ -124,8 +127,8 @@ def benchmark_grouped( tflops += 2 * b[i] * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12 gbps += ( ( - quantized_vals[0][i].numel() - * quantized_vals[0][i].element_size() + quantized_vals[0][i][: m[i]].numel() + * quantized_vals[0][i][: m[i]].element_size() + quantized_vals[1][i].numel() * quantized_vals[1][i].element_size() + output[i].numel() * output[i].element_size() @@ -156,6 +159,7 @@ def benchmark( kernels: Optional[List[str]] = None, bench_quantize: bool = False, use_rotating_buffer_bench: bool = False, + use_cuda_graph: bool = True, ) -> Dict[str, Any]: # Create input tensors. if b > 1: @@ -192,12 +196,14 @@ def benchmark( B, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, ) else: ms_runtime = quantize_op.benchmark( *quantized_vals, bench_quantize=False, use_rotating_buffer_bench=use_rotating_buffer_bench, + use_cuda_graph=use_cuda_graph, ) # Print out results for this op. @@ -316,6 +322,7 @@ def main(args: Any): kernels, args.bench_quantize, args.use_rotating_buffer_bench, + not args.no_cuda_graph, ) benchmark_results.append(quantize_measurements) if args.export_csv: @@ -377,6 +384,12 @@ def invoke_main() -> None: help="If set, do grouped gemm. In this mode, M, N, and K are interpreted " "as the size of groups. The length of each must be the same.", ) + parser.add_argument( + "--no_cuda_graph", + default=False, + action="store_true", + help="If set, do not use cuda graph for benchmarking.", + ) parser.add_argument( "--use_rotating_buffer_bench", default=False, diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index d30878032..b749e800e 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -9,6 +9,7 @@ from typing import List, Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 +import numpy as np import torch import triton # @manual=//triton:triton @@ -467,24 +468,84 @@ class FP8RowwiseGroupedGemm(QuantizeOpBase): FP8 grouped matmul with rowwise scaling. """ + def quantize_fixed_nk(self, x, w): + group_size = len(x) + m_values = [i.shape[0] for i in x] + # Inputs for fixed nk mode must be contiguous, however in the benchmark + # script they typically are not. Do a little special processing to make them + # work. In practice this wont be needed. + # Start by padding along m dimension with zeros. + max_m = max(m_values) + xq = [ + torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0) + for i in x + ] + # Stack inputs into groups. + xq = torch.stack(xq).contiguous() + wq = torch.stack(w).contiguous() + # Allocate output tensor. + output = torch.empty( + [xq.shape[0], xq.shape[1], wq.shape[1]], + dtype=torch.bfloat16, + device=xq.device, + ) + # Apply quantization. + xq, x_scale = quantize_fp8_row(xq) + wq, w_scale = quantize_fp8_row(wq) + # View these unified tensors as lists of tensors. + xq = [x.squeeze() for x in xq.split(1, dim=0)] + wq = [w.squeeze() for w in wq.split(1, dim=0)] + output = [o.squeeze() for o in output.split(1, dim=0)] + x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)] + w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)] + + # Return processed tensors. + return ( + xq, + wq, + x_scale, + w_scale, + torch.tensor(m_values).to(dtype=torch.int32, device=xq[0].device), + output, + ) + def quantize(self, x, w): - # Quantize both input tensors. - # Handle both grouped and standard gemm. assert isinstance( x, (list, tuple) ), "Inputs to group gemm must be a list of tensors." + + # First check if N and K are fixed. + m_values = [i.shape[0] for i in x] + n_values = [i.shape[0] for i in w] + k_values = [i.shape[1] for i in w] + # if so, do specialized version of initialization. + if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1: + return self.quantize_fixed_nk(x, w) + + # Otherwise handle in eager mode. xq, x_scale = zip(*[quantize_fp8_row(i) for i in x]) wq, w_scale = zip(*[quantize_fp8_row(i) for i in w]) - return xq, wq, x_scale, w_scale - - def compute(self, xq, wq, x_scale, w_scale, kernel_name=None): + output = [ + torch.empty(m, n, device=xq[0].device, dtype=torch.bfloat16) + for m, n in zip(m_values, n_values) + ] + m_values = None + return xq, wq, x_scale, w_scale, m_values, output + + def compute(self, xq, wq, x_scale, w_scale, m_values, output, kernel_name=None): return torch.ops.fbgemm.f8f8bf16_rowwise_grouped( - xq, wq, x_scale, w_scale, kernel_name=kernel_name + xq, + wq, + x_scale, + w_scale, + zero_start_index_M=m_values, + output=output, + kernel_name=kernel_name, ) def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale = self.quantize(x, w) - return self.compute(xq, wq, x_scale, w_scale) + xq, wq, x_scale, w_scale, m_values, output = self.quantize(x, w) + return self.compute(xq, wq, x_scale, w_scale, m_values, output) @property def name(self) -> str: diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index 7474e759d..d59e6db95 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -80,63 +80,164 @@ RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } -__global__ void set_kernel_args_kernel( +__global__ void set_kernel_args_fixed_nk_kernel( KernelArguments* kernel_args, ADataType* XQ, BDataType* WQ, D0DataType* w_scale, D1DataType* x_scale, EDataType* output, + int32_t* prepad_M, int M, int N, - int K) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - // Each kernel annoyingly can only set the kernel args for one group. - // This could only be avoided with complicated memory management. - if (idx == 0) { - // Write kernel arguments directly to memory. + int K, + int group_count) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each thread is responsible for setting up the arguments for one group. + if (group_idx < group_count) { + // Compute offsets for this group. + int group_M = prepad_M[group_idx]; KernelArguments kernel_group_args = { - XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N}; - kernel_args[0] = kernel_group_args; + XQ + (group_idx * M * K), + WQ + (group_idx * N * K), + {w_scale + (group_idx * N), x_scale + (group_idx * M)}, + output + (group_idx * M * N), + group_M, + N, + K, + K, + K, + {0, 0}, + N}; + // Write kernel args to memory. + kernel_args[group_idx] = kernel_group_args; } } -void set_grouped_kernel_args( +at::Tensor get_grouped_kernel_args( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - at::Tensor kernel_args, + std::optional zero_start_index_M, std::vector output) { - TORCH_CHECK( - XQ.size() == WQ.size() && XQ.size() == x_scale.size() && - XQ.size() == w_scale.size(), - "All inputs must have the same number of groups."); int group_count = XQ.size(); - // We use the smallest reasonable block size since we effectively need only 1 thread. - int blockSize = 32; - int numBlocks = 1; auto stream = at::cuda::getCurrentHIPStream().stream(); - // Launch a kernel for each group to set kernel memory on device. - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int K = XQ[i].size(1); - int N = WQ[i].size(0); - // Launch kernel to set kernel arguments. - set_kernel_args_kernel<<>>( - reinterpret_cast( - reinterpret_cast(kernel_args.data_ptr()) + - (i * sizeof(KernelArguments))), - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - reinterpret_cast(w_scale[i].data_ptr()), - reinterpret_cast(x_scale[i].data_ptr()), - reinterpret_cast(output[i].data_ptr()), + // Get space on device for the kernel argument tensor. + at::Tensor kernel_args = at::empty( + {static_cast(group_count * sizeof(KernelArguments))}, + XQ[0].options().dtype(at::kByte)); + + // There are two different modes for this kernel. + // When zero_start_index_M is provided, we assume that data is sequential and + // that N and K are constants. This allows a more efficient kernel + // launch and is best suited to MOE use cases where M is truly dynamic. + // When zero_start_index_M is not provided, we assume M, N, and K can all vary + // and set them for each group. It is important to note that this does not + // work well with cuda graphs and runtime dynamism so if possible we recommend + // using zero_start_index_M. + + if (zero_start_index_M.has_value()) { + // Make sure zero_start_index_M is configured properly. + at::Tensor prepad_M = zero_start_index_M.value(); + // Confirm M is on the proper device. + TORCH_CHECK( + XQ[0].device() == prepad_M.device(), + "zero_start_index_M and inputs must be on the same device."); + TORCH_CHECK( + prepad_M.size(0) == group_count, + "zero_start_index_M must have an entry for each group."); + + // We assume that M, N, and K are fixed across groups. + // The actual m values are sstored in the passed M tensor. + int M = XQ[0].size(0); + int K = XQ[0].size(1); + int N = WQ[0].size(0); + + // Make sure that inputs are allocated in sequential memory as required by + // this mode. + for (int i = 1; i < group_count; i++) { + // Check that all inputs are allocated directly following preceding input. + TORCH_CHECK( + XQ[i].data_ptr() == + (reinterpret_cast(XQ[i - 1].data_ptr()) + (M * K)), + "Inputs must be sequential in memory to support dynamic M, but XQ is not."); + TORCH_CHECK( + WQ[i].data_ptr() == + (reinterpret_cast(WQ[i - 1].data_ptr()) + (N * K)), + "Inputs must be sequential in memory to support dynamic M, but WQ is not."); + TORCH_CHECK( + x_scale[i].data_ptr() == + (reinterpret_cast(x_scale[i - 1].data_ptr()) + (M)), + "Inputs must be sequential in memory to support dynamic M, but x_scale is not."); + TORCH_CHECK( + w_scale[i].data_ptr() == + (reinterpret_cast(w_scale[i - 1].data_ptr()) + (N)), + "Inputs must be sequential in memory to support dynamic M, but w_scale is not."); + TORCH_CHECK( + output[i].data_ptr() == + (reinterpret_cast(output[i - 1].data_ptr()) + + (M * N)), + "Inputs must be sequential in memory to support dynamic M, but output is not."); + } + + // Launch a kernel that sets kernel argument memory. + int const blockSize = std::min(1024, group_count); + int const numBlocks = (group_count + blockSize - 1) / blockSize; + set_kernel_args_fixed_nk_kernel<<>>( + reinterpret_cast(kernel_args.data_ptr()), + reinterpret_cast(XQ[0].data_ptr()), + reinterpret_cast(WQ[0].data_ptr()), + reinterpret_cast(w_scale[0].data_ptr()), + reinterpret_cast(x_scale[0].data_ptr()), + reinterpret_cast(output[0].data_ptr()), + reinterpret_cast(prepad_M.data_ptr()), M, N, - K); + K, + group_count); + return kernel_args; + } else { + // When running in eager mode, we assume we can directly interact with host + // values. + // Note that this version is not supported with cuda graphs. + TORCH_CHECK( + stream == 0, + "f8f8bf16_rowwise_grouped eager mode is not supported with cuda graphs."); + + std::vector ggemm_kargs; + ggemm_kargs.reserve(group_count); + + // Iterate over inputs and get group information. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int K = XQ[i].size(1); + int N = WQ[i].size(0); + KernelArguments group_args = { + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + {reinterpret_cast(w_scale[i].data_ptr()), + reinterpret_cast(x_scale[i].data_ptr())}, + reinterpret_cast(output[i].data_ptr()), + M, + N, + K, + K, + K, + {0, 0}, + N}; + ggemm_kargs.push_back(group_args); + } + // Copy data onto device. + hipMemcpy( + kernel_args.data_ptr(), // Destination + ggemm_kargs.data(), // Source + sizeof(KernelArguments) * group_count, // Number of bytes + hipMemcpyHostToDevice); // Copy Type } + + return kernel_args; } std::vector f8f8bf16_rowwise_grouped( @@ -144,6 +245,7 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + std::optional zero_start_index_M = std::nullopt, std::optional> output = std::nullopt, std::optional kernel_name = std::nullopt) { // Check that input datatypes are valid. @@ -167,6 +269,9 @@ std::vector f8f8bf16_rowwise_grouped( TORCH_CHECK( w.dtype() == at::kFloat8_e4m3fnuz, "Inputs must be type float8_e4m3fnuz."); + TORCH_CHECK( + w.size(0) >= 512 && w.size(1) >= 512, + "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); } for (at::Tensor xs : x_scale) { TORCH_CHECK(xs.dtype() == at::kFloat, "Scales must be float32."); @@ -194,16 +299,30 @@ std::vector f8f8bf16_rowwise_grouped( Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); } } else { - for (int i = 0; i < group_count; i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); - Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16))); + // Two modes for allocating output. When m_values is provided, we need + // the output tensor to be contiguous and can assume M, N, and K are the + // same across groups. Otherwise, we can allocate each output separately. + if (zero_start_index_M.has_value()) { + int M = XQ[0].size(0); + int N = WQ[0].size(0); + // Fill output with zeros to simplify integration. This prevents nans from + // showing up in the tensor. + at::Tensor Y_full = + at::zeros({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + Y = at::unbind(Y_full, 0); + } else { + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int N = WQ[i].size(0); + Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16))); + } } } // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = at::empty({1000}, XQ[0].options().dtype(at::kByte)); - set_grouped_kernel_args(XQ, WQ, x_scale, w_scale, kernel_args, Y); + at::Tensor kernel_args = + get_grouped_kernel_args(XQ, WQ, x_scale, w_scale, zero_start_index_M, Y); // If provided a specific kernel implementation, dispatch to it. if (kernel_name.has_value()) { diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 5cc5e851d..78bcde568 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -93,6 +93,7 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + std::optional zero_start_index_M = std::nullopt, std::optional> output = std::nullopt, std::optional kernel_name = std::nullopt); std::vector get_f8f8bf16_rowwise_grouped_kernels(); @@ -188,7 +189,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { #endif #ifdef USE_ROCM m.def( - "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); + "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor? zero_start_index_M=None, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); m.def("get_f8f8bf16_rowwise_grouped_kernels() -> str[]"); m.impl( "get_f8f8bf16_rowwise_grouped_kernels", From cc1bad168715ed9d0074515f1418e60eb8ffe66a Mon Sep 17 00:00:00 2001 From: Jingyuan Fan Date: Wed, 18 Dec 2024 17:12:55 -0800 Subject: [PATCH 07/16] fix mx4 illegal memory access (#3509) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3509 X-link: https://github.com/facebookresearch/FBGEMM/pull/593 when calaculting num_thread and group_per_thread to distribute work, rounding gets accumulated and effectively expand the input space. for example (the new UT), when input tensor is (1, 2^31 - 8), ``` a.numel: 2147483640 num_threads: 46341 groups_per_thread: 1449 num_groups: 67108864 num_threads * groups_per_threads= 67148109 > num_groups ``` in kernel, when we try to access memory, input_start = num_threads * groups_per_threads * pid, so when pid is large, we end up visiting data outside the input Reviewed By: jwfromm Differential Revision: D67369392 fbshipit-source-id: 62c28fe3a94911a10921e233ff5ae42097e9dbb4 --- fbgemm_gpu/fbgemm_gpu/triton/quantize.py | 4 +++- fbgemm_gpu/test/quantize/mx4_test.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index e86f0a18e..4b02a3a50 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -183,6 +183,7 @@ def _kernel_quantize_mx4( # When theres no padding we can simplify indexing. else: padded_input_offset = input_offset + # Load a block of values. a = tl.load( A + padded_input_offset, @@ -434,7 +435,8 @@ def triton_quantize_mx4( rand_bits = None # Check if we need to use int64 for indexing. - use_int64 = a.numel() > 2**31 - 1 + use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1 + # Invoke triton quantization kernel over rows. grid = (num_threads,) _kernel_quantize_mx4[grid]( diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index 643e44fb9..03b160811 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -304,6 +304,21 @@ def test_mx4_index_overflow(self) -> None: # We just need to check that everything ran without an illegal memory access. assert mx_dequantized[0] == 0 + # pyre-fixme[56]: + @unittest.skipIf( + not ( + torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32 + ), + "Test requires a gpu with at least 32GB of memory.", + ) + def test_mx4_index_overflow_large_input(self) -> None: + """Tests that mx4 quantization kernels can handle inputs that would overflow int32 indices.""" + large_input = torch.zeros((1, 2**31 - 2**3), dtype=torch.float32).to("cuda") + mx_quantized = fp32_to_mx4(large_input, 32) + mx_dequantized = mx4_to_fp32(mx_quantized, 32) + # We just need to check that everything ran without an illegal memory access. + assert mx_dequantized[0][0] == 0 + if __name__ == "__main__": unittest.main() From eaa0961a681c68d4c3fe095aa0b9252832415d8f Mon Sep 17 00:00:00 2001 From: Jingyuan Fan Date: Wed, 18 Dec 2024 17:15:44 -0800 Subject: [PATCH 08/16] support quantize_fp8_row for up to 4d non contiguous tensor (#3508) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3508 X-link: https://github.com/facebookresearch/FBGEMM/pull/589 reland D66990975 with fix for the NaN issued observed during LLaMa4 17B model run with fp8_rowwise FFN Specifically, offset was not properly updated when loading/storing data. Reviewed By: jwfromm Differential Revision: D67303282 fbshipit-source-id: 334d32019424de6daff4261b1d5ebe3c977fdabd --- .../experimental/gemm/test/fp8_gemm_test.py | 53 +++++++++--- .../experimental/gemm/triton_gemm/fp8_gemm.py | 85 +++++++++++++------ 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py index b950dafab..5ba74185b 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py @@ -6,6 +6,7 @@ # pyre-strict +import itertools import unittest from typing import Optional, Tuple @@ -37,38 +38,62 @@ def _test_quantize_fp8_row( device: torch.device, output_device: Optional[torch.device] = None, use_scale_ub: bool = False, + transpose_inputs: bool = False, ) -> None: a = torch.randn(shape, dtype=torch.bfloat16, device=device) - + inputs = [a] + # if transpose_inputs is true, get all possible dimension combinations + # of the input tensor and transposes each pair + if transpose_inputs: + dims = range(a.ndim) + for dim1, dim2 in itertools.combinations(dims, 2): + dims_list = list(dims) + dims_list[dim1], dims_list[dim2] = dims_list[dim2], dims_list[dim1] + inputs.append(a.permute(dims_list)) scale_ub = ( torch.tensor([1200], dtype=torch.float, device=device) if use_scale_ub else None ) + for input_a in inputs: + a_fp8, a_scale = quantize_fp8_row( + input_a, + scale_ub=scale_ub, + use_triton=use_triton, + output_device=output_device, + ) - a_fp8, a_scale = quantize_fp8_row( - a, scale_ub=scale_ub, use_triton=use_triton, output_device=output_device - ) + # Undo scaling. + a_torch = a_fp8.to(torch.bfloat16) + broadcast_shape = list(a_torch.shape[:-1]) + [-1] + a_torch *= a_scale.view(broadcast_shape) - # Undo scaling. - a_torch = a_fp8.to(torch.bfloat16) - broadcast_shape = list(a_torch.shape[:-1]) + [-1] - a_torch *= a_scale.view(broadcast_shape) - - self.assertTrue( - torch.allclose( - a.to(device=output_device), a_torch, atol=2e-1, rtol=1e-1 + self.assertTrue( + torch.allclose( + input_a.to(device=output_device), a_torch, atol=2e-1, rtol=1e-1 + ) ) - ) - _test_quantize_fp8_row((2, 3), True, torch.device("cuda")) + for n_col in range(1, 9000, 100): + _test_quantize_fp8_row((2, n_col), True, torch.device("cuda")) # Test with batched input. _test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda")) + _test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cuda")) + # Test with non-contiguous input + _test_quantize_fp8_row( + (4, 2, 3), True, torch.device("cuda"), transpose_inputs=True + ) + _test_quantize_fp8_row( + (6, 4, 2, 3), True, torch.device("cuda"), transpose_inputs=True + ) _test_quantize_fp8_row((2, 3), True, torch.device("cuda"), use_scale_ub=True) + # Test with cpu _test_quantize_fp8_row((2, 3), False, torch.device("cpu"), torch.device("cuda")) _test_quantize_fp8_row( (2, 3), False, torch.device("cpu"), torch.device("cuda"), use_scale_ub=True ) + _test_quantize_fp8_row((4, 2, 3), True, torch.device("cpu")) + _test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cpu")) def test_scale_fp8_row(self) -> None: def _test_scale_fp8_row( diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 2cb9b3b47..23b501d45 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -1945,7 +1945,7 @@ def prep_matmul( Config({"BLOCK_SIZE": 4096}), Config({"BLOCK_SIZE": 8192}), ], - key=["N"], + key=["K"], ) @triton.jit def _kernel_quantize_fp8_row( @@ -1953,12 +1953,18 @@ def _kernel_quantize_fp8_row( A_scale, A_fp8, scale_ub, + B, M, N, + K, + stride_ab, stride_am, stride_an, + stride_ak, + stride_ob, stride_om, stride_on, + stride_ok, TL_FP8_DTYPE: tl.constexpr, MAX_FP8: tl.constexpr, EPS: tl.constexpr, @@ -1977,16 +1983,22 @@ def _kernel_quantize_fp8_row( * Better tiling schemes. Args: - A (Tensor): [m, n] higher precision input tensor. - A_scale (Tensor): [m] reciprocal scale tensor per row. - A_fp8 (Tensor): [m, n] fp8 scaled tensor. A_fp8 = A / a_scale + A (Tensor): higher precision input tensor of 4 dimension. + A_scale (Tensor): [B * M * N] reciprocal scale tensor per row. + A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale scale_ub (Tensor): [1] Maximum value allowed for scale. - M (int): Number of rows. - N (int): Number of columns. + B (int): Size of dimenion 0 + M (int): Size of dimenion 1 + N (int): Size of dimenion 2 + K (int): Size of dimenion 3 + stride_ab (int): Stride of b dimension of A. stride_am (int): Stride of m dimension of A. stride_an (int): Stride of n dimension of A. + stride_ak (int): Stride of k dimension of A. + stride_ob (int): Stride of b dimension of output. stride_om (int): Stride of m dimension of output. stride_on (int): Stride of n dimension of output. + stride_ok (int): Stride of k dimension of output. TL_FP8_DTYPE (tl.dtype): Target fp8 datatype. MAX_FP8 (float): Maxmimum expressible value for FP8. EPS (float): Epsilon value for numerical stability. @@ -2000,16 +2012,25 @@ def _kernel_quantize_fp8_row( if USE_INT64: pid = pid.to(tl.int64) n_offset = tl.arange(0, BLOCK_SIZE) + a_offset_base = ( + pid // (M * N) * stride_ab + + (pid % (M * N)) // N * stride_am + + (pid % (M * N)) % N * stride_an + ) + a_fp8_offset_base = ( + pid // (M * N) * stride_ob + + (pid % (M * N)) // N * stride_om + + (pid % (M * N)) % N * stride_on + ) # Calculate max. cur_max = 0.0 - for _k in range(0, tl.cdiv(N, BLOCK_SIZE)): + for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): a = tl.load( - A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0 + A + a_offset_base + n_offset * stride_ak, mask=n_offset < K, other=0.0 ) tile_max = tl.max(tl.abs(a)) cur_max = tl.maximum(tile_max, cur_max) - n_offset += BLOCK_SIZE # Clamp max value appropriately. @@ -2022,9 +2043,10 @@ def _kernel_quantize_fp8_row( a_scale = MAX_FP8 / cur_max tl.store(A_scale + pid, 1.0 / a_scale) n_offset = tl.arange(0, BLOCK_SIZE) - for _k in range(0, tl.cdiv(N, BLOCK_SIZE)): + + for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): a = tl.load( - A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0 + A + a_offset_base + n_offset * stride_ak, mask=n_offset < K, other=0.0 ) a_fp8 = a * a_scale # Clamp A to fp8 range to make sure there's no overflow. @@ -2033,7 +2055,7 @@ def _kernel_quantize_fp8_row( a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8) a_fp8.to(TL_FP8_DTYPE) tl.store( - A_fp8 + pid * stride_om + n_offset * stride_on, a_fp8, mask=n_offset < N + A_fp8 + a_fp8_offset_base + n_offset * stride_ok, a_fp8, mask=n_offset < K ) n_offset += BLOCK_SIZE @@ -2045,20 +2067,18 @@ def triton_quantize_fp8_row( Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings. Args: - a (Tensor): [m, n] higher precision input tensor. + a (Tensor): higher precision input tensor of 4 dimension. scale_ub (Tensor): Maximum allowed value for scale. Returns: torch.Tensor: fp8 scaled tensor. torch.Tensor: reciprocal scale tensor per row. """ - a_shape = a.shape - a = a.view(-1, a.size(-1)) # Get constant values. pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants() - num_rows = a.shape[0] + num_rows = a.numel() // a.shape[-1] a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device) - a_fp8 = torch.empty((a.shape[0], a.shape[1]), device=a.device, dtype=pt_dtype) + a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype) # If input tensor is sufficiently large, we need to use int64 indexing. use_int64 = a.numel() > (2**31 - 1) @@ -2070,10 +2090,16 @@ def triton_quantize_fp8_row( scale_ub, a.shape[0], a.shape[1], + a.shape[2], + a.shape[3], a.stride(0), a.stride(1), + a.stride(2), + a.stride(3), a_fp8.stride(0), a_fp8.stride(1), + a_fp8.stride(2), + a_fp8.stride(3), TL_FP8_DTYPE=tl_dtype, MAX_FP8=max_fp8, EPS=eps, @@ -2081,7 +2107,7 @@ def triton_quantize_fp8_row( USE_INT64=use_int64, ) - return a_fp8.view(a_shape), a_scale + return a_fp8, a_scale @torch.library.custom_op("triton::quantize_fp8_row", mutates_args=()) @@ -2095,7 +2121,7 @@ def quantize_fp8_row( Quantize a to fp8 with row-wise scalings and optionally move to output device. Args: - a (Tensor): Input high precision tensor. + a (Tensor): Input high precision tensor. Required to have no more than 4 dimension scale_ub (Tensor): Maximum allowed value for scale. use_triton (bool): Whether to use triton kernel or pytorch. output_device (torch.device): Device to optionally move the scaled tensors to. @@ -2104,36 +2130,41 @@ def quantize_fp8_row( torch.Tensor: fp8 scaled tensor. torch.Tensor: The reciprocal scale tensor per row. """ - a_shape = a.shape - a = a.view(-1, a.size(-1)) + if a.device == torch.device("cpu"): logger.info("Triton does not support cpu, falling back to torch ops.") use_triton = False if use_triton: - aq, a_scale = triton_quantize_fp8_row(a, scale_ub) - return aq.view(a_shape), a_scale + assert ( + a.dim() <= 4 + ), "Only up to 4 dimension input tensor is supported if use_triton is True" + a_shape = a.shape + while a.dim() < 4: + a = a.unsqueeze(0) + a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub) + return a_fp8.view(a_shape), a_scale # else use pytorch implementation. if not output_device: output_device = a.device # Get constants. pt_dtype, _, max_fp8, eps = get_fp8_constants() - row_max: torch.Tensor = torch.max(torch.abs(a), dim=1)[0] + row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0] # Apply clamping. if scale_ub is not None: row_max = torch.clamp(row_max, min=eps, max=scale_ub.item()) else: # pyre-ignore[6]: Incompatible parameter type [6] row_max = torch.clamp(row_max, min=eps) - a_scale = torch.empty((a.shape[0]), dtype=torch.float32, device=output_device) + a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device) a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore - a_fp8 = a * a_scale[:, None] # pyre-ignore + a_fp8 = a * a_scale[..., None] # pyre-ignore # Cast and move data to output device (for cpu weight loading). a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype) a_scale = a_scale.to(output_device) # pyre-ignore del a - return a_fp8.view(a_shape), 1 / a_scale # pyre-ignore + return a_fp8, 1 / a_scale # pyre-ignore @quantize_fp8_row.register_fake From ca4ea00d4c471d752dde1789fa90e8dcbacfe4f3 Mon Sep 17 00:00:00 2001 From: Feng Shi Date: Thu, 19 Dec 2024 10:15:10 -0800 Subject: [PATCH 09/16] MX4 group size configuration for pyper (#3516) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/597 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3516 added pyper configuration for mx4 goup size. Reviewed By: irobert0126, renganxu Differential Revision: D67407064 fbshipit-source-id: a23765777879491836fcb9f1a00ba8f1e1b26b76 --- fbgemm_gpu/fbgemm_gpu/quantize_comm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index fb7737ec2..a03cf965b 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -183,6 +183,8 @@ def __init__( self._loss_scale = loss_scale self._is_fwd = is_fwd self._row_dim: int = -1 if row_dim is None else row_dim + if self._comm_precision == SparseType.MX4: + self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim def encode( self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None @@ -252,11 +254,12 @@ def quantized_dtype(self) -> torch.dtype: def create_context(self) -> Optional[QuantizationContext]: # fp8 rowwise is activated when row_dim > 0 - if ( - self._comm_precision == SparseType.FP8 - or self._comm_precision == SparseType.MX4 - ): + if self._comm_precision == SparseType.FP8: return QuantizationContext(self._row_dim) + if self._comm_precision == SparseType.MX4: + return QuantizationContext( + row_dim=self._row_dim, mx_group_size=self._row_dim + ) # int8 rowwise is default return QuantizationContext() From 7d1c763551120ccfa4592fbf4c198b78ec94761b Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Thu, 19 Dec 2024 11:10:28 -0800 Subject: [PATCH 10/16] deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/test/quantize (#3512) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/596 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3512 Reviewed By: avikchaudhuri Differential Revision: D67381311 fbshipit-source-id: 345264f99d6f4b77508b4ea95fe20b3482ad1f04 --- .../experimental/gen_ai/test/quantize/quantize_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 14999c96d..3f121aaf9 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -134,7 +134,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = TestModule().cuda() # bf16 required here - _ = torch.export.export(model, (torch.randn(32, 32).to(torch.bfloat16).cuda(),)) + _ = torch.export.export( + model, (torch.randn(32, 32).to(torch.bfloat16).cuda(),), strict=True + ) def test_f8f8bf16_export(self) -> None: class TestModule(torch.nn.Module): @@ -161,7 +163,7 @@ def forward(self, xq: torch.Tensor, wq: torch.Tensor) -> torch.Tensor: fp8_dtype = torch.float8_e4m3fnuz xq = torch.randn(M, K).to(fp8_dtype).cuda() wq = torch.randn(N, K).to(fp8_dtype).cuda() - _ = torch.export.export(model, (xq, wq)) + _ = torch.export.export(model, (xq, wq), strict=True) @unittest.skipIf( From a75d8fe03003b24fc0ac635c723fe0c12c7ffc98 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 19 Dec 2024 14:07:59 -0800 Subject: [PATCH 11/16] OSS build fixes (#3514) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/599 - Fix the CMake minimum version in conda install - Fix issue with missing `librhash.so.0` when installing `gcc` - Fix build issues with bazel, and upgrade bazel version to latest Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3514 Reviewed By: spcyppt Differential Revision: D67435456 Pulled By: q10 fbshipit-source-id: 2fe53c59251df3633771b2b6b0d97c15a33df7b6 --- .github/scripts/utils_build.bash | 11 +++++++++-- MODULE.bazel | 2 +- WORKSPACE.bazel | 12 ++++++++++++ bench/CMakeLists.txt | 2 +- fbgemm_gpu/CMakeLists.txt | 2 +- test/CMakeLists.txt | 2 +- 6 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index f4096caa1..69cc36ca9 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -14,7 +14,7 @@ ################################################################################ setup_bazel () { - local bazel_version="${1:-6.1.1}" + local bazel_version="${1:-8.0.0}" echo "################################################################################" echo "# Setup Bazel" echo "#" @@ -294,12 +294,13 @@ install_build_tools () { # $CONDA_PREFIX/include directory, which is required for FBGEMM tests # # - ncurses is needed to silence libtinfo6.so errors for ROCm+Clang builds + # - rhash is needed bc newer versions of GXX package don't come packaged with this library anymore # # shellcheck disable=SC2086 (exec_with_retries 3 conda install ${env_prefix} -c conda-forge -y \ bazel \ click \ - cmake \ + 'cmake>=3.30' \ hypothesis \ jinja2 \ make \ @@ -307,9 +308,15 @@ install_build_tools () { ninja \ openblas \ patchelf \ + rhash \ scikit-build \ wheel) || return 1 + echo "[INSTALL] Adding symlink librhash.so.0, which is needed by Cmake ..." + # shellcheck disable=SC2155,SC2086 + local conda_prefix=$(conda run ${env_prefix} printenv CONDA_PREFIX) + (print_exec ln -s "${conda_prefix}/lib/librhash.so" "${conda_prefix}/lib/librhash.so.0") || return 1 + # For some reason, the build package for Python 3.12 is missing from Conda, so # we have to install through PyPI instead. # diff --git a/MODULE.bazel b/MODULE.bazel index 02f67d0ee..43f2f033d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -5,4 +5,4 @@ module(name = "fbgemm") -bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "bazel_skylib", version = "1.7.1") diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 9a9f9535a..6161eef18 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -7,6 +7,15 @@ workspace(name = "fbgemm") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +http_archive( + name = "bazel_skylib", + sha256 = "bc283cdfcd526a52c3201279cda4bc298652efa898b10b4db0837dc51652756f", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.7.1/bazel-skylib-1.7.1.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.7.1/bazel-skylib-1.7.1.tar.gz", + ], +) + http_archive( name = "com_google_googletest", strip_prefix = "googletest-1.14.0", @@ -29,3 +38,6 @@ new_local_repository( build_file = "@//external:asmjit.BUILD", path = "external/asmjit", ) + +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") +bazel_skylib_workspace() diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt index bd2575100..385fe6cbb 100644 --- a/bench/CMakeLists.txt +++ b/bench/CMakeLists.txt @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -cmake_minimum_required(VERSION 3.16 FATAL_ERROR) +cmake_minimum_required(VERSION 3.25 FATAL_ERROR) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_EXTENSIONS OFF) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 922b3f207..20fe463d8 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -8,7 +8,7 @@ # CMake Prelude ################################################################################ -cmake_minimum_required(VERSION 3.25.0 FATAL_ERROR) +cmake_minimum_required(VERSION 3.25 FATAL_ERROR) set(CMAKEMODULES ${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules) set(FBGEMM_GPU ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index da0be7fe7..bf2d8dbeb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -cmake_minimum_required(VERSION 3.21 FATAL_ERROR) +cmake_minimum_required(VERSION 3.25 FATAL_ERROR) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_EXTENSIONS OFF) From 6da23d51c2075d485e2593ea66b02ee0d3f49562 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 20 Dec 2024 09:53:55 -0800 Subject: [PATCH 12/16] Fix index overflow for superlarge inputs (#3519) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3519 X-link: https://github.com/facebookresearch/FBGEMM/pull/601 For extremely large inputs, we found that boundary check values were sufficiently large that they were causing integer overflow. This resulted in triton triggering masking for all loads and stores which lead to garbage outputs. This diff fixes the issue by more carefully doing int64 upcasting for super large tensors. After this change, all super large tests pass. Reviewed By: qchip Differential Revision: D67495115 fbshipit-source-id: dcea639a7343d5782823f103a0572870aa496b05 --- fbgemm_gpu/fbgemm_gpu/triton/quantize.py | 29 +++++++----- fbgemm_gpu/test/quantize/mx4_test.py | 57 +++++++++++++++++++----- 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index 4b02a3a50..9835c351b 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -141,17 +141,22 @@ def _kernel_quantize_mx4( EXPONENT_OVERFLOW_THRESHOLD: tl.constexpr = (1 << EBITS) - 1 # type: ignore[Incompatible variable type] IMPLICIT_1_MASK = (1 << (MBITS_IMPLICIT - 1)) - 1 RAND_MASK: tl.constexpr = (1 << (FP32_EXP_OFFSET - MBITS)) - 1 # type: ignore[Incompatible variable type] - # Boundaries for writing to output tensor. - PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type] - NUM_GROUPS = M * GROUPS_PER_ROW - OUTPUT_CHUNK_SIZE = (GROUPS_PER_THREAD * GROUP_SIZE) // 2 + GROUPS_PER_THREAD - OUTPUT_SIZE = (GROUP_SIZE * NUM_GROUPS) // 2 + NUM_GROUPS # Get the current thread number. pid = tl.program_id(0) # For very large inputs, we need to use int64 indexes. This is slower but necessary. if USE_INT64: pid = pid.to(tl.int64) + M = tl.cast(M, tl.int64) + K = tl.cast(K, tl.int64) + GROUPS_PER_THREAD = tl.cast(GROUPS_PER_THREAD, tl.int64) + + # Boundaries for writing to output tensor. + PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type] + NUM_GROUPS = M * GROUPS_PER_ROW + OUTPUT_CHUNK_SIZE = (GROUPS_PER_THREAD * GROUP_SIZE) // 2 + GROUPS_PER_THREAD + OUTPUT_SIZE = (GROUP_SIZE * NUM_GROUPS) // 2 + NUM_GROUPS + # Find starting offsets for this thread. These are calculated before adjusting for padding. input_start = pid * (GROUPS_PER_THREAD * GROUP_SIZE) output_start = pid * OUTPUT_CHUNK_SIZE @@ -501,16 +506,20 @@ def _kernel_dequantize_mx4( MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type] FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type] PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type] - # Boundaries for reading input and writing to output tensor. - INPUT_CHUNK_SIZE = GROUPS_PER_THREAD * PACKED_GROUP_SIZE - OUTPUT_CHUNK_SIZE = GROUPS_PER_THREAD * GROUP_SIZE - OUTPUT_SIZE = (M // PACKED_GROUP_SIZE) * GROUP_SIZE # Get the current thread number. pid = tl.program_id(0) # For very large tensors, use int64 for indexing. This is slower but necessary. if USE_INT64: pid = pid.to(tl.int64) + M = tl.cast(M, tl.int64) + GROUPS_PER_THREAD = tl.cast(GROUPS_PER_THREAD, tl.int64) + + # Boundaries for reading input and writing to output tensor. + INPUT_CHUNK_SIZE = GROUPS_PER_THREAD * PACKED_GROUP_SIZE + OUTPUT_CHUNK_SIZE = GROUPS_PER_THREAD * GROUP_SIZE + OUTPUT_SIZE = (M // PACKED_GROUP_SIZE) * GROUP_SIZE + # Find the starting offsets for this thread. input_start = pid * (GROUPS_PER_THREAD * PACKED_GROUP_SIZE) exp_start = input_start + GROUP_SIZE // 2 @@ -617,7 +626,7 @@ def triton_dequantize_mx4( output_elems = num_groups * group_size out = torch.empty([output_elems], device=a.device, dtype=torch.float) # Check if we need to use int64 for indexing. - use_int64 = a.numel() > 2**31 - 1 + use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1 # Invoke triton dequantization kernel over rows. grid = (num_threads,) _kernel_dequantize_mx4[grid]( diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index 03b160811..eddc6368b 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -304,20 +304,57 @@ def test_mx4_index_overflow(self) -> None: # We just need to check that everything ran without an illegal memory access. assert mx_dequantized[0] == 0 - # pyre-fixme[56]: @unittest.skipIf( not ( - torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32 + torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 64 ), - "Test requires a gpu with at least 32GB of memory.", + "Test requires a gpu with at least 64GB of memory.", ) - def test_mx4_index_overflow_large_input(self) -> None: - """Tests that mx4 quantization kernels can handle inputs that would overflow int32 indices.""" - large_input = torch.zeros((1, 2**31 - 2**3), dtype=torch.float32).to("cuda") - mx_quantized = fp32_to_mx4(large_input, 32) - mx_dequantized = mx4_to_fp32(mx_quantized, 32) - # We just need to check that everything ran without an illegal memory access. - assert mx_dequantized[0][0] == 0 + # pyre-fixme[56]: + @given( + shape=st.sampled_from([[1024 * 1024, 2020]]), + group_size=st.sampled_from([32]), + rounding_mode=st.sampled_from([RoundingMode.even]), + magnitude=st.sampled_from([1e6]), + mx4_format=st.sampled_from([(2, 1)]), + device=st.sampled_from(["cuda"]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_mx4_large_cases( + self, + shape: List[int], + group_size: int, + rounding_mode: RoundingMode, + magnitude: int, + mx4_format: Tuple[int, int], + device: str, + ) -> None: + """Test correctness of mx4 routines with random inputs and shapes that overflow int32.""" + # We only want to consider total sizes that are divisible by group_size. + ebits, mbits = mx4_format + + # Generate a random input with the specified magnitude. + input = torch.randn(shape, device=device, dtype=torch.float32) * magnitude + + # Perform quant then dequant to check that proper shape is maintained and + # outputs are reasonably correct. + mx_quantized = fp32_to_mx4( + input, group_size, rounding_mode=rounding_mode, ebits=ebits, mbits=mbits + ) + mx_dequantized = mx4_to_fp32(mx_quantized, group_size, ebits=ebits, mbits=mbits) + + # If the rows of input are not divisible by group_size, we expect the output + # to be padded. + if input.shape[-1] % group_size != 0: + pad = group_size - (input.shape[-1] % group_size) + input = torch.nn.functional.pad(input, (0, pad)) + + # Check that output shape matches input shape. + assert mx_dequantized.shape == input.shape + + # Check that values are reasonably close, based on expected variance. + # I give quite a bit of wiggle room to make sure this isnt flaky. + torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2) if __name__ == "__main__": From 5d6dd9299f4c95bebf9b1494e4fe86c701430abc Mon Sep 17 00:00:00 2001 From: Marko Radmilac Date: Fri, 20 Dec 2024 15:58:03 -0800 Subject: [PATCH 13/16] Async initialization of RockDB SSD tensors (#3520) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/602 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3520 In the diff D66465811 we introduced a bulk initialization function `_insert_all_kv` for ssd tensors. However, large tensors take a long time to fully initialize, and ideally this can happen in the background so it doesn't increase TTFB of the training jobs. This change does exactly that, moves this initialization to a separate thread, allowing other initialization in the training job, like reading data, to happen concurrently. In order to avoid pushing synchronization to the user space, this change introduces getter and setter for ssd_db, which ensure initialization is fully done before weights are used. Reviewed By: duduyi2013, drdarshan, jiayulu Differential Revision: D67480511 fbshipit-source-id: 6faf54621fc6e26a9791ac23e48aa7890329077a --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 58 +++++++++++++++++-- .../tbe/ssd/ssd_split_tbe_training_test.py | 11 ++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 067e8c89f..911e1fc49 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -14,6 +14,7 @@ import logging import os import tempfile +import threading import time from math import log2 from typing import Any, Callable, List, Optional, Tuple, Type, Union @@ -206,6 +207,7 @@ def __init__( f"TBE will allocate a UVM buffer with is_host_mapped={uvm_host_mapped}" ) self.bulk_init_chunk_size = bulk_init_chunk_size + self.lazy_init_thread: threading.Thread | None = None # Buffers for bounds check self.register_buffer( @@ -444,7 +446,7 @@ def __init__( ) # pyre-fixme[4]: Attribute must be annotated. # pyre-ignore[16] - self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( + self._ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( ssd_directory, ssd_rocksdb_shards, ssd_rocksdb_shards, @@ -469,11 +471,11 @@ def __init__( if self.bulk_init_chunk_size > 0: self.ssd_uniform_init_lower: float = ssd_uniform_init_lower self.ssd_uniform_init_upper: float = ssd_uniform_init_upper - self._insert_all_kv() + self._lazy_initialize_ssd_tbe() else: # pyre-fixme[4]: Attribute must be annotated. # pyre-ignore[16] - self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper( + self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper( [host[0] for host in ps_hosts], [host[1] for host in ps_hosts], tbe_unique_id, @@ -707,6 +709,51 @@ def __init__( self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name) self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name) + @property + # pyre-ignore + def ssd_db(self): + """Intercept the ssd_db property to make sure it is fully initialized before use. + This is needed because random weights are initialized in a separate thread""" + if self.lazy_init_thread is not None: + self.lazy_init_thread.join() + self.lazy_init_thread = None + logging.info("lazy ssd tbe initialization completed, weights are ready") + + return self._ssd_db + + @ssd_db.setter + # pyre-ignore + def ssd_db(self, value): + """Setter for ssd_db property.""" + if self.lazy_init_thread is not None: + # This is essentially a copy assignment operation, since the thread is + # already existing, and we are assigning a new ssd_db to it. Complete + # the initialization first, then assign the new value to it. + self.lazy_init_thread.join() + self.lazy_init_thread = None + logging.info( + "lazy ssd tbe initialization completed, ssd_db will now get overridden" + ) + + self._ssd_db = value + + def _lazy_initialize_ssd_tbe(self) -> None: + """ + Initialize the SSD TBE with random weights. This function should only be + called once at initialization time. + """ + if self.bulk_init_chunk_size > 0: + self.lazy_init_thread = threading.Thread(target=self._insert_all_kv) + # pyre-ignore + self.lazy_init_thread.start() + logging.info( + f"lazy ssd tbe initialization started since bulk_init_chunk_size is set to {self.bulk_init_chunk_size}" + ) + else: + logging.debug( + "bulk_init_chunk_size is not set, skipping lazy initialization" + ) + @torch.jit.ignore def _insert_all_kv(self) -> None: """ @@ -719,6 +766,7 @@ def _insert_all_kv(self) -> None: total_dim0 = 0 for dim0, _ in self.embedding_specs: total_dim0 += dim0 + start_ts = time.time() chunk_tensor = torch.empty( chunk_size, @@ -734,7 +782,9 @@ def _insert_all_kv(self) -> None: ) cpu_tensor.copy_(chunk_tensor, non_blocking=False) rand_val = cpu_tensor[:actual_dim0, :] - self.ssd_db.set_range_to_storage(rand_val, row_offset, actual_dim0) + # This code is intentionally not calling through the getter property + # to avoid the lazy initialization thread from joining with itself. + self._ssd_db.set_range_to_storage(rand_val, row_offset, actual_dim0) end_ts = time.time() elapsed = int((end_ts - start_ts) * 1e6) logging.info(f"TBE bulk initialization took {elapsed:_} us") diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index faf751c1a..b99e7c950 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -304,6 +304,17 @@ def generate_ssd_tbes( bulk_init_chunk_size=bulk_init_chunk_size, ).cuda() + if bulk_init_chunk_size > 0: + self.assertIsNotNone( + emb.lazy_init_thread, + "if bulk_init_chunk_size > 0, lazy_init_thread must be set and it should not be force-synchronized yet", + ) + + # By doing the check for ssd_db being None below, we also access the getter property of ssd_db, which will + # force the synchronization of lazy_init_thread, and then reset it to None. + if emb.ssd_db is not None: + self.assertIsNone(emb.lazy_init_thread) + # A list to keep the CPU tensor alive until `set` (called inside # `set_cuda`) is complete. Note that `set_cuda` is non-blocking # asynchronous From 12a2246a19aeeedcb1ca7797797387ad7d269990 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 20 Dec 2024 19:40:18 -0800 Subject: [PATCH 14/16] Explicitly update manylinux version (#3521) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/605 - Set the --plat-name explicitly to `manylinux_2_28` Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3521 Reviewed By: spcyppt Differential Revision: D67538191 Pulled By: q10 fbshipit-source-id: b2f8cc0b81c7e46bd2e380c03a6fa68da11786d6 --- .github/scripts/fbgemm_gpu_build.bash | 32 +++++++++++++++++-- .github/scripts/utils_build.bash | 1 + .github/scripts/utils_pytorch.bash | 4 +++ cmake/modules/GpuCppLibrary.cmake | 2 +- .../BuildInstructions.rst | 2 +- 5 files changed, 36 insertions(+), 5 deletions(-) diff --git a/.github/scripts/fbgemm_gpu_build.bash b/.github/scripts/fbgemm_gpu_build.bash index e5044159c..a23bcd445 100644 --- a/.github/scripts/fbgemm_gpu_build.bash +++ b/.github/scripts/fbgemm_gpu_build.bash @@ -324,9 +324,13 @@ __build_fbgemm_gpu_set_python_plat_name () { fi elif [[ $KERN_NAME == 'Linux' ]]; then - # manylinux2014 is specified, bc manylinux1 does not support aarch64 - # See https://github.com/pypa/manylinux - export python_plat_name="manylinux2014_${MACHINE_NAME}" + # NOTE: manylinux2014 is the minimum platform tag specified, bc + # manylinux1 does not support aarch64; see https://github.com/pypa/manylinux + # + # As of 2024-12, upstream torch has switched to manylinux_2_28: + # https://dev-discuss.pytorch.org/t/pytorch-linux-wheels-switching-to-new-wheel-build-platform-manylinux-2-28-on-november-12-2024/2581 + # https://github.com/pytorch/pytorch/pull/143423 + export python_plat_name="manylinux_2_28_${MACHINE_NAME}" else echo "[BUILD] Unsupported OS platform: ${KERN_NAME}" @@ -519,6 +523,24 @@ run_fbgemm_gpu_postbuild_checks () { __verify_library_symbols || return 1 } +run_fbgemm_gpu_audit_wheel () { + fbgemm_wheel="$1" + if [ "$fbgemm_wheel" == "" ]; then + echo "Usage: ${FUNCNAME[0]} FBGEMM_WHEEL_PATH" + echo "Example(s):" + echo " ${FUNCNAME[0]} dist/fbgemm_gpu_nightly_cpu-2024.12.20-cp39-cp39-manylinux_2_28_x86_64.whl" + return 1 + fi + + echo "################################################################################" + echo "[BUILD] Wheel Audit: ${fbgemm_wheel}" + echo "" + + print_exec conda run --no-capture-output ${env_prefix} auditwheel show "${fbgemm_wheel}" + echo "" + echo "################################################################################" +} + ################################################################################ # FBGEMM_GPU Build Functions ################################################################################ @@ -578,6 +600,10 @@ build_fbgemm_gpu_package () { # Run checks on the built libraries (run_fbgemm_gpu_postbuild_checks "${fbgemm_variant}") || return 1 + for wheelfile in dist/*.whl; do + run_fbgemm_gpu_audit_wheel "${wheelfile}" + done + echo "[BUILD] Enumerating the built wheels ..." print_exec ls -lth dist/*.whl || return 1 diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 69cc36ca9..a2e9152ed 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -298,6 +298,7 @@ install_build_tools () { # # shellcheck disable=SC2086 (exec_with_retries 3 conda install ${env_prefix} -c conda-forge -y \ + auditwheel \ bazel \ click \ 'cmake>=3.30' \ diff --git a/.github/scripts/utils_pytorch.bash b/.github/scripts/utils_pytorch.bash index f2ecba7b9..d66d0f2a6 100644 --- a/.github/scripts/utils_pytorch.bash +++ b/.github/scripts/utils_pytorch.bash @@ -143,6 +143,10 @@ install_pytorch_pip () { local installed_pytorch_version=$(conda run ${env_prefix} python -c "import torch; print(torch.__version__)") echo "[CHECK] NOTE: The installed version is: ${installed_pytorch_version}" + echo "[CHECK] NOTE: Checking _GLIBCXX_USE_CXX11_ABI ..." + # shellcheck disable=SC2086,SC2155 + conda run ${env_prefix} python -c 'import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI); print(torch.compiled_with_cxx11_abi())' + if [ "$pytorch_variant_type" == "cuda" ]; then # Ensure that the PyTorch-CUDA headers are properly installed (test_filepath "${env_name}" cuda_cmake_macros.h) || return 1 diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index d2ebde6d2..8c2d34e0a 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -354,7 +354,7 @@ function(gpu_cpp_library) "CUDA_SPECIFIC_SRCS:" "${args_CUDA_SPECIFIC_SRCS}" " " - "HIP_SPECIFIC_SRCS" + "HIP_SPECIFIC_SRCS:" "${args_HIP_SPECIFIC_SRCS}" " " "OTHER_SRCS:" diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-development/BuildInstructions.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-development/BuildInstructions.rst index 25949b6f1..f538d4e19 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-development/BuildInstructions.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-development/BuildInstructions.rst @@ -506,7 +506,7 @@ Python platform name must first be properly set: export ARCH=$(uname -m) # Set the Python platform name for the Linux case - export python_plat_name="manylinux2014_${ARCH}" + export python_plat_name="manylinux_2_28_${ARCH}" # For the macOS (x86_64) case export python_plat_name="macosx_10_9_${ARCH}" # For the macOS (arm64) case From fe980ab54a6e28818d81c8694b6564e7f804418b Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Fri, 20 Dec 2024 19:58:50 -0800 Subject: [PATCH 15/16] Add new optimizer state `row_counter` for Adam [Backend] (#3342) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3342 X-link: https://github.com/facebookresearch/FBGEMM/pull/436 A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as the iteration counter when a row (an index) occurs and used to do bias correction. Without rowwise bias correction (existing Adam), ``` m_hat_t = m_t / (1.0 - powf(beta1, iter)); v_hat_t = v_t / (1.0 - powf(beta2, iter)); ``` With rowwise bias correction enabled. ``` // when index `idx` occurs _row_counter = row_counter[idx] + 1; m_hat_t = m_t / (1.0 - powf(beta1, _row_counter)); v_hat_t = v_t / (1.0 - powf(beta2, _row_counter)); ``` This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories. ------- **__The functionality is not set by default.__** Frontend: D64848802 To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. ``` extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True) emb_op = SplitTableBatchedEmbeddingBagsCodegen ( embedding_specs=[ (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) ], optimizer=OptimType.Adam extra_optimizer_config=extra_optimizer_config, ... ) ``` ------ **__Performance from Kineto__** (unweighted) ``` Baseline* | default** | enabled*** forward | cpu | 2.293 s | 2.188 s | 2.043 s | cuda | 12.512 ms | 12.539 ms | 12.547 ms backward | cpu | 69.861 ms | 66.546 ms | 65.880 ms | cuda | 103.429 ms | 103.395 ms | 103.130 ms ``` \* Baseline: before changes \** default: default setting; use_bias_correction = False \*** enabled: use_bias_correction = True Reviewed By: sryap Differential Revision: D64808460 fbshipit-source-id: 9706bcc4601b370f4d67c81b833fb1cd46377a6c --- .../codegen/genscript/optimizer_args.py | 353 ++++++++++++++---- fbgemm_gpu/codegen/genscript/optimizers.py | 37 +- .../codegen/genscript/torch_type_utils.py | 1 + ...dding_backward_split_host_cpu_template.cpp | 6 +- ...embedding_backward_split_host_template.cpp | 6 +- ...dding_split_host_pt2_autograd_template.cpp | 44 ++- ..._embedding_codegen_lookup_invoker.template | 6 +- fbgemm_gpu/test/tbe/training/forward_test.py | 7 + 8 files changed, 367 insertions(+), 93 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index f290e0f62..669b1a44f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -39,6 +39,7 @@ class OptimizerArgsSetItem: name: str default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL ph_tys: Optional[List[ArgType]] = None # placeholder types + is_optional: bool = False # optional variable # Alias b/c the name is too long @@ -192,6 +193,42 @@ def schema_tensor_list_arg_no_default(name: str) -> str: return f"Tensor[] {name}" +def bool_arg(name: str, default: bool = False) -> str: + return f"bool {name} = {'true' if default else 'false'}" + + +def bool_arg_no_default(name: str) -> str: + return f"bool {name}" + + +def schema_bool_arg(name: str, default: bool = False) -> str: + return f"bool {name} = {default}" + + +def optional_tensor_arg(name: str) -> str: + return f"std::optional {name} = std::nullopt" + + +def optional_tensor_arg_no_default(name: str) -> str: + return f"std::optional {name}" + + +def schema_optional_tensor_arg(name: str) -> str: + return f"Tensor? {name} = None" + + +def optional_tensorlist_arg(name: str) -> str: + return f"std::optional {name} = std::nullopt" + + +def optional_tensorlist_arg_no_default(name: str) -> str: + return f"std::optional {name}" + + +def schema_optional_tensorlist_arg(name: str) -> str: + return f"Tensor[]? {name} = None" + + def make_kernel_arg( # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType, @@ -199,9 +236,6 @@ def make_kernel_arg( default: Union[int, float, None], pass_by_ref: bool = False, ) -> str: - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref), ArgType.INT_TENSOR: lambda x: int_tensor_arg(x, pass_by_ref=pass_by_ref), @@ -224,14 +258,15 @@ def make_kernel_arg( if default is not None else float_arg_no_default ), + ArgType.BOOL: ( + (lambda x: bool_arg(x, default=bool(default))) + if default is not None + else bool_arg_no_default + ), }[ty](name) def make_kernel_arg_constructor(ty: ArgType, name: str) -> str: - # learning_rate is a float in kernels - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: acc_cache_tensor_arg_constructor, ArgType.INT_TENSOR: int_tensor_arg_constructor, @@ -240,14 +275,11 @@ def make_kernel_arg_constructor(ty: ArgType, name: str) -> str: ArgType.INT: lambda x: x, ArgType.FLOAT: lambda x: x, ArgType.SYM_INT: lambda x: x, + ArgType.BOOL: lambda x: x, }[ty](name) def make_cpu_kernel_arg(ty: ArgType, name: str, default: Union[int, float]) -> str: - # learning_rate is a float in kernels - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False), ArgType.INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False), @@ -256,14 +288,11 @@ def make_cpu_kernel_arg(ty: ArgType, name: str, default: Union[int, float]) -> s ArgType.INT: lambda x: int64_arg(x, default=int(default)), ArgType.FLOAT: lambda x: float_arg(x, default=default), ArgType.SYM_INT: lambda x: sym_int_arg(x, default=int(default)), + ArgType.BOOL: lambda x: bool_arg(x, default=bool(default)), }[ty](name) def make_cpu_kernel_arg_constructor(ty: ArgType, name: str) -> str: - # learning_rate is a float in kernels - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False), ArgType.INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False), @@ -274,17 +303,53 @@ def make_cpu_kernel_arg_constructor(ty: ArgType, name: str) -> str: ArgType.INT: lambda x: x, ArgType.FLOAT: lambda x: x, ArgType.SYM_INT: lambda x: x, + ArgType.BOOL: lambda x: x, }[ty](name) def make_function_arg( - ty: ArgType, name: str, default: Optional[Union[int, float]] + ty: ArgType, + name: str, + default: Optional[Union[int, float]], + is_optional: bool = False, ) -> str: return { - ArgType.TENSOR: tensor_arg, - ArgType.INT_TENSOR: tensor_arg, - ArgType.LONG_TENSOR: tensor_arg, - ArgType.PLACEHOLDER_TENSOR: tensor_arg, + ArgType.TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), + ArgType.INT_TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), + ArgType.LONG_TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), + ArgType.PLACEHOLDER_TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), ArgType.INT: ( (lambda x: int64_arg(x, default=int(default))) if default is not None @@ -300,6 +365,11 @@ def make_function_arg( if default is not None else sym_int_arg_no_default ), + ArgType.BOOL: ( + (lambda x: bool_arg(x, default=bool(default))) + if default is not None + else bool_arg_no_default + ), }[ty](name) @@ -313,10 +383,11 @@ def make_function_schema_arg(ty: ArgType, name: str, default: Union[int, float]) ArgType.FLOAT: lambda x: float_arg(x, default=default), # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[float, int]`. ArgType.SYM_INT: lambda x: schema_sym_int_arg(x, default=default), + ArgType.BOOL: lambda x: schema_bool_arg(x, default=bool(default)), }[ty](name) -def _extend_tensor_str(name: str, is_cuda: bool) -> str: +def _extend_tensor_str(name: str, is_cuda: bool, optional: bool) -> str: """ Take a tensor name and extend for cpu or cuda @@ -327,10 +398,12 @@ def _extend_tensor_str(name: str, is_cuda: bool) -> str: Returns: String of extended tensors """ + opt = "?" if optional else "" + default = " = None" if optional else "" if is_cuda: - return f"Tensor {name}_dev, Tensor {name}_uvm, Tensor {name}_placements, Tensor {name}_offsets" + return f"Tensor{opt} {name}_dev {default}, Tensor{opt} {name}_uvm {default}, Tensor{opt} {name}_placements {default}, Tensor{opt} {name}_offsets {default}" else: - return f"Tensor {name}_host, Tensor {name}_placements, Tensor {name}_offsets" + return f"Tensor{opt} {name}_host {default}, Tensor{opt} {name}_placements {default}, Tensor{opt} {name}_offsets {default}" def extend_tensors_args_from_str(args_str: str, example_tensor: str) -> str: @@ -350,13 +423,18 @@ def extend_tensors_args_from_str(args_str: str, example_tensor: str) -> str: num_tensors = args_str.count("Tensor") if num_tensors > 0: is_cuda = "_dev" in example_tensor - args = args_str.split(", ", num_tensors) - tensors_args = args[:num_tensors] - non_tensors_args = args[-1] - extended_tensors_args = [ - _extend_tensor_str(t.split(" ")[1], is_cuda) for t in tensors_args - ] - return ", ".join(extended_tensors_args + [non_tensors_args]) + args = args_str.split(", ") + extended_tensors_args = [] + for arg in args: + ty = arg.split(" ")[0] + name = arg.split(" ")[1] + if ty == "Tensor": + extended_tensors_args.append(_extend_tensor_str(name, is_cuda, False)) + elif ty == "Tensor?": + extended_tensors_args.append(_extend_tensor_str(name, is_cuda, True)) + else: + extended_tensors_args.append(arg) + return ", ".join(extended_tensors_args) else: return args_str @@ -378,6 +456,9 @@ def make_split_function_args_v1(args_str: str) -> str: args_str.replace("int", "int64_t") .replace("SymInt", "c10::SymInt") .replace("float", "double") + .replace("Tensor?", "std::optional") + .replace("None", "std::nullopt") + .replace("False", "false") ) @@ -386,20 +467,49 @@ def make_ivalue_cast(ty: ArgType) -> str: ArgType.INT: "toInt", ArgType.FLOAT: "toDouble", ArgType.SYM_INT: "toSymInt", + ArgType.BOOL: "toBool", }[ty] +def reorder_args(split_arg_spec: List[OptimItem]) -> List[OptimItem]: + """ + Reorder such that tensor arguments come first. This is used in backend, wrapper and kernels where tensors are no longer optional. + We need to pass tensor arguments before other types which have default arguments. + + Parameters: + split_arg_spec (List[OptimItem]): List of argument items + + Return: + reordered of split_arg_spec + """ + tensor_args = [] + non_tensor_args = [] + for s in split_arg_spec: + if s.ty in ( + ArgType.TENSOR, + ArgType.INT_TENSOR, + ArgType.LONG_TENSOR, + ArgType.PLACEHOLDER_TENSOR, + ): + tensor_args.append(s) + else: + non_tensor_args.append(s) + + return tensor_args + non_tensor_args + + @dataclass class PT2ArgsSet: split_function_args: List[str] split_function_arg_names: List[str] split_function_schemas: List[str] - split_saved_tensor_list: List[str] + split_saved_tensorlist: List[str] + split_saved_tensorlist_optional: List[str] @staticmethod # pyre-ignore[3] def create( - split_arg_spec: List[OptimItem], + arg_spec: List[OptimItem], ): """ PT2ArgsSet.create() is a method that creates different formats given the optimization arguments @@ -410,24 +520,28 @@ def create( e.g., instead of passing `momentum_host, `momentum_dev`, etc, we pass `momentum` Parameters: - split_arg_spec: List[OptimItem] - list of argument specs + arg_spec: List[OptimItem] - list of argument specs Returns: PT2ArgsSet object with the following attributes: - split_function_args: List[str] - List of function arguments + split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions + Tensors will be packed and pass as TensorList e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay']. - split_function_arg_names: List[str] - List of argument names + split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions e.g., ['momentum1', 'eps', 'weight_decay']. - split_function_schemas: List[str] - List of arguments in the schema format + split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. - split_saved_tensor_list: List[str] - List of saved tensors for the split function - e.g., ['momentum1']. + split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in + PT2 autograd function. e.g., ['momentum1']. + split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional + and will be unpacked in PT2 autograd function e.g., ['row_counter']. """ split_function_arg_names = [] split_function_args = [] split_function_schemas = [] - split_saved_tensor_list = [] - for s in split_arg_spec: + split_saved_tensorlist = [] + split_saved_tensorlist_optional = [] + for s in arg_spec: if s.name == "learning_rate_tensor": split_function_arg_names.append(s.name) split_function_args.append(tensor_arg(s.name)) @@ -438,16 +552,20 @@ def create( ArgType.LONG_TENSOR, ArgType.PLACEHOLDER_TENSOR, ): - name = s.name.rsplit("_", 1)[0] - if name not in split_function_arg_names: - split_function_arg_names.append(name) - split_saved_tensor_list.append(name) + name = s.name + split_function_arg_names.append(name) + if s.is_optional: + split_function_args.append(optional_tensorlist_arg(name)) + split_function_schemas.append(schema_optional_tensorlist_arg(name)) + split_saved_tensorlist_optional.append(name) + else: split_function_args.append( tensor_list_arg_no_default(name, pass_by_ref=False) ) split_function_schemas.append( schema_tensor_list_arg_no_default(name) ) + split_saved_tensorlist.append(name) else: split_function_arg_names.append(s.name) split_function_args.append(make_function_arg(s.ty, s.name, s.default)) @@ -458,7 +576,8 @@ def create( split_function_args=split_function_args, split_function_arg_names=split_function_arg_names, split_function_schemas=split_function_schemas, - split_saved_tensor_list=split_saved_tensor_list, + split_saved_tensorlist=split_saved_tensorlist, + split_saved_tensorlist_optional=split_saved_tensorlist_optional, ) @@ -489,6 +608,9 @@ class OptimizerArgs: placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]] unified_pt2: PT2ArgsSet split_kernel_arg_names: List[str] + split_function_args_autograd: List[str] + split_function_arg_names_autograd: List[str] + split_saved_tensors_optional: List[str] split_function_args_v1: Optional[str] = None split_function_schemas_v1: Optional[str] = None @@ -499,6 +621,30 @@ def create( arg_spec: List[OptimItem], additional_spec: Optional[dict[str, Any]] = None, ): + # Keep the argument order for forward/backward compatibility + # Arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors + # This is used in lookup and autograd functions + frontend_split_arg_spec = split_arg_spec.copy() + + has_optional_tensors: bool = False + # Create another spec for kernels where learning_rate is float + # This is used in kernels + kernel_split_arg_spec = split_arg_spec.copy() + for i, s in enumerate(kernel_split_arg_spec): + if s.name == "learning_rate_tensor": + # pyre-ignore[6] + kernel_split_arg_spec[i] = OptimItem(ArgType.FLOAT, "learning_rate") + if s.is_optional: + has_optional_tensors = True + + # Optional tensors are converted to tensor in autograd functions + # Reorganize arguments for wrapper, backend and kernel functions + if has_optional_tensors: + # Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors, + split_arg_spec = reorder_args(split_arg_spec) + # Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors + kernel_split_arg_spec = reorder_args(kernel_split_arg_spec) + # Compute placeholder tensor combinations ph_tensor_names = [ s.name for s in arg_spec if s.ty == ArgType.PLACEHOLDER_TENSOR @@ -529,6 +675,36 @@ def create( ArgType.PLACEHOLDER_TENSOR, ) ] + # Create empty tensors based on weights + # weights name convention is different between v1 and pt2 unified interface (v2) + # i.e., host_weights, dev_weights uvm_weights, weights_placements, weights_offsets in v1 and weights_{} in v2 + # This is only used in v1, so we fix the name based on v1 + create_empty_tensor = { + "host": "host_weights.options()", + "dev": "dev_weights.options()", + "uvm": "uvm_weights.options()", + "placements": "weights_placements.options()", + "offsets": "weights_offsets.options()", + } + split_saved_tensors_optional = [ + ( + f"{s.name}.has_value() ? {s.name}.value() : at::empty(" + + "{0}, " + + create_empty_tensor[s.name.rsplit("_", 1)[1]] + + ")" + if s.is_optional + else s.name + ) + for s in split_arg_spec + if s.ty + in ( + ArgType.TENSOR, + ArgType.INT_TENSOR, + ArgType.LONG_TENSOR, + ArgType.PLACEHOLDER_TENSOR, + ) + ] + # Create function args and schemas for V1 interface for backward compatibility # V1 interface refers to separate CPU/CUDA lookup functions # e.g., split_embedding_codegen_lookup_{}_funtion and split_embedding_codegen_lookup_{}_funtion_cpu) @@ -548,20 +724,22 @@ def create( return OptimizerArgs( # GPU kernel args split_kernel_args=[ - make_kernel_arg(s.ty, s.name, s.default) for s in split_arg_spec + make_kernel_arg(s.ty, s.name, s.default) for s in kernel_split_arg_spec ], split_kernel_args_no_defaults=[ - make_kernel_arg(s.ty, s.name, None) for s in split_arg_spec + make_kernel_arg(s.ty, s.name, None) for s in kernel_split_arg_spec ], split_kernel_arg_constructors=[ - make_kernel_arg_constructor(s.ty, s.name) for s in split_arg_spec + make_kernel_arg_constructor(s.ty, s.name) for s in kernel_split_arg_spec ], # CPU kernel args split_cpu_kernel_args=[ - make_cpu_kernel_arg(s.ty, s.name, s.default) for s in split_arg_spec + make_cpu_kernel_arg(s.ty, s.name, s.default) + for s in kernel_split_arg_spec ], split_cpu_kernel_arg_constructors=[ - make_cpu_kernel_arg_constructor(s.ty, s.name) for s in split_arg_spec + make_cpu_kernel_arg_constructor(s.ty, s.name) + for s in kernel_split_arg_spec ], # Function args split_function_args=[ @@ -574,8 +752,11 @@ def create( split_tensors=[ s.name for s in arg_spec - if (s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR)) - and s.name != "learning_rate_tensor" + if ( + s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR) + and s.name != "learning_rate_tensor" + and not s.is_optional + ) ], split_tensor_types={ s.name: ( @@ -584,7 +765,11 @@ def create( else (s.name + "_ph_t") ) for s in arg_spec - if s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR) + if ( + s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR) + and s.name != "learning_rate_tensor" + and not s.is_optional + ) }, split_saved_tensors=split_saved_tensors, saved_data=[ @@ -600,16 +785,22 @@ def create( split_variables=["Variable()" for _ in split_arg_spec], split_ref_kernel_args=[ make_kernel_arg(s.ty, s.name, s.default, pass_by_ref=True) - for s in split_arg_spec + for s in kernel_split_arg_spec ], placeholder_tensor_names=ph_tensor_names, placeholder_type_combos=ph_combos, - unified_pt2=PT2ArgsSet.create(split_arg_spec), + unified_pt2=PT2ArgsSet.create(arg_spec), # learning rate remains float in kernels split_kernel_arg_names=[ "learning_rate" if s.name == "learning_rate_tensor" else s.name - for s in split_arg_spec + for s in kernel_split_arg_spec ], + split_function_args_autograd=[ + make_function_arg(s.ty, s.name, s.default, s.is_optional) + for s in frontend_split_arg_spec + ], + split_function_arg_names_autograd=[s.name for s in frontend_split_arg_spec], + split_saved_tensors_optional=split_saved_tensors_optional, split_function_args_v1=split_function_args_v1, split_function_schemas_v1=split_function_schemas_v1, ) @@ -636,7 +827,7 @@ def create_optim_args( for s in arg_spec: # no cpu/cuda extension for learning_rate if ( - s.ty in (ArgType.FLOAT, ArgType.INT, ArgType.SYM_INT) + s.ty in (ArgType.FLOAT, ArgType.INT, ArgType.SYM_INT, ArgType.BOOL) or s.name == "learning_rate_tensor" ): # pyre-fixme[19]: Expected 1 positional argument. @@ -651,13 +842,21 @@ def create_optim_args( def extend_for_cpu(spec: OptimItem) -> List[OptimItem]: name = spec.name default = spec.default + is_optional = spec.is_optional return [ # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.TENSOR, f"{name}_host", default), + OptimItem(ArgType.TENSOR, f"{name}_host", default, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default), + OptimItem( + ArgType.INT_TENSOR, + f"{name}_placements", + default, + is_optional=is_optional, + ), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default), + OptimItem( + ArgType.LONG_TENSOR, f"{name}_offsets", default, is_optional=is_optional + ), ] @staticmethod @@ -666,15 +865,23 @@ def extend_for_cuda(spec: OptimItem) -> List[OptimItem]: default = spec.default ty = spec.ty ph_tys = spec.ph_tys + is_optional = spec.is_optional return [ # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_dev", default, ph_tys), + OptimItem(ty, f"{name}_dev", default, ph_tys, is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_uvm", default, ph_tys), + OptimItem(ty, f"{name}_uvm", default, ph_tys, is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default), + OptimItem( + ArgType.INT_TENSOR, + f"{name}_placements", + default, + is_optional=is_optional, + ), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default), + OptimItem( + ArgType.LONG_TENSOR, f"{name}_offsets", default, is_optional=is_optional + ), ] @staticmethod @@ -683,17 +890,25 @@ def extend_for_any(spec: OptimItem) -> List[OptimItem]: default = spec.default ty = spec.ty ph_tys = spec.ph_tys + is_optional = spec.is_optional return [ # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.TENSOR, f"{name}_host", default), + OptimItem(ArgType.TENSOR, f"{name}_host", default, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_dev", default, ph_tys), + OptimItem(ty, f"{name}_dev", default, ph_tys, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_uvm", default, ph_tys), + OptimItem(ty, f"{name}_uvm", default, ph_tys, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default), + OptimItem( + ArgType.INT_TENSOR, + f"{name}_placements", + default, + is_optional=is_optional, + ), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default), + OptimItem( + ArgType.LONG_TENSOR, f"{name}_offsets", default, is_optional=is_optional + ), ] @staticmethod diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index a17506131..3d5de9bb4 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1001,6 +1001,29 @@ def partial_rowwise_lamb() -> Dict[str, Any]: def adam() -> Dict[str, Any]: + split_precomputation = """ + at::acc_type* __restrict__ row_counter; + at::acc_type _row_counter = iter; + if (use_rowwise_bias_correction) { + const auto row_counter_placement = static_cast(row_counter_placements[t]); + const int64_t row_counter_offset = row_counter_offsets[t]; + if (row_counter_placement == PlacementType::DEVICE) { + row_counter = &row_counter_dev[row_counter_offset]; + } else { + row_counter = &row_counter_uvm[row_counter_offset]; + } + + // need to compute bias correction for each row + if (threadIdx.x == 0) { + _row_counter = row_counter[idx] + 1; + row_counter[idx] = _row_counter; + } + + // broadcast bias correction to all threads + _row_counter = SHFL_SYNC(_row_counter, 0); + } + """ + split_weight_update = """ Vec4T m_t(&momentum1[idx * D + d]); m_t.acc.x *= beta1; @@ -1023,10 +1046,10 @@ def adam() -> Dict[str, Any]: v_t.fma_(grad, 1.0 - beta2); v_t.store(&momentum2[idx * D + d]); - weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.x); - weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.y); - weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.z); - weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.w); + weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.x); + weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.y); + weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.z); + weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.w); """ split_weight_update_cpu = "" # TODO @@ -1043,12 +1066,14 @@ def adam() -> Dict[str, Any]: OptimItem(ArgType.FLOAT, "beta2"), OptimItem(ArgType.FLOAT, "weight_decay"), OptimItem(ArgType.INT, "iter"), + OptimItem(ArgType.BOOL, "use_rowwise_bias_correction"), + OptimItem(ArgType.TENSOR, "row_counter", is_optional=True), ], { - "v1": "Tensor momentum1, Tensor momentum2, float learning_rate = 0, float eps = 0, float beta1 = 0, float beta2 = 0, float weight_decay = 0, int iter = 0" + "v1": "Tensor momentum1, Tensor momentum2, float learning_rate = 0, float eps = 0, float beta1 = 0, float beta2 = 0, float weight_decay = 0, int iter = 0, bool use_rowwise_bias_correction = False, Tensor? row_counter = None", }, ), - "split_precomputation": "", + "split_precomputation": split_precomputation, "split_weight_update": split_weight_update, "split_post_update": "", "split_weight_update_cpu": split_weight_update_cpu, diff --git a/fbgemm_gpu/codegen/genscript/torch_type_utils.py b/fbgemm_gpu/codegen/genscript/torch_type_utils.py index ebd4e0220..aa442ad37 100644 --- a/fbgemm_gpu/codegen/genscript/torch_type_utils.py +++ b/fbgemm_gpu/codegen/genscript/torch_type_utils.py @@ -26,6 +26,7 @@ class ArgType(IntEnum): INT = 7 FLOAT = 8 SYM_INT = 9 + BOOL = 10 @dataclass diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp index e3b459a7b..fc7d8a58f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp @@ -64,14 +64,14 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< bool gradient_clipping, double max_gradient, bool stochastic_rounding, - {{ args.split_function_args | join(", ") }}, + {{ args.split_function_args_autograd | join(", ") }}, int64_t output_dtype = static_cast(SparseType::FP32)) { Tensor indice_weights_value = indice_weights.value_or(Tensor()); Tensor feature_requires_grad_value = feature_requires_grad.value_or(Tensor()); ctx->save_for_backward({ host_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, - indices, offsets, indice_weights_value, feature_requires_grad_value, {{ args.split_saved_tensors | join(", ") }} }); + indices, offsets, indice_weights_value, feature_requires_grad_value, {{ args.split_saved_tensors_optional | join(", ") }} }); ctx->saved_data["total_D"] = total_D; ctx->saved_data["max_D"] = max_D; @@ -242,7 +242,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( gradient_clipping, max_gradient, stochastic_rounding, - {{ args.split_function_arg_names | join(", ") }}, + {{ args.split_function_arg_names_autograd | join(", ") }}, output_dtype)[0]; {% else %} TORCH_CHECK(false, "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail."); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 3efec2527..6efccefb8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -363,7 +363,7 @@ enum SSDTensor { {%- if ssd %} ssd_tensors.value(), {%- endif %} - {{ args.split_function_arg_names | join(", ") }} + {{ args.split_function_arg_names_autograd | join(", ") }} {%- endif %} )[0]; {%- endmacro %} @@ -623,7 +623,7 @@ class {{ autograd_func }} : {%- if ssd %} const at::TensorList& ssd_tensors, {%- endif %} - {{ args.split_function_args | join(", ") }} + {{ args.split_function_args_autograd | join(", ") }} {%- else %} {%- if vbe %} const std::optional& B_offsets, @@ -762,7 +762,7 @@ class {{ autograd_func }} : ssd_tensors[SSDTensor::{{ tensor | upper }}], {%- endfor %} {%- endif %} - {{ args.split_saved_tensors | join(", ") }} + {{ args.split_saved_tensors_optional | join(", ") }} }); {%- if not nobag %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index e7f6258be..b224c3e70 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -476,12 +476,35 @@ enum SSDTensor { /* This macro generates a code blob for unpacking the tensor list */ -{%- macro unpack_tensor_list(tensor_list) %} - const Tensor {{ tensor_list }}_host = {{ tensor_list }}[0]; - const Tensor {{ tensor_list }}_dev = {{ tensor_list }}[1]; - const Tensor {{ tensor_list }}_uvm = {{ tensor_list }}[2]; - const Tensor {{ tensor_list }}_placements = {{ tensor_list }}[3]; - const Tensor {{ tensor_list }}_offsets = {{ tensor_list }}[4]; +{%- macro unpack_tensorlist(name) %} + const Tensor {{ name }}_host = {{ name }}[0]; + const Tensor {{ name }}_dev = {{ name }}[1]; + const Tensor {{ name }}_uvm = {{ name }}[2]; + const Tensor {{ name }}_placements = {{ name }}[3]; + const Tensor {{ name }}_offsets = {{ name }}[4]; +{%- endmacro %} + +{%- macro unpack_tensorlist_optional(name) %} + Tensor {{ name }}_host; + Tensor {{ name }}_dev; + Tensor {{ name }}_uvm; + Tensor {{ name }}_placements; + Tensor {{ name }}_offsets; + if ({{ name }}.has_value()) { + at::TensorList _{{ name }} = {{ name }}.value(); + {{ name }}_host = _{{ name }}[0]; + {{ name }}_dev = _{{ name }}[1]; + {{ name }}_uvm = _{{ name }}[2]; + {{ name }}_placements = _{{ name }}[3]; + {{ name }}_offsets = _{{ name }}[4]; + } + else{ + {{ name }}_host = at::empty({0}, weights_host.options()); + {{ name }}_dev = at::empty({0}, weights_dev.options()); + {{ name }}_uvm = at::empty({0}, weights_uvm.options()); + {{ name }}_placements = at::empty({0}, weights_placements.options()); + {{ name }}_offsets = at::empty({0}, weights_offsets.options()); + } {%- endmacro %} @@ -582,9 +605,12 @@ class {{ autograd_func }} : {{ args_pt2.unified_pt2.split_function_args | join(", ") }}) { // unpack Tensor lists - {{ unpack_tensor_list("weights") }} - {%- for arg_name in args_pt2.unified_pt2.split_saved_tensor_list %} - {{ unpack_tensor_list(arg_name) }} + {{ unpack_tensorlist("weights") }} + {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist %} + {{ unpack_tensorlist(arg_name) }} + {%- endfor %} + {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + {{ unpack_tensorlist_optional(arg_name) }} {%- endfor %} const auto T = weights_offsets.sym_numel(); diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index e86b27b2d..c69837291 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -60,7 +60,7 @@ def invoke( {%- if "prev_iter_dev" in args.split_function_arg_names %} prev_iter: Momentum, {%- endif %} - {%- if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %} row_counter: Momentum, {%- endif %} {%- if "iter" in args.split_function_arg_names %} @@ -209,7 +209,7 @@ def invoke( prev_iter_placements=prev_iter.placements, {%- endif %} # row_counter - {%- if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %} row_counter_host=row_counter.host, row_counter_offsets=row_counter.offsets, row_counter_placements=row_counter.placements, @@ -387,7 +387,7 @@ def invoke( prev_iter_dev=prev_iter_dev, {%- endif %} # row_counter - {%- if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %} row_counter_dev=row_counter.dev, row_counter_uvm=row_counter.uvm, row_counter_offsets=row_counter.offsets, diff --git a/fbgemm_gpu/test/tbe/training/forward_test.py b/fbgemm_gpu/test/tbe/training/forward_test.py index 5ea2ff723..e4e54fc99 100644 --- a/fbgemm_gpu/test/tbe/training/forward_test.py +++ b/fbgemm_gpu/test/tbe/training/forward_test.py @@ -76,6 +76,13 @@ "test_faketensor__test_forward_gpu_uvm_cache_int8": [ unittest.skip("Operator not implemented for Meta tensors"), ], + # learning rate tensor needs to be on CPU to avoid D->H sync point since it will be used as float in the kernel + # this fails fake_tensor test as the test expects all tensors to be on the same device + "test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [ + unittest.skip( + "Operator failed on FakeTensor test since learning rate tensor is always on CPU regardless of other tensors" + ), + ], } ) From 64f8378d4101b7ad903c0299120efb93f4f63f66 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 23 Dec 2024 19:15:16 +0000 Subject: [PATCH 16/16] profile with kineto to eliminate the CPU overhead in benchmark --- fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index f439ed678..4772a3f71 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -1330,6 +1330,10 @@ def nbit_device( # noqa C901 def _kineto_trace_handler(p: profile, phase: str) -> None: p.export_chrome_trace( trace_url.format(tbe_type=tbe_type, phase=phase, ospid=os.getpid()) + #print(p.key_averages()) + # averges the sum of all kernels + total_cuda_time = sum(event.device_time*event.count/(iters+1) for event in p.key_averages() if event.cpu_time == 0.0) + print(f"Total CUDA time: {total_cuda_time:.3f} ") ) # pyre-ignore[3]