Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
q10 committed May 22, 2024
1 parent adca417 commit 19cfbc6
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 25 deletions.
35 changes: 26 additions & 9 deletions .github/scripts/fbgemm_gpu_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,22 @@ __configure_fbgemm_gpu_build_nvcc () {
# shellcheck disable=SC2155
local env_prefix=$(env_name_or_prefix "${env_name}")

echo "[BUILD] Looking up CUDA version ..."
# shellcheck disable=SC2155,SC2086
local cxx_path=$(conda run ${env_prefix} which c++)
# shellcheck disable=SC2155,SC2086
local cuda_version=$(conda run ${env_prefix} nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p')
# shellcheck disable=SC2206
local cuda_version_arr=(${cuda_version//./ })

echo "[BUILD] Looking up NCCL path ..."
# shellcheck disable=SC2155,SC2086
local conda_prefix=$(conda run ${env_prefix} printenv CONDA_PREFIX)
# shellcheck disable=SC2155,SC2086
local nccl_lib=$(conda run ${env_prefix} find ${conda_prefix} -name "libnccl.so*")
# shellcheck disable=SC2155,SC2086
local nccl_path=$(dirname "$(dirname ${nccl_lib})")

# Only NVCC 12+ supports C++20
if [[ ${cuda_version_arr[0]} -lt 12 ]]; then
local cppstd_ver=17
Expand All @@ -109,11 +118,16 @@ __configure_fbgemm_gpu_build_nvcc () {
# shellcheck disable=SC2086
print_exec conda env config vars set ${env_prefix} NVCC_PREPEND_FLAGS=\"${nvcc_prepend_flags}\"

echo "[BUILD] Setting CUDA build args ..."
# shellcheck disable=SC2155
local cxx_flags="-DNCCL_INCLUDE_DIR=${nccl_path}/include -DNCCL_LIB_DIR=${nccl_path}/lib"

# shellcheck disable=SC2206
build_args+=(
# Override CMake configuration
-DCMAKE_CXX_STANDARD="${cppstd_ver}"
-DHIP_STANDARD="${cppstd_ver}"
-DCMAKE_C_FLAGS="'${cxx_flags}'"
-DCMAKE_CXX_FLAGS="'${cxx_flags}'"
)
}

Expand Down Expand Up @@ -158,14 +172,17 @@ __configure_fbgemm_gpu_build_rocm () {
print_exec conda env config vars set ${env_prefix} PYTORCH_ROCM_ARCH="${arch_list}"

echo "[BUILD] Setting ROCm build args ..."
# shellcheck disable=SC2155
local cxx_flags="-DTORCH_USE_HIP_DSA"

build_args=(
--package_variant=rocm
# HIP_ROOT_DIR now required for HIP to be correctly detected by CMake
-DHIP_ROOT_DIR=/opt/rocm
# Enable device-side assertions in HIP
# https://stackoverflow.com/questions/44284275/passing-compiler-options-in-cmake-command-line
-DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA"
-DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA"
-DCMAKE_C_FLAGS="'${cxx_flags}'"
-DCMAKE_CXX_FLAGS="'${cxx_flags}'"
)
}

Expand Down Expand Up @@ -473,18 +490,18 @@ build_fbgemm_gpu_package () {
# shellcheck disable=SC2086
print_exec conda run --no-capture-output ${env_prefix} \
python -m build --wheel --no-isolation \
"${build_args[@]}"
"${build_args[@]}" || return 1

# Run checks on the built libraries
(run_fbgemm_gpu_postbuild_checks "${fbgemm_variant}") || return 1

echo "[BUILD] Enumerating the built wheels ..."
print_exec ls -lth dist/*.whl
print_exec ls -lth dist/*.whl || return 1

echo "[BUILD] Enumerating the wheel SHAs ..."
print_exec sha1sum dist/*.whl
print_exec sha256sum dist/*.whl
print_exec md5sum dist/*.whl
print_exec sha1sum dist/*.whl || return 1
print_exec sha256sum dist/*.whl || return 1
print_exec md5sum dist/*.whl || return 1

echo "[BUILD] FBGEMM-GPU build + package completed"
}
Expand Down Expand Up @@ -524,7 +541,7 @@ build_fbgemm_gpu_install () {
# shellcheck disable=SC2086
print_exec conda run --no-capture-output ${env_prefix} \
python setup.py "${run_multicore}" install \
"${build_args[@]}"
"${build_args[@]}" || return 1

# Run checks on the built libraries
(run_fbgemm_gpu_postbuild_checks "${fbgemm_variant}") || return 1
Expand Down
13 changes: 9 additions & 4 deletions fbgemm_gpu/experimental/example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ if(FBGEMM_GENAI_ONLY)
${CMAKE_CURRENT_SOURCE_DIR}../..
${CMAKE_CURRENT_SOURCE_DIR}../../include
${CMAKE_CURRENT_SOURCE_DIR}../../../include
# PyTorch
${TORCH_INCLUDE_DIRS}
# Third-party
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include
${THIRDPARTY}/cutlass/tools/util/include)
${THIRDPARTY}/cutlass/tools/util/include
${NCCL_INCLUDE_DIR})

set(third_party_include_directories
${THIRDPARTY}/asmjit/src
Expand All @@ -32,7 +35,8 @@ endif()

set(experimental_example_cpp_source_files
src/cutlass_sgemm_nn.cu
src/example_ops.cpp)
src/example_ops.cpp
src/nccl_example.cpp)

set_source_files_properties(${experimental_example_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES
Expand All @@ -50,8 +54,9 @@ set(experimental_example_python_source_files
add_library(fbgemm_gpu_experimental_example_py MODULE
${experimental_example_cpp_source_files})

target_include_directories(fbgemm_gpu_experimental_example_py PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(fbgemm_gpu_experimental_example_py ${TORCH_LIBRARIES})
target_link_libraries(fbgemm_gpu_experimental_example_py
${TORCH_LIBRARIES}
${NCCL_LIB_DIR})

# Remove `lib` from the output artifact name
set_target_properties(fbgemm_gpu_experimental_example_py PROPERTIES PREFIX "")
Expand Down
23 changes: 23 additions & 0 deletions fbgemm_gpu/experimental/example/src/nccl_example.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <nccl.h>

namespace fbgemm_gpu::experimental {

void example_nccl_code() {
ncclComm_t comms[4];
int devs[4] = { 0, 1, 2, 3 };
ncclCommInitAll(comms, 4, devs);

for (int i=0; i<4; i++) {
ncclCommDestroy(comms[i]);
}
}

} // namespace fbgemm_gpu::experimental
71 changes: 71 additions & 0 deletions fbgemm_gpu/experimental/example/test/triton_example_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
import triton
import triton.language as tl


@triton.jit
def softmax_triton(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
# Row index
m = tl.program_id(0)

# Column indices. This specific kernel only works for matrices that have
# less than BLOCK_SIZE columns
BLOCK_SIZE = 1024
n = tl.arange(0, BLOCK_SIZE)

# Compute the memory address of all the elements that we want to load
X = X + m * stride_xm + n * stride_xn

# Load input data; pad out-of-bounds elements with 0
x = tl.load(X, mask=n < N, other=-float('inf'))

# Compute numerically-stable softmax
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom

# write back to Y
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask=n < N)


@torch.jit.script
def softmax_torch(x):
x_max = x.max(dim=1)[0]

Check failure on line 45 in fbgemm_gpu/experimental/example/test/triton_example_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.11)

F841 local variable 'x_max' is assigned to but never used
numerator = torch.exp(x)
denominator = numerator.sum(dim=1)
return numerator / denominator[:, None]


@unittest.skipIf(
not torch.cuda.is_available(),
"Requires CUDA to run",
)
class TestTriton(unittest.TestCase):
def test_triton_example(self) -> None:
# Allocate input/output tensors
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)

# SPMD launch grid
grid = (X.shape[0], )

# Enqueue GPU kernel
softmax_triton[grid](
Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0] , X.shape[1]
)

torch.testing.assert_close(Y.cpu(), softmax_torch(X).cpu())
20 changes: 8 additions & 12 deletions fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ if(FBGEMM_GENAI_ONLY)
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
# PyTorch
${TORCH_INCLUDE_DIRS}
# Third-party
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include
${THIRDPARTY}/cutlass/tools/util/include)

set(third_party_include_directories
${THIRDPARTY}/asmjit/src
${THIRDPARTY}/cpuinfo/include
${THIRDPARTY}/cutlass/include)
${THIRDPARTY}/cutlass/tools/util/include
${NCCL_INCLUDE_DIR})
endif()

set(attention_ops_sources
Expand Down Expand Up @@ -93,18 +91,16 @@ if(USE_ROCM)
${ROCRAND_INCLUDE}
${ROCM_SMI_INCLUDE})

list(GET TORCH_INCLUDE_DIRS 0 TORCH_PATH)

else()
# Else create a CUDA library
add_library(fbgemm_gpu_experimental_gen_ai_py MODULE
${experimental_gen_ai_cpp_source_files})
endif()

# Link to PyTorch
target_include_directories(fbgemm_gpu_experimental_gen_ai_py
PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(fbgemm_gpu_experimental_gen_ai_py ${TORCH_LIBRARIES})
# Link to PyTorch and NCCL
target_link_libraries(fbgemm_gpu_experimental_gen_ai_py
${TORCH_LIBRARIES}
${NCCL_LIB_DIR})

# Remove `lib` from the output artifact name
set_target_properties(fbgemm_gpu_experimental_gen_ai_py PROPERTIES PREFIX "")
Expand Down

0 comments on commit 19cfbc6

Please sign in to comment.