diff --git a/library/src/amd_detail/hipblaslt-ext.cpp b/library/src/amd_detail/hipblaslt-ext.cpp index 4b59e9108..f4023cdb5 100644 --- a/library/src/amd_detail/hipblaslt-ext.cpp +++ b/library/src/amd_detail/hipblaslt-ext.cpp @@ -27,6 +27,7 @@ #include "hipblaslt-ext.hpp" #include "exceptions.hpp" #include "hipblaslt_internal.hpp" +#include #include #include #include @@ -89,21 +90,21 @@ namespace hipblaslt_ext { } - GemmProblemTypeV2::GemmProblemTypeV2(hipblasOperation_t opA, - hipblasOperation_t opB, - hipDataType typeA, - hipDataType typeB, - hipDataType typeC, - hipDataType typeD, + GemmProblemTypeV2::GemmProblemTypeV2(hipblasOperation_t opA, + hipblasOperation_t opB, + hipDataType typeA, + hipDataType typeB, + hipDataType typeC, + hipDataType typeD, hipblasComputeType_t typeCompute) : pimpl(std::make_unique()) { - pimpl->op_a = opA; - pimpl->op_b = opB; - pimpl->type_a = typeA; - pimpl->type_b = typeB; - pimpl->type_c = typeC; - pimpl->type_d = typeD; + pimpl->op_a = opA; + pimpl->op_b = opB; + pimpl->type_a = typeA; + pimpl->type_b = typeB; + pimpl->type_c = typeC; + pimpl->type_d = typeD; pimpl->type_compute = typeCompute; } @@ -289,7 +290,7 @@ namespace hipblaslt_ext { public: u_int16_t splitK = 0; - int16_t wgm = 0; + int16_t wgm = 0; }; GemmTuningV2::GemmTuningV2() @@ -595,6 +596,10 @@ namespace hipblaslt_ext return m_workspace_bytes; } + //////////////////////////////////////////////////////////// + // Gemm Instance + //////////////////////////////////////////////////////////// + GemmInstance::GemmInstance(hipblasLtHandle_t handle, GemmType type) : m_gemm_type(type) , m_handle(handle) @@ -619,19 +624,25 @@ namespace hipblaslt_ext const GemmPreference& pref, std::vector& heuristicResults) { + rocblaslt::Debug::Instance().markerStart("hipblasLtAlgoGetHeuristicCpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; + } auto gemmType = static_cast(m_gemm_type); auto results = reinterpret_cast*>(&heuristicResults); results->clear(); - return RocBlasLtStatusToHIPStatus( + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_algo_get_heuristic_cpp((rocblaslt_handle)m_handle, gemmType, m_data, pref.getMaxWorkspaceBytes(), requestedAlgoCount, *results)); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t GemmInstance::algoGetHeuristic( @@ -639,30 +650,39 @@ namespace hipblaslt_ext const GemmPreferenceV2& pref, std::vector& heuristicResults) { + rocblaslt::Debug::Instance().markerStart("hipblasLtAlgoGetHeuristicV2Cpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; + } auto gemmType = static_cast(m_gemm_type); auto results = reinterpret_cast*>(&heuristicResults); results->clear(); - return RocBlasLtStatusToHIPStatus( + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_algo_get_heuristic_cpp((rocblaslt_handle)m_handle, gemmType, m_data, pref.pimpl->workspace_bytes, requestedAlgoCount, *results)); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t GemmInstance::isAlgoSupported(hipblasLtMatmulAlgo_t& algo, size_t& workspaceSizeInBytes) try { - auto gemmType = static_cast(m_gemm_type); - auto rocalgo = reinterpret_cast(&algo); - rocblaslt::RocTuningV2 *tuning = nullptr; - return RocBlasLtStatusToHIPStatus(rocblaslt_is_algo_supported_cpp( + rocblaslt::Debug::Instance().markerStart("hipblasLtIsAlgoSupportedCpp"); + auto gemmType = static_cast(m_gemm_type); + auto rocalgo = reinterpret_cast(&algo); + rocblaslt::RocTuningV2* tuning = nullptr; + auto status = RocBlasLtStatusToHIPStatus(rocblaslt_is_algo_supported_cpp( (rocblaslt_handle)m_handle, gemmType, m_data, *rocalgo, tuning, workspaceSizeInBytes)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -674,16 +694,19 @@ namespace hipblaslt_ext size_t& workspaceSizeInBytes) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtIsAlgoSupportedTuningCpp"); auto gemmType = static_cast(m_gemm_type); auto rocalgo = reinterpret_cast(&algo); auto roctuning = reinterpret_cast(&tuning); - return RocBlasLtStatusToHIPStatus( - rocblaslt_is_algo_supported_cpp((rocblaslt_handle)m_handle, - gemmType, - m_data, - *rocalgo, - roctuning, - workspaceSizeInBytes)); + auto status + = RocBlasLtStatusToHIPStatus(rocblaslt_is_algo_supported_cpp((rocblaslt_handle)m_handle, + gemmType, + m_data, + *rocalgo, + roctuning, + workspaceSizeInBytes)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -695,16 +718,19 @@ namespace hipblaslt_ext size_t& workspaceSizeInBytes) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtIsAlgoSupportedTuningV2Cpp"); auto gemmType = static_cast(m_gemm_type); auto rocalgo = reinterpret_cast(&algo); auto roctuning = reinterpret_cast(&tuning); - return RocBlasLtStatusToHIPStatus( - rocblaslt_is_algo_supported_cpp((rocblaslt_handle)m_handle, - gemmType, - m_data, - *rocalgo, - roctuning, - workspaceSizeInBytes)); + auto status + = RocBlasLtStatusToHIPStatus(rocblaslt_is_algo_supported_cpp((rocblaslt_handle)m_handle, + gemmType, + m_data, + *rocalgo, + roctuning, + workspaceSizeInBytes)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -717,19 +743,26 @@ namespace hipblaslt_ext hipStream_t stream) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtInitializeCpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; - auto gemmType = static_cast(m_gemm_type); - auto rocalgo = reinterpret_cast(&algo); - rocblaslt::RocTuningV2 *tuning = nullptr; - return RocBlasLtStatusToHIPStatus(rocblaslt_makeArgument_cpp((rocblaslt_handle)m_handle, - gemmType, - *rocalgo, - tuning, - workspace, - useUserArgs, - stream, - m_data)); + } + auto gemmType = static_cast(m_gemm_type); + auto rocalgo = reinterpret_cast(&algo); + rocblaslt::RocTuningV2* tuning = nullptr; + auto status + = RocBlasLtStatusToHIPStatus(rocblaslt_makeArgument_cpp((rocblaslt_handle)m_handle, + gemmType, + *rocalgo, + tuning, + workspace, + useUserArgs, + stream, + m_data)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -743,19 +776,26 @@ namespace hipblaslt_ext hipStream_t stream) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtInitializeTuningCpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; + } auto gemmType = static_cast(m_gemm_type); auto rocalgo = reinterpret_cast(&algo); auto roctuning = reinterpret_cast(&tuning); - return RocBlasLtStatusToHIPStatus(rocblaslt_makeArgument_cpp((rocblaslt_handle)m_handle, - gemmType, - *rocalgo, - roctuning, - workspace, - useUserArgs, - stream, - m_data)); + auto status + = RocBlasLtStatusToHIPStatus(rocblaslt_makeArgument_cpp((rocblaslt_handle)m_handle, + gemmType, + *rocalgo, + roctuning, + workspace, + useUserArgs, + stream, + m_data)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -769,19 +809,26 @@ namespace hipblaslt_ext hipStream_t stream) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtInitializeTuningV2Cpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; + } auto gemmType = static_cast(m_gemm_type); auto rocalgo = reinterpret_cast(&algo); auto roctuning = reinterpret_cast(&tuning); - return RocBlasLtStatusToHIPStatus(rocblaslt_makeArgument_cpp((rocblaslt_handle)m_handle, - gemmType, - *rocalgo, - roctuning, - workspace, - useUserArgs, - stream, - m_data)); + auto status + = RocBlasLtStatusToHIPStatus(rocblaslt_makeArgument_cpp((rocblaslt_handle)m_handle, + gemmType, + *rocalgo, + roctuning, + workspace, + useUserArgs, + stream, + m_data)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -791,13 +838,17 @@ namespace hipblaslt_ext hipblasStatus_t GemmInstance::run(hipStream_t stream, hipEvent_t start, hipEvent_t stop) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtRunCpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; + } auto gemmType = static_cast(m_gemm_type); auto status = RocBlasLtStatusToHIPStatus( rocblaslt_run_cpp((rocblaslt_handle)m_handle, gemmType, m_data, stream, start, stop)); - + rocblaslt::Debug::Instance().markerStop(); return status; } catch(...) @@ -829,6 +880,7 @@ namespace hipblaslt_ext hipblasComputeType_t typeCompute) : GemmInstance(handle, GemmType::HIPBLASLT_GEMM) { + rocblaslt::Debug::Instance().markerStart("hipblasLtCreateGemmCpp"); m_problem_types.push_back({opA, opB, typeA, typeB, typeC, typeD, typeCompute}); rocblaslt_init_gemmData((rocblaslt_handle)m_handle, static_cast(m_gemm_type), @@ -841,6 +893,7 @@ namespace hipblaslt_ext (rocblaslt_compute_type)typeCompute, 0, m_data); + rocblaslt::Debug::Instance().markerStop(); } Gemm::Gemm(hipblasLtHandle_t handle, @@ -857,11 +910,13 @@ namespace hipblaslt_ext hipblasLtMatrixLayout_t matD) : GemmInstance(handle, GemmType::HIPBLASLT_GEMM) { + rocblaslt::Debug::Instance().markerStart("hipblasLtCreateGemmCAPICpp"); auto status = setProblem(matmul_descr, alpha, A, matA, B, matB, beta, C, matC, D, matD); if(status != HIPBLAS_STATUS_SUCCESS) { std::cout << "Failed to create instance " << status << std::endl; } + rocblaslt::Debug::Instance().markerStop(); } Gemm::Gemm(Gemm&&) noexcept = default; @@ -874,8 +929,10 @@ namespace hipblaslt_ext GemmEpilogue& epilogue, GemmInputs& inputs) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGemmSetProblemCpp"); if(n == 0 || m == 0) { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; } @@ -885,21 +942,23 @@ namespace hipblaslt_ext int64_t strideA = m * k; int64_t strideB = n * k; int64_t strideC = m * n; - return setProblem(m, - n, - k, - batch_count, - lda, - ldb, - ldc, - ldc, - strideA, - strideB, - strideC, - strideC, - epilogue, - inputs, - m_problem_types[0]); + auto status = setProblem(m, + n, + k, + batch_count, + lda, + ldb, + ldc, + ldc, + strideA, + strideB, + strideC, + strideC, + epilogue, + inputs, + m_problem_types[0]); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t Gemm::setProblem(int64_t m, @@ -909,17 +968,19 @@ namespace hipblaslt_ext GemmEpilogueV2& epilogue, GemmInputsV2& inputs) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGemmSetProblemV2Cpp"); if(n == 0 || m == 0) { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; } - int64_t lda = m_problem_types[0].op_a == HIPBLAS_OP_N ? m : k; - int64_t ldb = m_problem_types[0].op_b == HIPBLAS_OP_N ? k : n; - int64_t ldc = m; - int64_t strideA = m * k; - int64_t strideB = n * k; - int64_t strideC = m * n; + int64_t lda = m_problem_types[0].op_a == HIPBLAS_OP_N ? m : k; + int64_t ldb = m_problem_types[0].op_b == HIPBLAS_OP_N ? k : n; + int64_t ldc = m; + int64_t strideA = m * k; + int64_t strideB = n * k; + int64_t strideC = m * n; GemmProblemTypeV2 prob(m_problem_types[0].op_a, m_problem_types[0].op_b, m_problem_types[0].type_a, @@ -927,39 +988,42 @@ namespace hipblaslt_ext m_problem_types[0].type_c, m_problem_types[0].type_d, m_problem_types[0].type_compute); - return setProblem(m, - n, - k, - batch_count, - lda, - ldb, - ldc, - ldc, - strideA, - strideB, - strideC, - strideC, - epilogue, - inputs, - prob); + auto status = setProblem(m, + n, + k, + batch_count, + lda, + ldb, + ldc, + ldc, + strideA, + strideB, + strideC, + strideC, + epilogue, + inputs, + prob); + rocblaslt::Debug::Instance().markerStop(); + return status; } - hipblasStatus_t Gemm::setProblem(int64_t m, - int64_t n, - int64_t k, - int64_t batch_count, - int64_t lda, - int64_t ldb, - int64_t ldc, - int64_t ldd, - int64_t strideA, - int64_t strideB, - int64_t strideC, - int64_t strideD, - GemmEpilogue& epilogue, - GemmInputs& inputs, - GemmProblemType& problemtype) - { + hipblasStatus_t Gemm::setProblem(int64_t m, + int64_t n, + int64_t k, + int64_t batch_count, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t ldd, + int64_t strideA, + int64_t strideB, + int64_t strideC, + int64_t strideD, + GemmEpilogue& epilogue, + GemmInputs& inputs, + GemmProblemType& problemtype) + { + rocblaslt::Debug::Instance().markerStart("hipblasLtGemmSetProblemFullCpp"); GemmInputs gemmInputs = inputs; GemmProblemType gemmProblemType = problemtype; auto rocepilogue = reinterpret_cast(&epilogue); @@ -988,6 +1052,7 @@ namespace hipblaslt_ext { m_problem_types[0] = problemtype; } + rocblaslt::Debug::Instance().markerStop(); return status; } @@ -1007,9 +1072,11 @@ namespace hipblaslt_ext GemmInputsV2& inputs, GemmProblemTypeV2& problemtype) { - auto rocepilogue = reinterpret_cast(epilogue.pimpl.get()); - auto rocepinputs = reinterpret_cast(inputs.pimpl.get()); - auto rocproblemtype = reinterpret_cast(problemtype.pimpl.get()); + rocblaslt::Debug::Instance().markerStart("hipblasLtGemmSetProblemFullV2Cpp"); + auto rocepilogue = reinterpret_cast(epilogue.pimpl.get()); + auto rocepinputs = reinterpret_cast(inputs.pimpl.get()); + auto rocproblemtype + = reinterpret_cast(problemtype.pimpl.get()); auto status = RocBlasLtStatusToHIPStatus(rocblaslt_gemm_create_cpp((rocblaslt_handle)m_handle, m, @@ -1039,6 +1106,7 @@ namespace hipblaslt_ext problemtype.getTypeD(), problemtype.getTypeCompute()}; } + rocblaslt::Debug::Instance().markerStop(); return status; } @@ -1054,9 +1122,10 @@ namespace hipblaslt_ext void* D, hipblasLtMatrixLayout_t matD) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGemmSetProblemCAPICpp"); auto rocproblemtypes = reinterpret_cast*>(&m_problem_types); - return RocBlasLtStatusToHIPStatus( + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_gemm_create_cpp((rocblaslt_handle)m_handle, (rocblaslt_matmul_desc)matmul_descr, alpha, @@ -1072,6 +1141,8 @@ namespace hipblaslt_ext (*rocproblemtypes)[0], m_data, m_gemm_count)); + rocblaslt::Debug::Instance().markerStop(); + return status; } GemmProblemType Gemm::getProblemTypes() @@ -1101,6 +1172,7 @@ namespace hipblaslt_ext hipblasComputeType_t typeCompute) : GemmInstance(handle, GemmType::HIPBLASLT_GROUPED_GEMM) { + rocblaslt::Debug::Instance().markerStart("hipblasLtCreateGroupedGemmCpp"); m_problem_types.push_back({opA, opB, typeA, typeB, typeC, typeD, typeCompute}); rocblaslt_init_gemmData((rocblaslt_handle)m_handle, static_cast(m_gemm_type), @@ -1113,6 +1185,7 @@ namespace hipblaslt_ext (rocblaslt_compute_type)typeCompute, 0, m_data); + rocblaslt::Debug::Instance().markerStop(); } GroupedGemm::GroupedGemm(GroupedGemm&&) noexcept = default; @@ -1132,11 +1205,13 @@ namespace hipblaslt_ext std::vector& matD) : GemmInstance(handle, GemmType::HIPBLASLT_GROUPED_GEMM) { + rocblaslt::Debug::Instance().markerStart("hipblasLtCreateGroupedGemmCAPICpp"); auto status = setProblem(matmul_descr, alpha, A, matA, B, matB, beta, C, matC, D, matD); if(status != HIPBLAS_STATUS_SUCCESS) { std::cout << "Failed to create instance " << status << std::endl; } + rocblaslt::Debug::Instance().markerStop(); } hipblasStatus_t GroupedGemm::setProblem(std::vector& m, @@ -1146,6 +1221,7 @@ namespace hipblaslt_ext std::vector& epilogue, std::vector& inputs) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmSetProblemCpp"); std::vector lda; std::vector ldb; std::vector ldc; @@ -1166,21 +1242,23 @@ namespace hipblaslt_ext strideC.push_back(m[i] * k[i]); strideD.push_back(m[i] * k[i]); } - return setProblem(m, - n, - k, - batch_count, - lda, - ldb, - ldc, - ldd, - strideA, - strideB, - strideC, - strideD, - epilogue, - inputs, - m_problem_types[0]); + auto status = setProblem(m, + n, + k, + batch_count, + lda, + ldb, + ldc, + ldd, + strideA, + strideB, + strideC, + strideD, + epilogue, + inputs, + m_problem_types[0]); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t GroupedGemm::setProblem(std::vector& m, @@ -1190,6 +1268,7 @@ namespace hipblaslt_ext std::vector& epilogue, std::vector& inputs) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmSetProblemV2Cpp"); std::vector lda; std::vector ldb; std::vector ldc; @@ -1217,21 +1296,23 @@ namespace hipblaslt_ext m_problem_types[0].type_c, m_problem_types[0].type_d, m_problem_types[0].type_compute); - return setProblem(m, - n, - k, - batch_count, - lda, - ldb, - ldc, - ldd, - strideA, - strideB, - strideC, - strideD, - epilogue, - inputs, - prob); + auto status = setProblem(m, + n, + k, + batch_count, + lda, + ldb, + ldc, + ldd, + strideA, + strideB, + strideC, + strideD, + epilogue, + inputs, + prob); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t GroupedGemm::setProblem(std::vector& m, @@ -1250,6 +1331,7 @@ namespace hipblaslt_ext std::vector& inputs, GemmProblemType& problemtype) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmSetProblemFullCpp"); auto rocepilogue = reinterpret_cast*>(&epilogue); auto rocinputs = reinterpret_cast*>(&inputs); std::vector tmptype = {problemtype}; @@ -1278,6 +1360,7 @@ namespace hipblaslt_ext { m_problem_types = tmptype; } + rocblaslt::Debug::Instance().markerStop(); return status; } @@ -1297,6 +1380,7 @@ namespace hipblaslt_ext std::vector& inputs, GemmProblemTypeV2& problemtype) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmSetProblemFullV2Cpp"); std::vector rocepilogue; for(auto& e : epilogue) { @@ -1308,8 +1392,9 @@ namespace hipblaslt_ext { rocinputs.push_back(*reinterpret_cast(i.pimpl.get())); } - GemmProblemTypeV2 tmp = problemtype; - std::vector rocproblemtype = {*reinterpret_cast(tmp.pimpl.get())}; + GemmProblemTypeV2 tmp = problemtype; + std::vector rocproblemtype + = {*reinterpret_cast(tmp.pimpl.get())}; auto status = RocBlasLtStatusToHIPStatus( rocblaslt_groupedgemm_create_cpp((rocblaslt_handle)m_handle, m, @@ -1339,6 +1424,7 @@ namespace hipblaslt_ext problemtype.getTypeD(), problemtype.getTypeCompute()}; } + rocblaslt::Debug::Instance().markerStop(); return status; } @@ -1354,6 +1440,7 @@ namespace hipblaslt_ext std::vector& D, std::vector& matD) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmSetProblemCAPICpp"); auto matmul_descr_groupedGemm = reinterpret_cast*>(&matmul_descr); auto matA_groupedGemm = reinterpret_cast*>(&matA); @@ -1367,7 +1454,7 @@ namespace hipblaslt_ext auto beta_groupedGemm = reinterpret_cast*>(&beta); auto rocproblemtypes = reinterpret_cast*>(&m_problem_types); - return RocBlasLtStatusToHIPStatus( + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_groupedgemm_create_cpp((rocblaslt_handle)m_handle, *matmul_descr_groupedGemm, *alpha_groupedGemm, @@ -1383,6 +1470,8 @@ namespace hipblaslt_ext (*rocproblemtypes), m_data, m_gemm_count)); + rocblaslt::Debug::Instance().markerStop(); + return status; } std::vector GroupedGemm::getProblemTypes() @@ -1409,18 +1498,27 @@ namespace hipblaslt_ext HIPBLASLT_EXPORT hipblasStatus_t GroupedGemm::getDefaultValueForDeviceUserArguments(void* hostDeviceUserArgs) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmGetDefaultUserArgsCpp"); auto gemmType = static_cast(m_gemm_type); - return RocBlasLtStatusToHIPStatus(rocblaslt_get_default_user_args( + auto status = RocBlasLtStatusToHIPStatus(rocblaslt_get_default_user_args( (rocblaslt_handle)m_handle, gemmType, m_data, hostDeviceUserArgs)); + rocblaslt::Debug::Instance().markerStop(); + return status; } HIPBLASLT_EXPORT hipblasStatus_t GroupedGemm::run(void* deviceUserArgs, hipStream_t stream) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGroupedGemmRunCpp"); if(m_gemm_count == 0) + { + rocblaslt::Debug::Instance().markerStop(); return HIPBLAS_STATUS_INVALID_VALUE; + } auto gemmType = static_cast(m_gemm_type); - return RocBlasLtStatusToHIPStatus(rocblaslt_run_user_args_cpp( + auto status = RocBlasLtStatusToHIPStatus(rocblaslt_run_user_args_cpp( (rocblaslt_handle)m_handle, gemmType, m_data, deviceUserArgs, stream)); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t matmulIsAlgoSupported(hipblasLtHandle_t handle, @@ -1435,7 +1533,8 @@ namespace hipblaslt_ext size_t& workspaceSizeInBytes) try { - return RocBlasLtStatusToHIPStatus( + rocblaslt::Debug::Instance().markerStart("hipblasLtMatMulIsAlgoSupportedCpp"); + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_matmul_is_algo_supported((rocblaslt_handle)handle, (rocblaslt_matmul_desc)matmulDesc, alpha, @@ -1446,6 +1545,8 @@ namespace hipblaslt_ext (rocblaslt_matrix_layout)Ddesc, (rocblaslt_matmul_algo*)&algo, &workspaceSizeInBytes)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -1475,10 +1576,11 @@ namespace hipblaslt_ext std::vector& heuristicResults) try { + rocblaslt::Debug::Instance().markerStart("hipblasLtGetAllAlgosCpp"); auto results = reinterpret_cast*>(&heuristicResults); results->clear(); - return RocBlasLtStatusToHIPStatus( + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_matmul_get_all_algos_cpp((rocblaslt_handle)handle, static_cast(typeGemm), opA, @@ -1489,6 +1591,8 @@ namespace hipblaslt_ext typeD, (rocblaslt_compute_type)typeCompute, *results)); + rocblaslt::Debug::Instance().markerStop(); + return status; } catch(...) { @@ -1532,17 +1636,23 @@ namespace hipblaslt_ext std::vector& algoIndex, std::vector& heuristicResults) { + rocblaslt::Debug::Instance().markerStart("hipblasLtGetAlgosFromIndexCpp"); auto results = reinterpret_cast*>(&heuristicResults); results->clear(); - return RocBlasLtStatusToHIPStatus(rocblaslt_matmul_get_algos_from_index_cpp( + auto status = RocBlasLtStatusToHIPStatus(rocblaslt_matmul_get_algos_from_index_cpp( (rocblaslt_handle)handle, algoIndex, *results)); + rocblaslt::Debug::Instance().markerStop(); + return status; } hipblasStatus_t copyMatmul(hipblasLtMatmulDesc_t src, hipblasLtMatmulDesc_t dst) { - return RocBlasLtStatusToHIPStatus( + rocblaslt::Debug::Instance().markerStart("hipblasLtCopyMatmulCpp"); + auto status = RocBlasLtStatusToHIPStatus( rocblaslt_copy_matmul((rocblaslt_matmul_desc)src, (rocblaslt_matmul_desc)dst)); + rocblaslt::Debug::Instance().markerStop(); + return status; } int matmulIsTuned(hipblasLtHandle_t handle, @@ -1552,12 +1662,15 @@ namespace hipblaslt_ext hipblasLtMatrixLayout_t Cdesc, hipblasLtMatrixLayout_t Ddesc) { - return rocblaslt_matmul_is_tuned((rocblaslt_handle)handle, - (rocblaslt_matmul_desc)matmulDesc, - (rocblaslt_matrix_layout)Adesc, - (rocblaslt_matrix_layout)Bdesc, - (rocblaslt_matrix_layout)Cdesc, - (rocblaslt_matrix_layout)Ddesc); + rocblaslt::Debug::Instance().markerStart("hipblasLtMatmulIsTunedCpp"); + auto status = rocblaslt_matmul_is_tuned((rocblaslt_handle)handle, + (rocblaslt_matmul_desc)matmulDesc, + (rocblaslt_matrix_layout)Adesc, + (rocblaslt_matrix_layout)Bdesc, + (rocblaslt_matrix_layout)Cdesc, + (rocblaslt_matrix_layout)Ddesc); + rocblaslt::Debug::Instance().markerStop(); + return status; } } // End of namespace hipblasltext