From f4710c1b3b477838498cc71c7e6d3cf296f3ad82 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 3 Oct 2024 12:36:19 -0700 Subject: [PATCH] FP8 KV + Disagg unit test (#3218) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3218 X-link: https://github.com/facebookresearch/FBGEMM/pull/315 Adding the fp8 kv cache to disagg test for mp2. Changes include changing the model to 7b llama model. The small model has D_H of 64, which is not working with dequantization kernel (will check the issue in another diff). TODO: add Fp8 kv cache + paged kv to the test Reviewed By: jianyuh Differential Revision: D62772678 fbshipit-source-id: 775f572e2c345354844e24d80e2481284ac6f1a3 --- fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 787c0547c..00974a9fe 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -1437,6 +1437,7 @@ __global__ void dequantize_fp8_cache_kernel( auto MAX_T = cache_K.size(1); auto D_H = cache_K_dq.size(3); auto D_H_q = cache_K.size(3); + // TODO: support D_H < 128 for small model used in testing. CUDA_KERNEL_ASSERT(D_H == 128); auto b = blockIdx.x;