Skip to content

yifuwang/symm-mem-recipes

Repository files navigation

symm-mem-recipes

This repository includes:

  • Usage and benchmarks of SymmetricMemory-based multi-GPU algorithms in PyTorch.
  • Examples and benchmarks of multi-GPU algorithms built with SymmetricMemory + Triton.

symm_mem_all_reduce.py

This script demonstrates the usage of SymmetricMemory-based NVLink all-reduce implementations and benchmarks their performance. The available variants are:

  • multimem_all_reduce (PyTorch op available in nightly)
  • one_shot_all_reduce (PyTorch op available in nightly)
  • two_shot_all_reduce (PyTorch op available in nightly)
  • triton_multimem_all_reduce (Triton kernel defined in this repo)
  • triton_one_shot_all_reduce (Triton kernel defined in this repo)

Usage:

torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 symm_mem_all_reduce.py --impl multimem_all_reduce

Some benchmarks on 8xH100 with NVSwitch:


triton_all_gather_matmul.py

This is a fused all-gather matmul example using Triton + SymmetricMemory, based on the tma_persistent Triton tutorial with slight modifications.

This example requires PyTorch Nightly and Triton 3.0.0+ to run.

Usage:

torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 triton_all_gather_matmul.py \
--M 16384 --N 6656 --K 16384 --BLOCK_SIZE_M 128 --BLOCK_SIZE_N 256 --BLOCK_SIZE_K 64

Some benchmarks on 8xH100 (special version with HBM2e, at 650W) with NVSwitch:

Llama 3 8B (N=1792, K=4096)

Problem Size
(M)
Config1 cuBLAS MM
Only (µs)
Triton MM
Only (µs)
cuBLAS +
NCCL (µs)
Triton
Fused (µs)
Speedup
4096 64,128,128,4 100 142 223 211 1.05x2
8192 128,128,64,6 186 198 393 293 1.34x
16384 128,256,64,3 363 363 748 485 1.54x

Llama 3 70B (N=3584, K=8192)

Problem Size
(M)
Config1 cuBLAS MM
Only (µs)
Triton MM
Only (µs)
cuBLAS +
NCCL (µs)
Triton
Fused (µs)
Speedup
4096 128,128,64,6 376 392 587 453 1.29x
8192 128,256,64,3 746 706 1168 821 1.42x
16384 128,256,64,3 1502 1403 2306 1566 1.47x

Llama 3 105B (N=6656, K=16384)

Problem Size
(M)
Config1 cuBLAS MM
Only (µs)
Triton MM
Only (µs)
cuBLAS +
NCCL (µs)
Triton
Fused (µs)
Speedup
4096 128,256,64,3 1358 1425 1858 1615 1.15x
8192 128,256,64,3 2567 2656 3533 2907 1.22x
16384 128,256,64,3 5249 5375 6982 5814 1.20x

1 Config refers to BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, and num_stages.

2 For this problem size, using multicast all-gather would be a more suitable optimization.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages