Skip to content

Commit

Permalink
gfx12 support (#9) (#384)
Browse files Browse the repository at this point in the history
Add gfx12 support
  • Loading branch information
stanleytsang-amd authored Jul 9, 2024
1 parent 5050406 commit 41f6679
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 64 deletions.
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ if(NOT (CMAKE_CXX_COMPILER MATCHES ".*nvcc$" OR "${CMAKE_CXX_COMPILER_ID}" STREQ
)
else()
rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS
TARGETS "gfx803;gfx900:xnack-;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
TARGETS "gfx803;gfx900:xnack-;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
)
endif()

set(GPU_TARGETS "${DEFAULT_AMDGPU_TARGETS}" CACHE STRING "GPU architectures to compile for" FORCE)
endif()
endif()
Expand Down
61 changes: 34 additions & 27 deletions hipcub/include/hipcub/backend/rocprim/thread/thread_load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,49 +60,56 @@ HIPCUB_DEVICE __forceinline__ T AsmThreadLoad(void * ptr)
interim_type, \
asm_operator, \
output_modifier, \
wait_inst, \
wait_cmd) \
template<> \
HIPCUB_DEVICE __forceinline__ type AsmThreadLoad<cache_modifier, type>(void * ptr) \
{ \
interim_type retval; \
asm volatile( \
#asm_operator " %0, %1 " llvm_cache_modifier "\n" \
"\ts_waitcnt " wait_cmd "(0)" : "=" #output_modifier(retval) : "v"(ptr) \
); \
asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier "\n\t" \
wait_inst wait_cmd "(%2)" \
: "=" #output_modifier(retval) \
: "v"(ptr), "I"(0x00)); \
return retval; \
}

// TODO Add specialization for custom larger data types
#define HIPCUB_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_cmd);
#define HIPCUB_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_inst, wait_cmd) \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd);


#if defined(__gfx940__) || defined(__gfx941__)
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc1", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0 sc1", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0 sc1", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", "s_waitcnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc1", "s_waitcnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0 sc1", "s_waitcnt", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0 sc1", "s_waitcnt", "vmcnt");
#elif defined(__gfx942__)
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc0 nt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "sc0", "s_waitcnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "sc0 nt", "s_waitcnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "sc0", "s_waitcnt", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "sc0", "s_waitcnt", "vmcnt");
#elif defined(__gfx1200__) || defined(__gfx1201__)
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_loadcnt_dscnt", "");
#else
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "glc", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "glc slc", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "glc", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "glc", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "glc", "s_waitcnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "glc slc", "s_waitcnt", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "glc", "s_waitcnt", "vmcnt");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "glc", "s_waitcnt", "vmcnt");
#endif

// TODO find correct modifiers to match these
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_LDG, "", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CS, "", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_LDG, "", "", "");
HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CS, "", "", "");

#endif

Expand Down
58 changes: 32 additions & 26 deletions hipcub/include/hipcub/backend/rocprim/thread/thread_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,46 +62,52 @@ HIPCUB_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val)
interim_type, \
asm_operator, \
output_modifier, \
wait_inst, \
wait_cmd) \
template<> \
HIPCUB_DEVICE __forceinline__ void AsmThreadStore<cache_modifier, type>(void * ptr, type val) \
{ \
interim_type temp_val = val; \
asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier : : "v"(ptr), #output_modifier(temp_val)); \
asm volatile("s_waitcnt " wait_cmd "(%0)" : : "I"(0x00)); \
interim_type temp_val = val; \
asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier "\n\t" \
wait_inst wait_cmd "(%2)" \
: : "v"(ptr), #output_modifier(temp_val), "I"(0x00)); \
}

// TODO fix flat_store_ubyte and flat_store_sbyte issues
// TODO Add specialization for custom larger data types
#define HIPCUB_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_store_byte, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_store_short, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_store_dword, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_cmd);
#define HIPCUB_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_inst, wait_cmd) \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_store_byte, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_store_short, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_store_dword, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd); \
HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd);

#if defined(__gfx940__) || defined(__gfx941__)
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0 sc1", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 sc1", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0 sc1", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0 sc1", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0 sc1", "s_waitcnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 sc1", "s_waitcnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0 sc1", "s_waitcnt", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0 sc1", "s_waitcnt", "vmcnt");
#elif defined(__gfx942__)
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 nt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "sc0", "s_waitcnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "sc0 nt", "s_waitcnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "sc0", "s_waitcnt", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "sc0", "s_waitcnt", "vmcnt");
#elif defined(__gfx1200__) || defined(__gfx1201__)
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "th:TH_DEFAULT scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "scope:SCOPE_DEV", "s_wait_storecnt_dscnt", "");
#else
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "glc", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "glc slc", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "glc", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "glc", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "glc", "s_waitcnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "glc slc", "s_waitcnt", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "glc", "s_waitcnt", "vmcnt");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "glc", "s_waitcnt", "vmcnt");
#endif

// TODO find correct modifiers to match these
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CS, "", "");
HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CS, "", "", "");

#endif

Expand Down
18 changes: 9 additions & 9 deletions test/hipcub/test_hipcub_iterators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ TYPED_TEST(HipcubIteratorTests, TestTexObj)
hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
std::string deviceName = std::string(props.gcnArchName);
if (deviceName.rfind("gfx94", 0) == 0) {
// This is a gfx94x device, so skip this test
GTEST_SKIP() << "Test not run on gfx94x as texture cache API is not supported";
if (deviceName.rfind("gfx94", 0) == 0 || deviceName.rfind("gfx120") == 0) {
// This is a gfx94x or gfx120x device, so skip this test
GTEST_SKIP() << "Test not run on gfx94x or gfx120x as texture cache API is not supported";
}

HIP_CHECK(hipSetDevice(device_id));
Expand Down Expand Up @@ -411,9 +411,9 @@ TYPED_TEST(HipcubIteratorTests, TestTexRef)
hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
std::string deviceName = std::string(props.gcnArchName);
if (deviceName.rfind("gfx94", 0) == 0) {
// This is a gfx94x device, so skip this test
GTEST_SKIP() << "Test not run on gfx94x as texture cache API is not supported";
if (deviceName.rfind("gfx94", 0) == 0 || deviceName.rfind("gfx120") == 0) {
// This is a gfx94x or gfx120x device, so skip this test
GTEST_SKIP() << "Test not run on gfx94x or gfx120x as texture cache API is not supported";
}

HIP_CHECK(hipSetDevice(device_id));
Expand Down Expand Up @@ -482,9 +482,9 @@ TYPED_TEST(HipcubIteratorTests, TestTexTransform)
hipDeviceProp_t props;
HIP_CHECK(hipGetDeviceProperties(&props, device_id));
std::string deviceName = std::string(props.gcnArchName);
if (deviceName.rfind("gfx94", 0) == 0) {
// This is a gfx94x device, so skip this test
GTEST_SKIP() << "Test not run on gfx94x as texture cache API is not supported";
if (deviceName.rfind("gfx94", 0) == 0 || deviceName.rfind("gfx120") == 0) {
// This is a gfx94x or gfx120x device, so skip this test
GTEST_SKIP() << "Test not run on gfx94x or gfx120x as texture cache API is not supported";
}

HIP_CHECK(hipSetDevice(device_id));
Expand Down

0 comments on commit 41f6679

Please sign in to comment.