diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index a5051e04b..214dc8bb7 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -37,9 +37,14 @@ set(quantize_ops_sources src/quantize/quantize.cu src/quantize/quantize.cpp) +set(comm_ops_sources + src/comm/car.cu + src/comm/car.cpp) + set(experimental_gen_ai_cpp_source_files ${attention_ops_sources} - ${quantize_ops_sources}) + ${quantize_ops_sources} + ${comm_ops_sources}) set_source_files_properties(${experimental_gen_ai_cpp_source_files} PROPERTIES INCLUDE_DIRECTORIES diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp new file mode 100644 index 000000000..e5678ff81 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "c10/core/ScalarType.h" +#include "c10/cuda/CUDADeviceAssertionHost.h" +#include "c10/cuda/CUDAFunctions.h" +#include "c10/cuda/CUDAStream.h" +#include "c10/util/Optional.h" +#include "folly/futures/Future.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "c10/cuda/CUDAException.h" +#include "c10/util/Exception.h" + +namespace fbgemm_gpu { + +namespace { + +static_assert(sizeof(cudaIpcMemHandle_t) == 64, ""); + +constexpr size_t kMaxNumNcclComms = 3; + +static ncclComm_t* get_nccl_comm(int64_t comm_idx) { + static ncclComm_t comms[kMaxNumNcclComms]; + + CHECK_GE(comm_idx, 0); + CHECK_LT(comm_idx, kMaxNumNcclComms); + return &comms[comm_idx]; +} + +void nccl_init( + int64_t rank, + int64_t world_size, + std::string rendevouz, + int64_t comm_idx) { + using namespace c10d; + ncclUniqueId id; + if (rank == 0) { + C10D_NCCL_CHECK(ncclGetUniqueId(&id), "ncclGetUniqueId"); + auto* f = fopen(rendevouz.c_str(), "w"); + fwrite(&id, sizeof(id), 1, f); + fclose(f); + } else { + auto check_size = [&]() { + struct stat s; + memset(&s, 0, sizeof(s)); + stat(rendevouz.c_str(), &s); + return s.st_size; + }; + while ((unsigned long)(check_size()) < sizeof(ncclUniqueId)) { + usleep(1000); + } + auto* f = fopen(rendevouz.c_str(), "r"); + fread(&id, sizeof(id), 1, f); + fclose(f); + } + C10D_NCCL_CHECK( + ncclCommInitRank(get_nccl_comm(comm_idx), world_size, id, rank), + "ncclCommInitRank"); + return; +} + +at::Tensor nccl_get_unique_id() { + using namespace c10d; + ncclUniqueId id; + static_assert(sizeof(ncclUniqueId) == 128, ""); + C10D_NCCL_CHECK(ncclGetUniqueId(&id), "ncclGetUniqueId"); + auto id_ = at::empty({128}, at::TensorOptions().dtype(at::kChar)); + std::memcpy(id_.data_ptr(), &id, sizeof(id)); + return id_; +} + +void nccl_comm_init_rank( + int64_t world_size, + int64_t rank, + at::Tensor id_, + int64_t comm_idx) { + using namespace c10d; + ncclUniqueId id; + static_assert(sizeof(ncclUniqueId) == 128, ""); + std::memcpy(&id, id_.data_ptr(), sizeof(id)); + C10D_NCCL_CHECK( + ncclCommInitRank(get_nccl_comm(comm_idx), world_size, id, rank), + "ncclCommInitRank"); +} + +void nccl_allgather(at::Tensor y_allgather, at::Tensor y, int64_t comm_idx) { + using namespace c10d; + TORCH_CHECK(y.is_contiguous()); + TORCH_CHECK(y_allgather.is_contiguous()); + ncclDataType_t type; + switch (y.scalar_type()) { + case at::kFloat: + type = ncclDataType_t::ncclFloat; + break; + case at::kHalf: + type = ncclDataType_t::ncclHalf; + break; + case at::kBFloat16: + type = ncclDataType_t::ncclBfloat16; + break; + default: + TORCH_CHECK(false, "unsupported type: ", y.scalar_type()); + } + C10D_NCCL_CHECK( + ncclAllGather( + y.data_ptr(), + y_allgather.data_ptr(), + y.numel(), + type, + *get_nccl_comm(comm_idx), + at::cuda::getCurrentCUDAStream()), + "ncclAllGather"); +} + +void nccl_reducescatter( + at::Tensor y_reducescatter, + at::Tensor y, + int64_t comm_idx) { + using namespace c10d; + TORCH_CHECK(y.is_contiguous()); + TORCH_CHECK(y_reducescatter.is_contiguous()); + TORCH_CHECK(y.dtype() == at::ScalarType::BFloat16); + TORCH_CHECK(y_reducescatter.dtype() == at::ScalarType::BFloat16); + + C10D_NCCL_CHECK( + ncclReduceScatter( + y.data_ptr(), + y_reducescatter.data_ptr(), + y_reducescatter.numel(), + ncclDataType_t::ncclBfloat16, + ncclSum, + *get_nccl_comm(comm_idx), + at::cuda::getCurrentCUDAStream()), + "ncclReduceScatter"); +} + +void nccl_allreduce( + at::Tensor y_allreduce, + at::Tensor y, + std::optional z, + int64_t comm_idx) { + using namespace c10d; + TORCH_CHECK(y.is_contiguous()); + TORCH_CHECK(y_allreduce.is_contiguous()); + TORCH_CHECK(y_allreduce.dtype() == y.dtype()); + ncclDataType_t type; + switch (y.scalar_type()) { + case at::kFloat: + type = ncclDataType_t::ncclFloat; + break; + case at::kHalf: + type = ncclDataType_t::ncclHalf; + break; +#ifdef IS_NCCLX_MSCCL + case at::kFloat8_e4m3fn: + type = ncclDataType_t::ncclFp8E4M3; + break; +#endif + case at::kBFloat16: + type = ncclDataType_t::ncclBfloat16; + break; + default: + TORCH_CHECK(false, "unsupported type: ", y.scalar_type()); + } + C10D_NCCL_CHECK( + ncclAllReduce( + y.data_ptr(), + y_allreduce.data_ptr(), + y.numel(), + type, + ncclSum, + *get_nccl_comm(comm_idx), + at::cuda::getCurrentCUDAStream()), + "ncclAllReduce"); + if (z) { + y_allreduce.add_(*z); + } +} + +} // namespace + +at::Tensor car_ipc_handle(at::Tensor x); +void car_init( + int64_t rank, + int64_t world_size, + at::Tensor local_barrier, + std::vector all_barrier_handles, + at::Tensor local_buffer, + std::vector all_buffer_handles); +void one_shot_car_allreduce( + at::Tensor y_allreduce, + at::Tensor y, + std::optional z, + int64_t comm_idx); +void two_shot_car_allreduce( + at::Tensor y_allreduce, + at::Tensor y, + std::optional z, + int64_t comm_idx); + +at::Tensor car_tensor(); + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "nccl_init(int rank, int world_size, str rendevouz, int comm_idx=0) -> ()"); + m.impl("nccl_init", nccl_init); + + m.def("nccl_get_unique_id() -> Tensor"); + m.impl("nccl_get_unique_id", nccl_get_unique_id); + + m.def( + "nccl_comm_init_rank(int world_size, int rank, Tensor id_, int comm_idx=0) -> ()"); + m.impl("nccl_comm_init_rank", nccl_comm_init_rank); + + m.def("nccl_allgather(Tensor y_allgather, Tensor y, int comm_idx=0) -> ()"); + m.impl("nccl_allgather", nccl_allgather); + + m.def( + "nccl_reducescatter(Tensor y_reducescatter, Tensor y, int comm_idx=0) -> ()"); + m.impl("nccl_reducescatter", nccl_reducescatter); + + m.def( + "nccl_allreduce(Tensor y_allreduce, Tensor y, Tensor? z=None, int comm_idx=0) -> ()"); + m.impl("nccl_allreduce", nccl_allreduce); + + // car: customized all reduce + m.def("car_tensor() -> Tensor"); + m.impl("car_tensor", car_tensor); + + m.def("car_ipc_handle(Tensor buffer) -> Tensor"); + m.impl("car_ipc_handle", car_ipc_handle); + + m.def( + "car_init(int rank, int world_size, Tensor local_barrier, Tensor[] all_barrier_handles, Tensor local_buffer, Tensor[] all_buffer_handles) -> ()"); + m.impl("car_init", car_init); + + m.def( + "one_shot_car_allreduce(Tensor y_allreduce, Tensor y, Tensor? z=None, int comm_idx=0) -> ()"); + m.impl("one_shot_car_allreduce", one_shot_car_allreduce); + + m.def( + "two_shot_car_allreduce(Tensor y_allreduce, Tensor y, Tensor? z=None, int comm_idx=0) -> ()"); + m.impl("two_shot_car_allreduce", two_shot_car_allreduce); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu new file mode 100644 index 000000000..9ef0c4efe --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -0,0 +1,652 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "c10/core/ScalarType.h" +#include "c10/util/BFloat16.h" + +#if !( \ + defined(USE_ROCM) || \ + ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#elif (defined(USE_ROCM)) +#include +#endif + +#ifndef USE_ROCM +#include +#endif +#include +// #include "cuda_dispatch_utils.h" + +#include + +#include + +#if ( \ + defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900))) +#define USE_WMMA_FRAG +#endif + +namespace fbgemm_gpu { + +#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) + +static __host__ DEVICE_INLINE int32_t div_up(int32_t a, int32_t b) { + return (a + b - 1) / b; +}; + +#ifdef __HIP_PLATFORM_AMD__ +constexpr int32_t kThreadsPerWarp = 64; +#else +constexpr int32_t kThreadsPerWarp = 32; +#endif + +#ifdef __HIP_PLATFORM_AMD__ +using __nv_bfloat16 = hip_bfloat16; + +typedef struct __align__(4) { + uint16_t x; + uint16_t y; +} +__nv_bfloat162_raw; + +struct __align__(4) __nv_bfloat162 { + __nv_bfloat16 x; + __nv_bfloat16 y; +}; + +// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical +// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC +static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) { + __nv_bfloat16 output; + return output.round_to_bfloat16(f); +} + +static __host__ __device__ __nv_bfloat16 __float2bfloat16_rn(float f) { + __nv_bfloat16 output; + return output.round_to_bfloat16(f); +} + +static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) { + // float output; + // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html + return float(f); +} + +static __host__ __device__ __nv_bfloat162 +__floats2bfloat162_rn(float x, float y) { + __nv_bfloat162 output; + output.x = __float2bfloat16_rn(x); + output.y = __float2bfloat16_rn(y); + return output; +} + +#endif + +struct __align__(16) bf16x8 { + __nv_bfloat162 vals[4]; +}; + +DEVICE_INLINE __nv_bfloat162 +bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#elif defined(USE_ROCM) + float fxl, fxh, fyl, fyh; + fxl = __bfloat162float(x.x); + fxh = __bfloat162float(x.y); + fyl = __bfloat162float(y.x); + fyh = __bfloat162float(y.y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) { + bf16x8 c; + c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]); + c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]); + c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]); + c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]); + return c; +} + +template +__global__ void one_shot_all_reduce( + int32_t rank, + int32_t world_size, + int32_t flag, + std::array barriers, + std::array inputs, + at::BFloat16* acc, + at::BFloat16* output, + int32_t N) { + // Synchronize the ranks. + volatile int32_t* barrier_d = barriers[rank]; + if (threadIdx.x < kWorldSize) { + // The 1st block notifies the other ranks. + if (blockIdx.x == 0) { +#if defined(USE_ROCM) + __atomic_store_n(barriers[threadIdx.x] + rank, flag, __ATOMIC_RELEASE); +#else + barriers[threadIdx.x][rank] = flag; +#endif + } + + // Busy-wait until all ranks are ready. +#if defined(USE_ROCM) + while (__atomic_load_n(barrier_d + threadIdx.x, __ATOMIC_ACQUIRE) != flag) { + } +#else + while (barrier_d[threadIdx.x] != flag) { + } +#endif + } + + // Make sure we can move on... + __syncthreads(); + // The source pointers. Distributed round-robin for the different warps. + const at::BFloat16* src_d[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int src_rank = (rank + ii) % kWorldSize; + src_d[ii] = inputs[src_rank]; + } + + // Each block accumulates the values from the different GPUs on the same + // node. + for (size_t i = blockDim.x * blockIdx.x * 8 + threadIdx.x * 8; i < N; + i += blockDim.x * gridDim.x * 8) { + // Iterate over the different ranks/devices on the node to load the + // values. + bf16x8 vals[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + *reinterpret_cast(&vals[ii]) = + reinterpret_cast(&src_d[ii][i])[0]; + } + + // Sum the values from the different ranks. + bf16x8 sums; + if (acc) { + *reinterpret_cast(&sums) = + *reinterpret_cast(&acc[i]); + } else { + memset(reinterpret_cast(&sums), 0, sizeof(sums)); + } + +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + sums = add_bf16x8(sums, vals[ii]); + } + + // Store to the destination buffer. + *reinterpret_cast(&output[i]) = + *reinterpret_cast(&sums); + } + + // barrier to sync with all other ranks on the same blockIdx + // this is needed to ensure this-rank won't override its inputs buffer + // (as we always do memcpy from srcbuff to inputs buffer first) + // while other ranks are still reading them. + __syncthreads(); + + if (threadIdx.x < kWorldSize) { + // notify all other blocks this blockIdx is ready + const int32_t flag_block_offset = kWorldSize + blockIdx.x * kWorldSize; + +#if defined(USE_ROCM) + __atomic_store_n( + barriers[threadIdx.x] + flag_block_offset + rank, + flag, + __ATOMIC_RELEASE); +#else + barriers[threadIdx.x][flag_block_offset + rank] = flag; +#endif + + // busy-wait until all ranks are ready +#if defined(USE_ROCM) + while (__atomic_load_n( + barrier_d + flag_block_offset + threadIdx.x, __ATOMIC_ACQUIRE) != + flag) { + } +#else + while (barrier_d[flag_block_offset + threadIdx.x] != flag) { + } +#endif + } +} + +struct CustomAllReduceState { + std::vector barriers_; + std::vector buffers_; + + int32_t rank_; + int32_t world_size_; + int32_t flag_{0}; +}; + +CustomAllReduceState* get_car_state() { + static auto* r = new CustomAllReduceState(); + return r; +} +constexpr int64_t kMaxCAR = 50 * 1024 * 1024; + +void car_init( + int64_t rank, + int64_t world_size, + at::Tensor local_barrier, + std::vector all_barrier_handles, + at::Tensor local_buffer, + std::vector all_buffer_handles) { + at::OptionalDeviceGuard guard(local_buffer.device()); + auto to_handle = [](at::Tensor r) { + cudaIpcMemHandle_t handle; + std::memcpy(&handle, r.data_ptr(), sizeof(handle)); + return handle; + }; + + auto state = get_car_state(); + state->rank_ = rank; + state->world_size_ = world_size; + state->flag_ = 0; + state->buffers_.resize(world_size); + state->barriers_.resize(world_size); + TORCH_CHECK(world_size == all_buffer_handles.size()); + TORCH_CHECK(world_size == all_barrier_handles.size()); + + for (auto ii = 0; ii < world_size; ++ii) { + void* ptr = nullptr; + if (ii != rank) { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &ptr, + to_handle(all_buffer_handles[ii]), + cudaIpcMemLazyEnablePeerAccess)); + } else { + ptr = local_buffer.data_ptr(); + } +#ifndef __HIP_PLATFORM_AMD__ + auto target_rank = rank; +#else + /* + * This is to mitigate an issue for ROCm where the + * device for the data ptr from hipIpcOpenMemHandle + * is always 0, tracked in FBA-288 + */ + auto target_rank = (ii == rank ? rank : 0); +#endif + state->buffers_[ii] = at::from_blob( + ptr, + {kMaxCAR}, + at::TensorOptions() + .dtype(at::kBFloat16) + .device(at::Device(at::kCUDA, target_rank))); + } + for (auto ii = 0; ii < world_size; ++ii) { + void* ptr = nullptr; +#ifndef __HIP_PLATFORM_AMD__ + auto target_rank = rank; +#else + auto target_rank = (ii == rank ? rank : 0); +#endif + if (ii != rank) { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &ptr, + to_handle(all_barrier_handles[ii]), + cudaIpcMemLazyEnablePeerAccess)); + } else { + ptr = local_barrier.data_ptr(); + } + state->barriers_[ii] = at::from_blob( + ptr, + {kMaxCAR}, + at::TensorOptions().dtype(at::kInt).device( + at::Device(at::kCUDA, target_rank))); + } +} + +at::Tensor car_ipc_handle(at::Tensor x) { + cudaIpcMemHandle_t handle; + AT_CUDA_CHECK(cudaIpcGetMemHandle(&handle, x.data_ptr())); + auto r = at::empty( + sizeof(cudaIpcMemHandle_t), at::TensorOptions().dtype(at::kChar)); + std::memcpy(r.data_ptr(), &handle, sizeof(handle)); + return r; +} + +// need to cudaMalloc ourselves to avoid caching allocator handing out wrong +// base pointer. +at::Tensor car_tensor() { + void* ptr = nullptr; + // 1M N +#if defined(USE_ROCM) + // for MI300, we need to allocate uncached (fine-grained) memory so that the + // barrier value will be visible within the kernel instead of at the kernel + // boundary + int flag = hipDeviceMallocUncached; + C10_CUDA_CHECK(hipExtMallocWithFlags(&ptr, kMaxCAR * 2, flag)); +#else + C10_CUDA_CHECK(cudaMalloc(&ptr, kMaxCAR * 2)); +#endif + return at::from_blob( + ptr, + {kMaxCAR}, + at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA)); +} + +static DEVICE_INLINE void st_flag_release(int32_t& flag, int32_t* flag_addr) { +#if defined(USE_ROCM) + __atomic_store_n(flag_addr, flag, __ATOMIC_RELEASE); +#elif __CUDA_ARCH__ >= 700 + asm volatile( + "st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#else + __threadfence_system(); + asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) { +#if defined(USE_ROCM) + flag = __atomic_load_n(flag_addr, __ATOMIC_ACQUIRE); +#elif __CUDA_ARCH__ >= 700 + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.global.volatile.b32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#endif +} + +template +__launch_bounds__(1024) __global__ void two_shot_all_reduce( + int32_t rank, + int32_t world_size, + int32_t flag, + std::array barriers, + std::array inputs, + at::BFloat16* acc, + at::BFloat16* output, + int32_t N) { + int32_t N_per_rank = N / kWorldSize; + int32_t N_start = N_per_rank * rank; + + // Synchronize the ranks. + volatile int32_t* barrier_d = barriers[rank]; + if (threadIdx.x < kWorldSize) { + // The 1st block notifies the other ranks. + if (blockIdx.x == 0) { +#if defined(USE_ROCM) + __atomic_store_n(barriers[threadIdx.x] + rank, flag, __ATOMIC_RELEASE); +#else + barriers[threadIdx.x][rank] = flag; +#endif + } + + // Busy-wait until all ranks are ready. +#if defined(USE_ROCM) + while (__atomic_load_n(barrier_d + threadIdx.x, __ATOMIC_ACQUIRE) != flag) { + } +#else + while (barrier_d[threadIdx.x] != flag) { + } +#endif + } + + __syncthreads(); + + at::BFloat16* src_d[kWorldSize]; + int dst_rank[kWorldSize]; + +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int d_rank = (rank + ii) % kWorldSize; + src_d[ii] = inputs[d_rank]; + dst_rank[ii] = d_rank; + } + + // Each block accumulates the values from the different GPUs on the same + // node. + for (size_t i = threadIdx.x * 8 + blockIdx.x * blockDim.x * 8; i < N_per_rank; + i += gridDim.x * blockDim.x * 8) { + bf16x8 vals[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + *reinterpret_cast(&vals[ii]) = + reinterpret_cast(&src_d[ii][i + N_start])[0]; + } + + bf16x8 sums; + if (acc) { + *reinterpret_cast(&sums) = + *reinterpret_cast(&acc[i + N_start]); + } else { + memset(reinterpret_cast(&sums), 0, sizeof(sums)); + } + +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + sums = add_bf16x8(sums, vals[ii]); + } + + // Store to the local buffer. + *reinterpret_cast(&src_d[0][i + N_start]) = + *reinterpret_cast(&sums); + } + + __syncthreads(); + + // barreris among the blocks with the same idx (release-acuqire semantics) + if (threadIdx.x < kWorldSize) { + // The all blocks notifies the other ranks. + int32_t flag_block_offset = kWorldSize + blockIdx.x * kWorldSize; + st_flag_release(flag, barriers[threadIdx.x] + flag_block_offset + rank); + + // Busy-wait until all ranks are ready. + int32_t rank_barrier = 0; + int32_t* peer_barrier_d = barriers[rank] + flag_block_offset + threadIdx.x; + do { + ld_flag_acquire(rank_barrier, peer_barrier_d); + } while (rank_barrier != flag); + } + + __syncthreads(); + + // Gather all needed elts from other intra-node ranks + for (size_t i = threadIdx.x * 8 + blockIdx.x * blockDim.x * 8; i < N_per_rank; + i += gridDim.x * blockDim.x * 8) { +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int i_r = N_start + i + (dst_rank[ii] - rank) * N_per_rank; + *reinterpret_cast(&output[i_r]) = + reinterpret_cast(&src_d[ii][i_r])[0]; + } + } +} + +void one_shot_car_allreduce( + at::Tensor y_allreduce, + at::Tensor y, + std::optional z, + int64_t comm_idx) { // match the API with nccl_allreduce in + // https://fburl.com/code/v538vig9 + c10::cuda::CUDAGuard gg(y_allreduce.device()); + TORCH_CHECK(y_allreduce.is_contiguous()); + TORCH_CHECK(y.is_contiguous()); + TORCH_CHECK(y.numel() == y_allreduce.numel()); + TORCH_CHECK(y.numel() % 8 == 0); + TORCH_CHECK(y.numel() < kMaxCAR); + const auto N = y.numel(); + if (z) { + TORCH_CHECK(z->numel() == y.numel()); + } + auto state = get_car_state(); + ++state->flag_; + + std::array inputs; + for (auto ii = 0; ii < state->world_size_; ++ii) { + inputs[ii] = state->buffers_[ii].data_ptr(); + } + + std::array barriers; + for (auto ii = 0; ii < state->world_size_; ++ii) { + barriers[ii] = state->barriers_[ii].data_ptr(); + } + + AT_CUDA_CHECK(cudaMemcpyAsync( + inputs[state->rank_], + y.data_ptr(), + y.numel() * y.element_size(), + cudaMemcpyDeviceToDevice, + at::cuda::getCurrentCUDAStream())); + + constexpr int32_t N_per_thread = 8; + constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp; + TORCH_CHECK(N % N_per_warp == 0); + constexpr int32_t kThreadsPerBlock = 1024; + constexpr int32_t kMaxBlocks = 24; + + dim3 threads(0, 1, 1); + dim3 blocks(0, 1, 1); + if (N < N_per_thread * kThreadsPerBlock) { + threads.x = div_up(N, N_per_warp) * kThreadsPerWarp; + blocks.x = 1; + } else { + auto warps_required = div_up(N, N_per_warp); + blocks.x = std::min( + cuda_calc_block_count(div_up(N, N_per_thread), kThreadsPerBlock), + kMaxBlocks); + auto warps_per_block = div_up(warps_required, blocks.x); + auto threads_per_block = + std::min(kThreadsPerBlock, warps_per_block * kThreadsPerWarp); + + threads.x = threads_per_block; + } + +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + z ? z->data_ptr() : nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } + + TORCH_CHECK( + state->world_size_ == 2 || state->world_size_ == 4 || + state->world_size_ == 8); + X(2); + X(4); + X(8); + +#undef X + return; +} + +void two_shot_car_allreduce( + at::Tensor y_allreduce, + at::Tensor y, + std::optional z, + int64_t comm_idx) { // match the API with nccl_allreduce in + // https://fburl.com/code/v538vig9 + c10::cuda::CUDAGuard gg(y_allreduce.device()); + TORCH_CHECK(y_allreduce.is_contiguous()); + TORCH_CHECK(y.is_contiguous()); + TORCH_CHECK(y.numel() == y_allreduce.numel()); + TORCH_CHECK(y.numel() % 8 == 0); + TORCH_CHECK(y.numel() < kMaxCAR); + const auto N = y.numel(); + if (z) { + TORCH_CHECK(z->numel() == y.numel()); + } + auto state = get_car_state(); + ++state->flag_; + + std::array inputs; + for (auto ii = 0; ii < state->world_size_; ++ii) { + inputs[ii] = state->buffers_[ii].data_ptr(); + } + + std::array barriers; + for (auto ii = 0; ii < state->world_size_; ++ii) { + barriers[ii] = state->barriers_[ii].data_ptr(); + } + + AT_CUDA_CHECK(cudaMemcpyAsync( + inputs[state->rank_], + y.data_ptr(), + y.numel() * y.element_size(), + cudaMemcpyDeviceToDevice, + at::cuda::getCurrentCUDAStream())); + + constexpr int32_t N_per_thread = 8; + TORCH_CHECK(N % state->world_size_ == 0); + const auto N_per_rank = N / state->world_size_; + + TORCH_CHECK(N_per_rank % N_per_thread == 0); + auto threads_per_rank = N_per_rank / N_per_thread; + + constexpr int32_t kThreadsPerBlock = 1024; + constexpr int32_t kMaxBlocks = 24; + + auto blocks = std::min( + cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks); + +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + z ? z->data_ptr() : nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } + + TORCH_CHECK( + state->world_size_ == 2 || state->world_size_ == 4 || + state->world_size_ == 8); + X(2); + X(4); + X(8); + +#undef X + return; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py new file mode 100644 index 000000000..f02e8cade --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[56] + +import functools +import logging +import math +import os +import tempfile +import unittest +import uuid + +import fbgemm_gpu.experimental.gen_ai # noqa: F401 + +import numpy as np +import torch +from torch.distributed.launcher.api import elastic_launch, LaunchConfig + +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +@functools.lru_cache +def has_nvswitch() -> bool: + import subprocess + + model = subprocess.check_output( + "cat /etc/fbwhoami | grep MODEL_NAME", shell=True + ).decode("utf-8") + return "GRANDTETON" in model or "SUPERMICRO" in model + + +def _run_allgather_inner(rdvz: str) -> None: + rank = int(os.environ["LOCAL_RANK"]) + W = int(os.environ["WORLD_SIZE"]) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + torch.ops.fbgemm.nccl_init(rank, W, rdvz) + # torch.distributed.init_process_group(backend="nccl") + + B, T, D = 2, 4096, 1024 + y = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") + y[:] = rank + y_gather = torch.zeros(size=(W, B, T, D), dtype=torch.bfloat16, device="cuda") + y_gather[:] = -1 + torch.ops.fbgemm.nccl_allgather(y_gather, y) + for w in range(W): + torch.testing.assert_close( + y_gather[w], + torch.full( + size=(B, T, D), fill_value=w, dtype=torch.bfloat16, device=y.device + ), + ) + + for _ in range(20): + torch.ops.fbgemm.nccl_allgather(y_gather, y) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + torch.ops.fbgemm.nccl_allgather(y_gather, y) + + for _ in range(10): + g.replay() + + +def _run_allreduce_inner(path: str) -> None: + rank = int(os.environ["LOCAL_RANK"]) + W = int(os.environ["WORLD_SIZE"]) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + torch.ops.fbgemm.nccl_init(rank, W, os.path.join(path, "rdvz")) + + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=f"file://{os.path.join(path, 'gloo_rdvz')}", + world_size=W, + rank=rank, + ) + + buffer = torch.ops.fbgemm.car_tensor() + barrier = torch.ops.fbgemm.car_tensor() + barrier.zero_() + + buffer_handle = torch.ops.fbgemm.car_ipc_handle(buffer) + all_buffer_handles = [torch.empty_like(buffer_handle) for _ in range(W)] + torch.distributed.all_gather(all_buffer_handles, buffer_handle) + + barrier_handle = torch.ops.fbgemm.car_ipc_handle(barrier) + all_barrier_handles = [torch.empty_like(barrier_handle) for _ in range(W)] + torch.distributed.all_gather(all_barrier_handles, barrier_handle) + + torch.ops.fbgemm.car_init( + rank, W, barrier, all_barrier_handles, buffer, all_buffer_handles + ) + torch.cuda.synchronize() + torch.distributed.barrier() + + for N in np.logspace(10, 24, num=20, base=2).tolist(): + N = int(N) + y = torch.zeros(size=(N,), dtype=torch.bfloat16, device="cuda") + y[:] = rank + y_allreduce = torch.empty_like(y) + torch.ops.fbgemm.nccl_allreduce(y_allreduce, y) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ), + ) + + z = torch.ones(size=(N,), dtype=torch.bfloat16, device="cuda") + torch.ops.fbgemm.nccl_allreduce(y_allreduce, y, z) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ) + + 1, + ) + + def round_up(a: int, b: int) -> int: + return int(math.ceil(a / b)) * b + + N = round_up(N, 256) + y = torch.zeros(size=(N,), dtype=torch.bfloat16, device="cuda") + y[:] = rank + y_allreduce = torch.empty_like(y) + torch.ops.fbgemm.one_shot_car_allreduce(y_allreduce, y) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ), + ) + z = torch.ones(size=(N,), dtype=torch.bfloat16, device="cuda") + torch.ops.fbgemm.one_shot_car_allreduce(y_allreduce, y, z) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ) + + 1, + ) + if has_nvswitch() or (not has_nvswitch() and N < 16 * 1024): + N = round_up(N, 1024) + y = torch.zeros(size=(N,), dtype=torch.bfloat16, device="cuda") + y[:] = rank + y_allreduce = torch.empty_like(y) + torch.ops.fbgemm.two_shot_car_allreduce(y_allreduce, y) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ), + ) + z = torch.ones(size=(N,), dtype=torch.bfloat16, device="cuda") + torch.ops.fbgemm.two_shot_car_allreduce(y_allreduce, y, z) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ) + + 1, + ) + + +def _run_oneshot_car_stress_inner(path: str) -> None: + rank = int(os.environ["LOCAL_RANK"]) + W = int(os.environ["WORLD_SIZE"]) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + torch.ops.fbgemm.nccl_init(rank, W, os.path.join(path, "rdvz")) + + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=f"file://{os.path.join(path, 'gloo_rdvz')}", + world_size=W, + rank=rank, + ) + + buffer = torch.ops.fbgemm.car_tensor() + barrier = torch.ops.fbgemm.car_tensor() + barrier.zero_() + + buffer_handle = torch.ops.fbgemm.car_ipc_handle(buffer) + all_buffer_handles = [torch.empty_like(buffer_handle) for _ in range(W)] + torch.distributed.all_gather(all_buffer_handles, buffer_handle) + + barrier_handle = torch.ops.fbgemm.car_ipc_handle(barrier) + all_barrier_handles = [torch.empty_like(barrier_handle) for _ in range(W)] + torch.distributed.all_gather(all_barrier_handles, barrier_handle) + + torch.ops.fbgemm.car_init( + rank, W, barrier, all_barrier_handles, buffer, all_buffer_handles + ) + torch.cuda.synchronize() + torch.distributed.barrier() + + ITER = 1000 + for idx, N in enumerate(np.logspace(4, 24, num=20, base=2).tolist()): + N = int(N) + + def round_up(a: int, b: int) -> int: + return int(math.ceil(a / b)) * b + + N = round_up(N, 256) + if rank == 0: + print(f"N: {N}") + for iterId in range(ITER): + y = torch.zeros(size=(N,), dtype=torch.bfloat16, device="cuda") + y[:] = rank + idx + iterId + y_allreduce = torch.empty_like(y) + torch.ops.fbgemm.one_shot_car_allreduce(y_allreduce, y) + torch.testing.assert_close( + y_allreduce, + torch.full( + size=(N,), + fill_value=(W * (W - 1) // 2), + dtype=torch.bfloat16, + device=y.device, + ) + + (idx + iterId) * W, + ) + + +@unittest.skipIf( + not torch.cuda.is_available(), + "Skip when CUDA is not available", +) +class LLamaMultiGpuTests(unittest.TestCase): + @unittest.skipIf( + torch.cuda.device_count() < 2, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_allgather(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path: + lc = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=torch.cuda.device_count(), + run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint=os.path.join(tmpdir, "rdzv"), + rdzv_configs={"store_type": "file"}, + start_method="spawn", + monitor_interval=1, + max_restarts=0, + ) + elastic_launch(config=lc, entrypoint=_run_allgather_inner)( + os.path.join(path, "rdvz") + ) + + @unittest.skipIf( + torch.cuda.device_count() < 2, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_allreduce(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path: + lc = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=torch.cuda.device_count(), + run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint=os.path.join(tmpdir, "rdzv"), + rdzv_configs={"store_type": "file"}, + start_method="spawn", + monitor_interval=1, + max_restarts=0, + ) + elastic_launch(config=lc, entrypoint=_run_allreduce_inner)(path) + + def test_oneshot_car_stress(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path: + lc = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=torch.cuda.device_count(), + run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint=os.path.join(tmpdir, "rdzv"), + rdzv_configs={"store_type": "file"}, + start_method="spawn", + monitor_interval=1, + max_restarts=0, + ) + elastic_launch(config=lc, entrypoint=_run_oneshot_car_stress_inner)(path)