Skip to content

Commit

Permalink
Add NCCL integration (#2624)
Browse files Browse the repository at this point in the history
Summary:
- Add example triton test to exercise triton integration
- Add NCCL integration


Differential Revision: D57741140

Pulled By: q10
  • Loading branch information
q10 authored and facebook-github-bot committed May 23, 2024
1 parent 7930859 commit d29c203
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 30 deletions.
32 changes: 23 additions & 9 deletions .github/scripts/fbgemm_gpu_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,22 @@ __configure_fbgemm_gpu_build_nvcc () {
# shellcheck disable=SC2155
local env_prefix=$(env_name_or_prefix "${env_name}")

echo "[BUILD] Looking up CUDA version ..."
# shellcheck disable=SC2155,SC2086
local cxx_path=$(conda run ${env_prefix} which c++)
# shellcheck disable=SC2155,SC2086
local cuda_version=$(conda run ${env_prefix} nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p')
# shellcheck disable=SC2206
local cuda_version_arr=(${cuda_version//./ })

echo "[BUILD] Looking up NCCL path ..."
# shellcheck disable=SC2155,SC2086
local conda_prefix=$(conda run ${env_prefix} printenv CONDA_PREFIX)
# shellcheck disable=SC2155,SC2086
local nccl_lib=$(conda run ${env_prefix} find ${conda_prefix} -name "libnccl.so*")
# shellcheck disable=SC2155,SC2086
local nccl_path=$(dirname "$(dirname ${nccl_lib})")

# Only NVCC 12+ supports C++20
if [[ ${cuda_version_arr[0]} -lt 12 ]]; then
local cppstd_ver=17
Expand All @@ -109,11 +118,13 @@ __configure_fbgemm_gpu_build_nvcc () {
# shellcheck disable=SC2086
print_exec conda env config vars set ${env_prefix} NVCC_PREPEND_FLAGS=\"${nvcc_prepend_flags}\"

echo "[BUILD] Setting CUDA build args ..."
# shellcheck disable=SC2206
build_args+=(
# Override CMake configuration
-DCMAKE_CXX_STANDARD="${cppstd_ver}"
-DHIP_STANDARD="${cppstd_ver}"
-DNCCL_INCLUDE_DIR=${nccl_path}/include
-DNCCL_LIB_DIR=${nccl_path}/lib
)
}

Expand Down Expand Up @@ -158,14 +169,17 @@ __configure_fbgemm_gpu_build_rocm () {
print_exec conda env config vars set ${env_prefix} PYTORCH_ROCM_ARCH="${arch_list}"

echo "[BUILD] Setting ROCm build args ..."
# shellcheck disable=SC2155
local cxx_flags="-DTORCH_USE_HIP_DSA"

build_args=(
--package_variant=rocm
# HIP_ROOT_DIR now required for HIP to be correctly detected by CMake
-DHIP_ROOT_DIR=/opt/rocm
# Enable device-side assertions in HIP
# https://stackoverflow.com/questions/44284275/passing-compiler-options-in-cmake-command-line
-DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA"
-DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA"
-DCMAKE_C_FLAGS="'${cxx_flags}'"
-DCMAKE_CXX_FLAGS="'${cxx_flags}'"
)
}

Expand Down Expand Up @@ -473,18 +487,18 @@ build_fbgemm_gpu_package () {
# shellcheck disable=SC2086
print_exec conda run --no-capture-output ${env_prefix} \
python -m build --wheel --no-isolation \
"${build_args[@]}"
"${build_args[@]}" || return 1

# Run checks on the built libraries
(run_fbgemm_gpu_postbuild_checks "${fbgemm_variant}") || return 1

echo "[BUILD] Enumerating the built wheels ..."
print_exec ls -lth dist/*.whl
print_exec ls -lth dist/*.whl || return 1

echo "[BUILD] Enumerating the wheel SHAs ..."
print_exec sha1sum dist/*.whl
print_exec sha256sum dist/*.whl
print_exec md5sum dist/*.whl
print_exec sha1sum dist/*.whl || return 1
print_exec sha256sum dist/*.whl || return 1
print_exec md5sum dist/*.whl || return 1

echo "[BUILD] FBGEMM-GPU build + package completed"
}
Expand Down Expand Up @@ -524,7 +538,7 @@ build_fbgemm_gpu_install () {
# shellcheck disable=SC2086
print_exec conda run --no-capture-output ${env_prefix} \
python setup.py "${run_multicore}" install \
"${build_args[@]}"
"${build_args[@]}" || return 1

# Run checks on the built libraries
(run_fbgemm_gpu_postbuild_checks "${fbgemm_variant}") || return 1
Expand Down
7 changes: 7 additions & 0 deletions cmake/modules/CudaSetup.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake)
# CUDA Setup
################################################################################

BLOCK_PRINT(
"NCCL flags"
""
"NCCL_INCLUDE_DIR=${NCCL_INCLUDE_DIR}"
"NCCL_LIB_DIR=${NCCL_LIB_DIR}"
)

# Set NVML_LIB_PATH if provided, or detect the default lib path
if(NOT NVML_LIB_PATH)
set(DEFAULT_NVML_LIB_PATH
Expand Down
4 changes: 3 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ endif()
################################################################################

if(NOT FBGEMM_CPU_ONLY)
add_subdirectory(experimental/example)
add_subdirectory(experimental/gemm)

if(NOT USE_ROCM)
# TODO: Figure out NCCL/RCCL integration with ROCm
add_subdirectory(experimental/example)

# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
#
# 2024-05-06T23:09:35.5730483Z /__w/FBGEMM/FBGEMM/fbgemm_gpu/../third_party/cutlass/include/cutlass/half.h:73:10: fatal error: 'cuda_fp16.h' file not found
Expand Down
13 changes: 10 additions & 3 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ set(fbgemm_sources_include_directories
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/../include
# PyTorch
${TORCH_INCLUDE_DIRS}
# Third-party
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include
${THIRDPARTY}/cutlass/tools/util/include)
${THIRDPARTY}/cutlass/tools/util/include
${NCCL_INCLUDE_DIR})


################################################################################
Expand Down Expand Up @@ -624,13 +627,17 @@ else()
endif()

# Add PyTorch include/
target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS})
target_include_directories(fbgemm_gpu_py PRIVATE
${TORCH_INCLUDE_DIRS}
${NCCL_INCLUDE_DIR})

# Remove `lib` from the output artifact name `libfbgemm_gpu_py.so`
set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "")

# Link to PyTorch
target_link_libraries(fbgemm_gpu_py ${TORCH_LIBRARIES})
target_link_libraries(fbgemm_gpu_py
${TORCH_LIBRARIES}
${NCCL_LIB_DIR})

# Link to NVML
if(NVML_LIB_PATH)
Expand Down
17 changes: 13 additions & 4 deletions fbgemm_gpu/experimental/example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ if(FBGEMM_GENAI_ONLY)
${CMAKE_CURRENT_SOURCE_DIR}../..
${CMAKE_CURRENT_SOURCE_DIR}../../include
${CMAKE_CURRENT_SOURCE_DIR}../../../include
# PyTorch
${TORCH_INCLUDE_DIRS}
# Third-party
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include
${THIRDPARTY}/cutlass/tools/util/include)
${THIRDPARTY}/cutlass/tools/util/include
${NCCL_INCLUDE_DIR})

set(third_party_include_directories
${THIRDPARTY}/asmjit/src
Expand All @@ -32,7 +35,8 @@ endif()

set(experimental_example_cpp_source_files
src/cutlass_sgemm_nn.cu
src/example_ops.cpp)
src/example_ops.cpp
src/nccl_example.cpp)

set_source_files_properties(${experimental_example_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES
Expand All @@ -50,8 +54,13 @@ set(experimental_example_python_source_files
add_library(fbgemm_gpu_experimental_example_py MODULE
${experimental_example_cpp_source_files})

target_include_directories(fbgemm_gpu_experimental_example_py PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(fbgemm_gpu_experimental_example_py ${TORCH_LIBRARIES})
target_include_directories(fbgemm_gpu_experimental_example_py PRIVATE
${TORCH_INCLUDE_DIRS}
${NCCL_INCLUDE_DIR})

target_link_libraries(fbgemm_gpu_experimental_example_py
${TORCH_LIBRARIES}
${NCCL_LIB_DIR})

# Remove `lib` from the output artifact name
set_target_properties(fbgemm_gpu_experimental_example_py PROPERTIES PREFIX "")
Expand Down
23 changes: 23 additions & 0 deletions fbgemm_gpu/experimental/example/src/nccl_example.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <nccl.h>

namespace fbgemm_gpu::experimental {

void example_nccl_code() {
ncclComm_t comms[4];
int devs[4] = {0, 1, 2, 3};
ncclCommInitAll(comms, 4, devs);

for (int i = 0; i < 4; i++) {
ncclCommDestroy(comms[i]);
}
}

} // namespace fbgemm_gpu::experimental
75 changes: 75 additions & 0 deletions fbgemm_gpu/experimental/example/test/triton_example_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
import triton
import triton.language as tl


@triton.jit
def triton_add_kernel(x_ptr, y_ptr, z_ptr, n_elements, BLOCK_SIZE: tl.constexpr):

# We use a 1D launch grid so axis is 0.
pid = tl.program_id(axis=0)

# Compute the offsets in BLOCK_SIZE chunks.
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements

# Load x and y from DRAM.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)

# Sum and write back to DRAM.
output = x + y
tl.store(z_ptr + offsets, output, mask=mask)


def triton_add(x: torch.Tensor, y: torch.Tensor):
# Pre-allocate the output.
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda

# Create the SPMD launch grid. It can be either Tuple[int], or
# Callable(metaparameters) -> Tuple[int]. In this case, we use a 1D grid
# where the size is the number of blocks:
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731


# Launch the kernel.

Check failure on line 49 in fbgemm_gpu/experimental/example/test/triton_example_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.11)

E303 too many blank lines (2)
#
# Each torch.tensor object is implicitly converted into a pointer to its
# first element.
#
# `triton.jit`'ed functions can be indexed with a launch grid to obtain a
# callable GPU kernel.
#
# Pass meta-parameters as keywords arguments.
triton_add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been
# called, the kernel is still running asynchronously at this point.
return output


@unittest.skipIf(
not torch.cuda.is_available(),
"Requires CUDA to run",
)
class TestTriton(unittest.TestCase):
def test_triton_example(self) -> None:
size = 98432
X = torch.rand(size, device="cuda")
Y = torch.rand(size, device="cuda")

torch.testing.assert_close(triton_add(X, Y).cpu(), (X + Y).cpu())
23 changes: 11 additions & 12 deletions fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ if(FBGEMM_GENAI_ONLY)
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
# PyTorch
${TORCH_INCLUDE_DIRS}
# Third-party
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include
${THIRDPARTY}/cutlass/tools/util/include)

set(third_party_include_directories
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include)
${THIRDPARTY}/cutlass/tools/util/include
${NCCL_INCLUDE_DIR})
endif()

set(attention_ops_sources
Expand Down Expand Up @@ -98,18 +96,19 @@ if(USE_ROCM)
${ROCRAND_INCLUDE}
${ROCM_SMI_INCLUDE})

list(GET TORCH_INCLUDE_DIRS 0 TORCH_PATH)

else()
# Else create a CUDA library
add_library(fbgemm_gpu_experimental_gen_ai_py MODULE
${experimental_gen_ai_cpp_source_files})
endif()

# Link to PyTorch
target_include_directories(fbgemm_gpu_experimental_gen_ai_py
PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(fbgemm_gpu_experimental_gen_ai_py ${TORCH_LIBRARIES})
target_include_directories(fbgemm_gpu_experimental_gen_ai_py PRIVATE
${TORCH_INCLUDE_DIRS}
${NCCL_INCLUDE_DIR})

target_link_libraries(fbgemm_gpu_experimental_gen_ai_py
${TORCH_LIBRARIES}
${NCCL_LIB_DIR})

# Remove `lib` from the output artifact name
set_target_properties(fbgemm_gpu_experimental_gen_ai_py PROPERTIES PREFIX "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict

# pyre-ignore-all-errors[56]

import unittest
Expand Down

0 comments on commit d29c203

Please sign in to comment.