diff --git a/README.rst b/README.rst index 4ae21c2e..0514ee48 100644 --- a/README.rst +++ b/README.rst @@ -33,6 +33,9 @@ Documentation `onnx-extended Source are available on `github/onnx-extended `_. +Use C++ implementation of existing operators +++++++++++++++++++++++++++++++++++++++++++++ + .. code-block:: python import timeit @@ -89,3 +92,25 @@ Source are available on `github/onnx-extended difference: 0.0 onnx: 0.024006774998269975 onnx-extended: 0.0002316169993719086 + +Build with CUDA, openmp ++++++++++++++++++++++++ + +The package also contains some dummy example on how to +build with C++ functions (`pybind11 `_, +`cython `_), with `openmp +`_, +with or without CUDA. +The build will automatically link with CUDA if it is found. +If not, some extensions might not be available. + +:: + + python setup.py build_ext --inplace + +`NVTX `_ +can be enabled with the following command: + +:: + + python setup.py build_ext --inplace --enable_nvtx 1 diff --git a/_doc/_static/vector_sum6.png b/_doc/_static/vector_sum6.png new file mode 100644 index 00000000..83177433 Binary files /dev/null and b/_doc/_static/vector_sum6.png differ diff --git a/_doc/_static/vector_sum6_results.png b/_doc/_static/vector_sum6_results.png new file mode 100644 index 00000000..5dfa3910 Binary files /dev/null and b/_doc/_static/vector_sum6_results.png differ diff --git a/_doc/api/reference.rst b/_doc/api/reference.rst index 57932005..2f05be19 100644 --- a/_doc/api/reference.rst +++ b/_doc/api/reference.rst @@ -20,6 +20,10 @@ ai.onnx ai.onnx.ml ++++++++++ -.. autoclass:: onnx_extended.reference.c_ops.c_op_tree_ensemble_classifier.TreeEnsembleClassifier +.. autoclass:: onnx_extended.reference.c_ops.c_op_tree_ensemble_classifier.TreeEnsembleClassifier_1 -.. autoclass:: onnx_extended.reference.c_ops.c_op_tree_ensemble_regressor.TreeEnsembleRegresspr +.. autoclass:: onnx_extended.reference.c_ops.c_op_tree_ensemble_classifier.TreeEnsembleClassifier_3 + +.. autoclass:: onnx_extended.reference.c_ops.c_op_tree_ensemble_regressor.TreeEnsembleRegressor_1 + +.. autoclass:: onnx_extended.reference.c_ops.c_op_tree_ensemble_regressor.TreeEnsembleRegressor_3 diff --git a/_doc/api/validation.rst b/_doc/api/validation.rst index 8dce6c49..ce2260f6 100644 --- a/_doc/api/validation.rst +++ b/_doc/api/validation.rst @@ -10,32 +10,34 @@ C API _validation +++++++++++ -.. autoclass:: onnx_extended.validation._validation.ElementTime +.. autoclass:: onnx_extended.validation.cpu._validation.ElementTime -.. autofunction:: onnx_extended.validation._validation.benchmark_cache +.. autofunction:: onnx_extended.validation.cpu._validation.benchmark_cache -.. autofunction:: onnx_extended.validation._validation.benchmark_cache_tree +.. autofunction:: onnx_extended.validation.cpu._validation.benchmark_cache_tree -.. autofunction:: onnx_extended.validation._validation.vector_add +.. autofunction:: onnx_extended.validation.cpu._validation.vector_add -.. autofunction:: onnx_extended.validation._validation.vector_sum +.. autofunction:: onnx_extended.validation.cpu._validation.vector_sum -.. autofunction:: onnx_extended.validation._validation.vector_sum_array +.. autofunction:: onnx_extended.validation.cpu._validation.vector_sum_array -.. autofunction:: onnx_extended.validation._validation.vector_sum_array_parallel +.. autofunction:: onnx_extended.validation.cpu._validation.vector_sum_array_parallel -.. autofunction:: onnx_extended.validation._validation.vector_sum_array_avx +.. autofunction:: onnx_extended.validation.cpu._validation.vector_sum_array_avx -.. autofunction:: onnx_extended.validation._validation.vector_sum_array_avx_parallel +.. autofunction:: onnx_extended.validation.cpu._validation.vector_sum_array_avx_parallel cuda_example_py +++++++++++++++ -.. autofunction:: onnx_extended.cuda_example_py.vector_add +.. autofunction:: onnx_extended.validation.cuda.cuda_example_py.vector_add -.. autofunction:: onnx_extended.cuda_example_py.vector_sum +.. autofunction:: onnx_extended.validation.cuda.cuda_example_py.vector_sum0 + +.. autofunction:: onnx_extended.validation.cuda.cuda_example_py.vector_sum6 vector_function_cy ++++++++++++++++++ -.. autofunction:: onnx_extended.vector_function_cy.vector_add_c +.. autofunction:: onnx_extended.validation.cython.vector_function_cy.vector_add_c diff --git a/_doc/examples/plot_bench_gpu_vector_sum_gpu.py b/_doc/examples/plot_bench_gpu_vector_sum_gpu.py index 25286097..667e8c1b 100644 --- a/_doc/examples/plot_bench_gpu_vector_sum_gpu.py +++ b/_doc/examples/plot_bench_gpu_vector_sum_gpu.py @@ -23,6 +23,7 @@ try: from onnx_extended.validation.cuda.cuda_example_py import ( vector_sum0, + vector_sum6, vector_sum_atomic, ) except ImportError: @@ -82,29 +83,43 @@ ) ) - diff = abs(vector_sum0(values, 128) - dim**2) - res = measure_time(lambda: vector_sum0(values, 128), max_time=0.5) + diff = abs(vector_sum_atomic(values, 32) - dim**2) + res = measure_time(lambda: vector_sum_atomic(values, 32), max_time=0.5) obs.append( dict( dim=dim, size=values.size, time=res["average"], - direction="0cuda128", + direction="Acuda32", time_per_element=res["average"] / dim**2, diff=diff, ) ) - diff = abs(vector_sum_atomic(values, 32) - dim**2) - res = measure_time(lambda: vector_sum_atomic(values, 32), max_time=0.5) + diff = abs(vector_sum6(values, 32) - dim**2) + res = measure_time(lambda: vector_sum6(values, 32), max_time=0.5) obs.append( dict( dim=dim, size=values.size, time=res["average"], - direction="Acuda32", + direction="6cuda32", + time_per_element=res["average"] / dim**2, + diff=diff, + ) + ) + + diff = abs(vector_sum6(values, 256) - dim**2) + res = measure_time(lambda: vector_sum6(values, 256), max_time=0.5) + + obs.append( + dict( + dim=dim, + size=values.size, + time=res["average"], + direction="6cuda256", time_per_element=res["average"] / dim**2, diff=diff, ) @@ -126,7 +141,45 @@ piv.plot(ax=ax[0], logx=True, title="Comparison between two summation") piv_diff.plot(ax=ax[1], logx=True, logy=True, title="Summation errors") piv_time.plot(ax=ax[2], logx=True, logy=True, title="Total time") -fig.savefig("plot_bench_cpu_vector_sum_avx_parallel.png") +fig.savefig("plot_bench_gpu_vector_sum_gpu.png") ############################################## -# AVX is faster. +# The results should look like the following. +# +# .. image:: ../_static/vector_sum6_results.png +# +# AVX is still faster. Let's try to understand why. +# +# Profiling +# +++++++++ +# +# The profiling indicates where the program is most of the time. +# It shows when the GPU is waiting and when the memory is copied from +# from host (CPU) to device (GPU) and the other way around. There are +# the two steps we need to reduce or avoid to make use of the GPU. +# +# Profiling with `nsight-compute `_: +# +# :: +# +# nsys profile --trace=cuda,cudnn,cublas,osrt,nvtx,openmp python +# +# If `nsys` fails to find `python`, the command `which python` should locate it. +# ` can be `plot_bench_gpu_vector_sum_gpu.py` for example. +# +# Then command `nsys-ui` starts the Visual Interface interface of the profiling. +# A screen shot shows the following after loading the profiling. +# +# .. image:: ../_static/vector_sum6.png +# +# Most of time is spent in copy the data from CPU memory to GPU memory. +# In our case, GPU is not really useful because just copying the data from CPU +# to GPU takes more time than processing it with CPU and AVX instructions. +# +# GPU is useful for deep learning because many operations can be chained and +# the data stays on GPU memory until the very end. When multiple tools are involved, +# torch, numpy, onnxruntime, the `DLPack `_ +# avoids copying the data when switching. +# +# The copy of a big tensor can happens by block. The computation may start +# before the data is fully copied. diff --git a/_doc/tutorial/index.rst b/_doc/tutorial/index.rst index d72b9327..4e4ba4b5 100644 --- a/_doc/tutorial/index.rst +++ b/_doc/tutorial/index.rst @@ -16,7 +16,7 @@ Operators .. toctree:: :maxdepth: 1 - ../autoexemples/plot_conv + ../auto_examples/plot_conv Validation, Experiments +++++++++++++++++++++++ @@ -24,7 +24,8 @@ Validation, Experiments .. toctree:: :maxdepth: 1 - ../autoexemples/plot_bench_cpu - ../autoexemples/plot_bench_cpu_vector_sum - ../autoexemples/plot_bench_cpu_vector_sum_parallel - ../autoexemples/plot_bench_cpu_vector_sum_avx_parallel + ../auto_examples/plot_bench_cpu + ../auto_examples/plot_bench_cpu_vector_sum + ../auto_examples/plot_bench_cpu_vector_sum_parallel + ../auto_examples/plot_bench_cpu_vector_sum_avx_parallel + ../auto_examples/plot_bench_gpu_vector_sum_gpu diff --git a/_unittests/ut_validation/test_vector_cuda.py b/_unittests/ut_validation/test_vector_cuda.py index e9128837..f7b4fc3f 100644 --- a/_unittests/ut_validation/test_vector_cuda.py +++ b/_unittests/ut_validation/test_vector_cuda.py @@ -8,11 +8,13 @@ vector_sum0, vector_add, vector_sum_atomic, + vector_sum6, ) else: vector_sum0 = None vector_add = None vector_sum_atomic = None + vector_sum6 = None class TestVectorCuda(ExtTestCase): @@ -75,6 +77,18 @@ def test_vector_sum_atomic_cuda(self): def test_vector_sum_atomic_cud_bigger(self): values = numpy.random.randn(30, 224, 224).astype(numpy.float32) t = vector_sum_atomic(values) + self.assertAlmostEqual(t, values.sum().astype(numpy.float32), rtol=1e-3) + + @unittest.skipIf(vector_sum6 is None, reason="CUDA not available") + def test_vector_sum6_cuda(self): + values = numpy.array([[10, 1, 4, 5, 6, 7]], dtype=numpy.float32) + t = vector_sum6(values) + self.assertEqual(t, values.sum().astype(numpy.float32)) + + @unittest.skipIf(vector_sum6 is None, reason="CUDA not available") + def test_vector_sum6_cud_bigger(self): + values = numpy.random.randn(30, 224, 224).astype(numpy.float32) + t = vector_sum6(values) self.assertAlmostEqual(t, values.sum().astype(numpy.float32), rtol=1e-4) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index fa3b1b29..f98083d6 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -34,7 +34,8 @@ jobs: black --diff . displayName: 'Black' - script: | - cmake-lint cmake/* --disabled-codes C0103 C0113 + cmake-lint cmake/Find* --disabled-codes C0103 C0113 + cmake-lint cmake/CMake* --disabled-codes C0103 C0113 displayName: 'cmake-lint' - script: | # python -m pip install -e . diff --git a/clang_format.sh b/clang_format.sh index 44655232..0e85066f 100644 --- a/clang_format.sh +++ b/clang_format.sh @@ -1,30 +1,6 @@ -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_common_parallel.hpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_common.cpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_common.h +#!/bin/bash -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_conv_.cpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_conv_common.h -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_conv.h - -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_tree_ensemble_common_.hpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_tree_ensemble_common_agg_.hpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_tree_ensemble_common_classifier_.hpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_tree_ensemble_py_.cpp -clang-format --length 88 -i onnx_extended/reference/c_ops/c_op_tree_ensemble_py_classifier_.hpp - -clang-format --length 88 -i onnx_extended/validation/vector_function.h -clang-format --length 88 -i onnx_extended/validation/vector_function.cpp -clang-format --length 88 -i onnx_extended/validation/vector_sum.h -clang-format --length 88 -i onnx_extended/validation/vector_sum.cpp - -clang-format --length 88 -i onnx_extended/validation/speed_metrics.cpp -clang-format --length 88 -i onnx_extended/validation/speed_metrics.h - -clang-format --length 88 -i onnx_extended/validation/_validation.cpp - -clang-format --length 88 -i onnx_extended/validation/cuda_utils.h -clang-format --length 88 -i onnx_extended/validation/cuda_example.h -clang-format --length 88 -i onnx_extended/validation/cuda_example.cpp -clang-format --length 88 -i onnx_extended/validation/cuda_example.cu -clang-format --length 88 -i onnx_extended/validation/cuda_example.cuh -clang-format --length 88 -i onnx_extended/validation/cuda_example_py.cpp +find onnx_extended -type f \( -name "*.h" -o -name "*.hpp" -o -name "*.cuh" -o -name "*.cpp" -o -name "*.cc" -o -name "*.cu" \) | while read f; do + echo "Processing '$f'"; + clang-format --length 88 -i $f; +done diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cb2186bf..702bbf3f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -26,6 +26,7 @@ message(STATUS "PYTHON_LIBRARY_DIR=${PYTHON_LIBRARY_DIR}") message(STATUS "PYTHON_NUMPY_INCLUDE_DIR=${PYTHON_NUMPY_INCLUDE_DIR}") message(STATUS "PYTHON_MODULE_EXTENSION=${PYTHON_MODULE_EXTENSION}") message(STATUS "PYTHON_NUMPY_VERSION=${PYTHON_NUMPY_VERSION}") +message(STATUS "USE_NVTX=${USE_NVTX}") message(STATUS "ENV-PATH=$ENV{PATH}") message(STATUS "ENV-PYTHONPATH=$ENV{PYTHONPATH}") @@ -124,6 +125,8 @@ if(CUDA_FOUND) message(STATUS "CUDA_cusparse_LIBRARY=${CUDA_cusparse_LIBRARY}") message(STATUS "CUDA_nvToolsExt_LIBRARY=${CUDA_nvToolsExt_LIBRARY}") message(STATUS "CUDA_OpenCL_LIBRARY=${CUDA_OpenCL_LIBRARY}") + message(STATUS "CUDA NVTX_LINK_C=${NVTX_LINK_C}") + message(STATUS "CUDA NVTX_LINK_CPP=${NVTX_LINK_CPP}") set(CUDA_AVAILABLE 1) else() message(STATUS "Module CudaExtension is not installed.") @@ -182,7 +185,8 @@ if(CUDA_AVAILABLE) cuda_example_py ../onnx_extended/validation/cuda/cuda_example_py.cpp ../onnx_extended/validation/cuda/cuda_example.cpp - ../onnx_extended/validation/cuda/cuda_example.cu) + ../onnx_extended/validation/cuda/cuda_example.cu + ../onnx_extended/validation/cuda/cuda_example_reduce.cu) else() set(config_content "HAS_CUDA = 0") diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake new file mode 100644 index 00000000..70aebf10 --- /dev/null +++ b/cmake/CPM.cmake @@ -0,0 +1,1154 @@ +# CPM.cmake - CMake's missing package manager +# =========================================== +# See https://github.com/cpm-cmake/CPM.cmake for usage and update instructions. +# +# MIT License +# ----------- +#[[ + Copyright (c) 2019-2022 Lars Melchior and contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +]] + +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +# Initialize logging prefix +if(NOT CPM_INDENT) + set(CPM_INDENT + "CPM:" + CACHE INTERNAL "" + ) +endif() + +if(NOT COMMAND cpm_message) + function(cpm_message) + message(${ARGV}) + endfunction() +endif() + +set(CURRENT_CPM_VERSION 0.38.1) + +get_filename_component(CPM_CURRENT_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}" REALPATH) +if(CPM_DIRECTORY) + if(NOT CPM_DIRECTORY STREQUAL CPM_CURRENT_DIRECTORY) + if(CPM_VERSION VERSION_LESS CURRENT_CPM_VERSION) + message( + AUTHOR_WARNING + "${CPM_INDENT} \ +A dependency is using a more recent CPM version (${CURRENT_CPM_VERSION}) than the current project (${CPM_VERSION}). \ +It is recommended to upgrade CPM to the most recent version. \ +See https://github.com/cpm-cmake/CPM.cmake for more information." + ) + endif() + if(${CMAKE_VERSION} VERSION_LESS "3.17.0") + include(FetchContent) + endif() + return() + endif() + + get_property( + CPM_INITIALIZED GLOBAL "" + PROPERTY CPM_INITIALIZED + SET + ) + if(CPM_INITIALIZED) + return() + endif() +endif() + +if(CURRENT_CPM_VERSION MATCHES "development-version") + message( + WARNING "${CPM_INDENT} Your project is using an unstable development version of CPM.cmake. \ +Please update to a recent release if possible. \ +See https://github.com/cpm-cmake/CPM.cmake for details." + ) +endif() + +set_property(GLOBAL PROPERTY CPM_INITIALIZED true) + +macro(cpm_set_policies) + # the policy allows us to change options without caching + cmake_policy(SET CMP0077 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + # the policy allows us to change set(CACHE) without caching + if(POLICY CMP0126) + cmake_policy(SET CMP0126 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0126 NEW) + endif() + + # The policy uses the download time for timestamp, instead of the timestamp in the archive. This + # allows for proper rebuilds when a projects url changes + if(POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0135 NEW) + endif() +endmacro() +cpm_set_policies() + +option(CPM_USE_LOCAL_PACKAGES "Always try to use `find_package` to get dependencies" + $ENV{CPM_USE_LOCAL_PACKAGES} +) +option(CPM_LOCAL_PACKAGES_ONLY "Only use `find_package` to get dependencies" + $ENV{CPM_LOCAL_PACKAGES_ONLY} +) +option(CPM_DOWNLOAD_ALL "Always download dependencies from source" $ENV{CPM_DOWNLOAD_ALL}) +option(CPM_DONT_UPDATE_MODULE_PATH "Don't update the module path to allow using find_package" + $ENV{CPM_DONT_UPDATE_MODULE_PATH} +) +option(CPM_DONT_CREATE_PACKAGE_LOCK "Don't create a package lock file in the binary path" + $ENV{CPM_DONT_CREATE_PACKAGE_LOCK} +) +option(CPM_INCLUDE_ALL_IN_PACKAGE_LOCK + "Add all packages added through CPM.cmake to the package lock" + $ENV{CPM_INCLUDE_ALL_IN_PACKAGE_LOCK} +) +option(CPM_USE_NAMED_CACHE_DIRECTORIES + "Use additional directory of package name in cache on the most nested level." + $ENV{CPM_USE_NAMED_CACHE_DIRECTORIES} +) + +set(CPM_VERSION + ${CURRENT_CPM_VERSION} + CACHE INTERNAL "" +) +set(CPM_DIRECTORY + ${CPM_CURRENT_DIRECTORY} + CACHE INTERNAL "" +) +set(CPM_FILE + ${CMAKE_CURRENT_LIST_FILE} + CACHE INTERNAL "" +) +set(CPM_PACKAGES + "" + CACHE INTERNAL "" +) +set(CPM_DRY_RUN + OFF + CACHE INTERNAL "Don't download or configure dependencies (for testing)" +) + +if(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_SOURCE_CACHE_DEFAULT $ENV{CPM_SOURCE_CACHE}) +else() + set(CPM_SOURCE_CACHE_DEFAULT OFF) +endif() + +set(CPM_SOURCE_CACHE + ${CPM_SOURCE_CACHE_DEFAULT} + CACHE PATH "Directory to download CPM dependencies" +) + +if(NOT CPM_DONT_UPDATE_MODULE_PATH) + set(CPM_MODULE_PATH + "${CMAKE_BINARY_DIR}/CPM_modules" + CACHE INTERNAL "" + ) + # remove old modules + file(REMOVE_RECURSE ${CPM_MODULE_PATH}) + file(MAKE_DIRECTORY ${CPM_MODULE_PATH}) + # locally added CPM modules should override global packages + set(CMAKE_MODULE_PATH "${CPM_MODULE_PATH};${CMAKE_MODULE_PATH}") +endif() + +if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + set(CPM_PACKAGE_LOCK_FILE + "${CMAKE_BINARY_DIR}/cpm-package-lock.cmake" + CACHE INTERNAL "" + ) + file(WRITE ${CPM_PACKAGE_LOCK_FILE} + "# CPM Package Lock\n# This file should be committed to version control\n\n" + ) +endif() + +include(FetchContent) + +# Try to infer package name from git repository uri (path or url) +function(cpm_package_name_from_git_uri URI RESULT) + if("${URI}" MATCHES "([^/:]+)/?.git/?$") + set(${RESULT} + ${CMAKE_MATCH_1} + PARENT_SCOPE + ) + else() + unset(${RESULT} PARENT_SCOPE) + endif() +endfunction() + +# Try to infer package name and version from a url +function(cpm_package_name_and_ver_from_url url outName outVer) + if(url MATCHES "[/\\?]([a-zA-Z0-9_\\.-]+)\\.(tar|tar\\.gz|tar\\.bz2|zip|ZIP)(\\?|/|$)") + # We matched an archive + set(filename "${CMAKE_MATCH_1}") + + if(filename MATCHES "([a-zA-Z0-9_\\.-]+)[_-]v?(([0-9]+\\.)*[0-9]+[a-zA-Z0-9]*)") + # We matched - (ie foo-1.2.3) + set(${outName} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + set(${outVer} + "${CMAKE_MATCH_2}" + PARENT_SCOPE + ) + elseif(filename MATCHES "(([0-9]+\\.)+[0-9]+[a-zA-Z0-9]*)") + # We couldn't find a name, but we found a version + # + # In many cases (which we don't handle here) the url would look something like + # `irrelevant/ACTUAL_PACKAGE_NAME/irrelevant/1.2.3.zip`. In such a case we can't possibly + # distinguish the package name from the irrelevant bits. Moreover if we try to match the + # package name from the filename, we'd get bogus at best. + unset(${outName} PARENT_SCOPE) + set(${outVer} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + else() + # Boldly assume that the file name is the package name. + # + # Yes, something like `irrelevant/ACTUAL_NAME/irrelevant/download.zip` will ruin our day, but + # such cases should be quite rare. No popular service does this... we think. + set(${outName} + "${filename}" + PARENT_SCOPE + ) + unset(${outVer} PARENT_SCOPE) + endif() + else() + # No ideas yet what to do with non-archives + unset(${outName} PARENT_SCOPE) + unset(${outVer} PARENT_SCOPE) + endif() +endfunction() + +function(cpm_find_package NAME VERSION) + string(REPLACE " " ";" EXTRA_ARGS "${ARGN}") + find_package(${NAME} ${VERSION} ${EXTRA_ARGS} QUIET) + if(${CPM_ARGS_NAME}_FOUND) + if(DEFINED ${CPM_ARGS_NAME}_VERSION) + set(VERSION ${${CPM_ARGS_NAME}_VERSION}) + endif() + cpm_message(STATUS "${CPM_INDENT} Using local package ${CPM_ARGS_NAME}@${VERSION}") + CPMRegisterPackage(${CPM_ARGS_NAME} "${VERSION}") + set(CPM_PACKAGE_FOUND + YES + PARENT_SCOPE + ) + else() + set(CPM_PACKAGE_FOUND + NO + PARENT_SCOPE + ) + endif() +endfunction() + +# Create a custom FindXXX.cmake module for a CPM package This prevents `find_package(NAME)` from +# finding the system library +function(cpm_create_module_file Name) + if(NOT CPM_DONT_UPDATE_MODULE_PATH) + # erase any previous modules + file(WRITE ${CPM_MODULE_PATH}/Find${Name}.cmake + "include(\"${CPM_FILE}\")\n${ARGN}\nset(${Name}_FOUND TRUE)" + ) + endif() +endfunction() + +# Find a package locally or fallback to CPMAddPackage +function(CPMFindPackage) + set(oneValueArgs NAME VERSION GIT_TAG FIND_PACKAGE_ARGUMENTS) + + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "" ${ARGN}) + + if(NOT DEFINED CPM_ARGS_VERSION) + if(DEFINED CPM_ARGS_GIT_TAG) + cpm_get_version_from_git_tag("${CPM_ARGS_GIT_TAG}" CPM_ARGS_VERSION) + endif() + endif() + + set(downloadPackage ${CPM_DOWNLOAD_ALL}) + if(DEFINED CPM_DOWNLOAD_${CPM_ARGS_NAME}) + set(downloadPackage ${CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + elseif(DEFINED ENV{CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + set(downloadPackage $ENV{CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + endif() + if(downloadPackage) + CPMAddPackage(${ARGN}) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + if(CPM_PACKAGE_ALREADY_ADDED) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + cpm_find_package(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}" ${CPM_ARGS_FIND_PACKAGE_ARGUMENTS}) + + if(NOT CPM_PACKAGE_FOUND) + CPMAddPackage(${ARGN}) + cpm_export_variables(${CPM_ARGS_NAME}) + endif() + +endfunction() + +# checks if a package has been added before +function(cpm_check_if_package_already_added CPM_ARGS_NAME CPM_ARGS_VERSION) + if("${CPM_ARGS_NAME}" IN_LIST CPM_PACKAGES) + CPMGetPackageVersion(${CPM_ARGS_NAME} CPM_PACKAGE_VERSION) + if("${CPM_PACKAGE_VERSION}" VERSION_LESS "${CPM_ARGS_VERSION}") + message( + WARNING + "${CPM_INDENT} Requires a newer version of ${CPM_ARGS_NAME} (${CPM_ARGS_VERSION}) than currently included (${CPM_PACKAGE_VERSION})." + ) + endif() + cpm_get_fetch_properties(${CPM_ARGS_NAME}) + set(${CPM_ARGS_NAME}_ADDED NO) + set(CPM_PACKAGE_ALREADY_ADDED + YES + PARENT_SCOPE + ) + cpm_export_variables(${CPM_ARGS_NAME}) + else() + set(CPM_PACKAGE_ALREADY_ADDED + NO + PARENT_SCOPE + ) + endif() +endfunction() + +# Parse the argument of CPMAddPackage in case a single one was provided and convert it to a list of +# arguments which can then be parsed idiomatically. For example gh:foo/bar@1.2.3 will be converted +# to: GITHUB_REPOSITORY;foo/bar;VERSION;1.2.3 +function(cpm_parse_add_package_single_arg arg outArgs) + # Look for a scheme + if("${arg}" MATCHES "^([a-zA-Z]+):(.+)$") + string(TOLOWER "${CMAKE_MATCH_1}" scheme) + set(uri "${CMAKE_MATCH_2}") + + # Check for CPM-specific schemes + if(scheme STREQUAL "gh") + set(out "GITHUB_REPOSITORY;${uri}") + set(packageType "git") + elseif(scheme STREQUAL "gl") + set(out "GITLAB_REPOSITORY;${uri}") + set(packageType "git") + elseif(scheme STREQUAL "bb") + set(out "BITBUCKET_REPOSITORY;${uri}") + set(packageType "git") + # A CPM-specific scheme was not found. Looks like this is a generic URL so try to determine + # type + elseif(arg MATCHES ".git/?(@|#|$)") + set(out "GIT_REPOSITORY;${arg}") + set(packageType "git") + else() + # Fall back to a URL + set(out "URL;${arg}") + set(packageType "archive") + + # We could also check for SVN since FetchContent supports it, but SVN is so rare these days. + # We just won't bother with the additional complexity it will induce in this function. SVN is + # done by multi-arg + endif() + else() + if(arg MATCHES ".git/?(@|#|$)") + set(out "GIT_REPOSITORY;${arg}") + set(packageType "git") + else() + # Give up + message(FATAL_ERROR "${CPM_INDENT} Can't determine package type of '${arg}'") + endif() + endif() + + # For all packages we interpret @... as version. Only replace the last occurrence. Thus URIs + # containing '@' can be used + string(REGEX REPLACE "@([^@]+)$" ";VERSION;\\1" out "${out}") + + # Parse the rest according to package type + if(packageType STREQUAL "git") + # For git repos we interpret #... as a tag or branch or commit hash + string(REGEX REPLACE "#([^#]+)$" ";GIT_TAG;\\1" out "${out}") + elseif(packageType STREQUAL "archive") + # For archives we interpret #... as a URL hash. + string(REGEX REPLACE "#([^#]+)$" ";URL_HASH;\\1" out "${out}") + # We don't try to parse the version if it's not provided explicitly. cpm_get_version_from_url + # should do this at a later point + else() + # We should never get here. This is an assertion and hitting it means there's a bug in the code + # above. A packageType was set, but not handled by this if-else. + message(FATAL_ERROR "${CPM_INDENT} Unsupported package type '${packageType}' of '${arg}'") + endif() + + set(${outArgs} + ${out} + PARENT_SCOPE + ) +endfunction() + +# Check that the working directory for a git repo is clean +function(cpm_check_git_working_dir_is_clean repoPath gitTag isClean) + + find_package(Git REQUIRED) + + if(NOT GIT_EXECUTABLE) + # No git executable, assume directory is clean + set(${isClean} + TRUE + PARENT_SCOPE + ) + return() + endif() + + # check for uncommitted changes + execute_process( + COMMAND ${GIT_EXECUTABLE} status --porcelain + RESULT_VARIABLE resultGitStatus + OUTPUT_VARIABLE repoStatus + OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET + WORKING_DIRECTORY ${repoPath} + ) + if(resultGitStatus) + # not supposed to happen, assume clean anyway + message(WARNING "${CPM_INDENT} Calling git status on folder ${repoPath} failed") + set(${isClean} + TRUE + PARENT_SCOPE + ) + return() + endif() + + if(NOT "${repoStatus}" STREQUAL "") + set(${isClean} + FALSE + PARENT_SCOPE + ) + return() + endif() + + # check for committed changes + execute_process( + COMMAND ${GIT_EXECUTABLE} diff -s --exit-code ${gitTag} + RESULT_VARIABLE resultGitDiff + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_QUIET + WORKING_DIRECTORY ${repoPath} + ) + + if(${resultGitDiff} EQUAL 0) + set(${isClean} + TRUE + PARENT_SCOPE + ) + else() + set(${isClean} + FALSE + PARENT_SCOPE + ) + endif() + +endfunction() + +# method to overwrite internal FetchContent properties, to allow using CPM.cmake to overload +# FetchContent calls. As these are internal cmake properties, this method should be used carefully +# and may need modification in future CMake versions. Source: +# https://github.com/Kitware/CMake/blob/dc3d0b5a0a7d26d43d6cfeb511e224533b5d188f/Modules/FetchContent.cmake#L1152 +function(cpm_override_fetchcontent contentName) + cmake_parse_arguments(PARSE_ARGV 1 arg "" "SOURCE_DIR;BINARY_DIR" "") + if(NOT "${arg_UNPARSED_ARGUMENTS}" STREQUAL "") + message(FATAL_ERROR "${CPM_INDENT} Unsupported arguments: ${arg_UNPARSED_ARGUMENTS}") + endif() + + string(TOLOWER ${contentName} contentNameLower) + set(prefix "_FetchContent_${contentNameLower}") + + set(propertyName "${prefix}_sourceDir") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} "${arg_SOURCE_DIR}") + + set(propertyName "${prefix}_binaryDir") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} "${arg_BINARY_DIR}") + + set(propertyName "${prefix}_populated") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} TRUE) +endfunction() + +# Download and add a package from source +function(CPMAddPackage) + cpm_set_policies() + + list(LENGTH ARGN argnLength) + if(argnLength EQUAL 1) + cpm_parse_add_package_single_arg("${ARGN}" ARGN) + + # The shorthand syntax implies EXCLUDE_FROM_ALL and SYSTEM + set(ARGN "${ARGN};EXCLUDE_FROM_ALL;YES;SYSTEM;YES;") + endif() + + set(oneValueArgs + NAME + FORCE + VERSION + GIT_TAG + DOWNLOAD_ONLY + GITHUB_REPOSITORY + GITLAB_REPOSITORY + BITBUCKET_REPOSITORY + GIT_REPOSITORY + SOURCE_DIR + DOWNLOAD_COMMAND + FIND_PACKAGE_ARGUMENTS + NO_CACHE + SYSTEM + GIT_SHALLOW + EXCLUDE_FROM_ALL + SOURCE_SUBDIR + ) + + set(multiValueArgs URL OPTIONS) + + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "${multiValueArgs}" "${ARGN}") + + # Set default values for arguments + + if(NOT DEFINED CPM_ARGS_VERSION) + if(DEFINED CPM_ARGS_GIT_TAG) + cpm_get_version_from_git_tag("${CPM_ARGS_GIT_TAG}" CPM_ARGS_VERSION) + endif() + endif() + + if(CPM_ARGS_DOWNLOAD_ONLY) + set(DOWNLOAD_ONLY ${CPM_ARGS_DOWNLOAD_ONLY}) + else() + set(DOWNLOAD_ONLY NO) + endif() + + if(DEFINED CPM_ARGS_GITHUB_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://github.com/${CPM_ARGS_GITHUB_REPOSITORY}.git") + elseif(DEFINED CPM_ARGS_GITLAB_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://gitlab.com/${CPM_ARGS_GITLAB_REPOSITORY}.git") + elseif(DEFINED CPM_ARGS_BITBUCKET_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://bitbucket.org/${CPM_ARGS_BITBUCKET_REPOSITORY}.git") + endif() + + if(DEFINED CPM_ARGS_GIT_REPOSITORY) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_REPOSITORY ${CPM_ARGS_GIT_REPOSITORY}) + if(NOT DEFINED CPM_ARGS_GIT_TAG) + set(CPM_ARGS_GIT_TAG v${CPM_ARGS_VERSION}) + endif() + + # If a name wasn't provided, try to infer it from the git repo + if(NOT DEFINED CPM_ARGS_NAME) + cpm_package_name_from_git_uri(${CPM_ARGS_GIT_REPOSITORY} CPM_ARGS_NAME) + endif() + endif() + + set(CPM_SKIP_FETCH FALSE) + + if(DEFINED CPM_ARGS_GIT_TAG) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_TAG ${CPM_ARGS_GIT_TAG}) + # If GIT_SHALLOW is explicitly specified, honor the value. + if(DEFINED CPM_ARGS_GIT_SHALLOW) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_SHALLOW ${CPM_ARGS_GIT_SHALLOW}) + endif() + endif() + + if(DEFINED CPM_ARGS_URL) + # If a name or version aren't provided, try to infer them from the URL + list(GET CPM_ARGS_URL 0 firstUrl) + cpm_package_name_and_ver_from_url(${firstUrl} nameFromUrl verFromUrl) + # If we fail to obtain name and version from the first URL, we could try other URLs if any. + # However multiple URLs are expected to be quite rare, so for now we won't bother. + + # If the caller provided their own name and version, they trump the inferred ones. + if(NOT DEFINED CPM_ARGS_NAME) + set(CPM_ARGS_NAME ${nameFromUrl}) + endif() + if(NOT DEFINED CPM_ARGS_VERSION) + set(CPM_ARGS_VERSION ${verFromUrl}) + endif() + + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS URL "${CPM_ARGS_URL}") + endif() + + # Check for required arguments + + if(NOT DEFINED CPM_ARGS_NAME) + message( + FATAL_ERROR + "${CPM_INDENT} 'NAME' was not provided and couldn't be automatically inferred for package added with arguments: '${ARGN}'" + ) + endif() + + # Check if package has been added before + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + if(CPM_PACKAGE_ALREADY_ADDED) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + # Check for manual overrides + if(NOT CPM_ARGS_FORCE AND NOT "${CPM_${CPM_ARGS_NAME}_SOURCE}" STREQUAL "") + set(PACKAGE_SOURCE ${CPM_${CPM_ARGS_NAME}_SOURCE}) + set(CPM_${CPM_ARGS_NAME}_SOURCE "") + CPMAddPackage( + NAME "${CPM_ARGS_NAME}" + SOURCE_DIR "${PACKAGE_SOURCE}" + EXCLUDE_FROM_ALL "${CPM_ARGS_EXCLUDE_FROM_ALL}" + SYSTEM "${CPM_ARGS_SYSTEM}" + OPTIONS "${CPM_ARGS_OPTIONS}" + SOURCE_SUBDIR "${CPM_ARGS_SOURCE_SUBDIR}" + DOWNLOAD_ONLY "${DOWNLOAD_ONLY}" + FORCE True + ) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + # Check for available declaration + if(NOT CPM_ARGS_FORCE AND NOT "${CPM_DECLARATION_${CPM_ARGS_NAME}}" STREQUAL "") + set(declaration ${CPM_DECLARATION_${CPM_ARGS_NAME}}) + set(CPM_DECLARATION_${CPM_ARGS_NAME} "") + CPMAddPackage(${declaration}) + cpm_export_variables(${CPM_ARGS_NAME}) + # checking again to ensure version and option compatibility + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + return() + endif() + + if(NOT CPM_ARGS_FORCE) + if(CPM_USE_LOCAL_PACKAGES OR CPM_LOCAL_PACKAGES_ONLY) + cpm_find_package(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}" ${CPM_ARGS_FIND_PACKAGE_ARGUMENTS}) + + if(CPM_PACKAGE_FOUND) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + if(CPM_LOCAL_PACKAGES_ONLY) + message( + SEND_ERROR + "${CPM_INDENT} ${CPM_ARGS_NAME} not found via find_package(${CPM_ARGS_NAME} ${CPM_ARGS_VERSION})" + ) + endif() + endif() + endif() + + CPMRegisterPackage("${CPM_ARGS_NAME}" "${CPM_ARGS_VERSION}") + + if(DEFINED CPM_ARGS_GIT_TAG) + set(PACKAGE_INFO "${CPM_ARGS_GIT_TAG}") + elseif(DEFINED CPM_ARGS_SOURCE_DIR) + set(PACKAGE_INFO "${CPM_ARGS_SOURCE_DIR}") + else() + set(PACKAGE_INFO "${CPM_ARGS_VERSION}") + endif() + + if(DEFINED FETCHCONTENT_BASE_DIR) + # respect user's FETCHCONTENT_BASE_DIR if set + set(CPM_FETCHCONTENT_BASE_DIR ${FETCHCONTENT_BASE_DIR}) + else() + set(CPM_FETCHCONTENT_BASE_DIR ${CMAKE_BINARY_DIR}/_deps) + endif() + + if(DEFINED CPM_ARGS_DOWNLOAD_COMMAND) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS DOWNLOAD_COMMAND ${CPM_ARGS_DOWNLOAD_COMMAND}) + elseif(DEFINED CPM_ARGS_SOURCE_DIR) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS SOURCE_DIR ${CPM_ARGS_SOURCE_DIR}) + if(NOT IS_ABSOLUTE ${CPM_ARGS_SOURCE_DIR}) + # Expand `CPM_ARGS_SOURCE_DIR` relative path. This is important because EXISTS doesn't work + # for relative paths. + get_filename_component( + source_directory ${CPM_ARGS_SOURCE_DIR} REALPATH BASE_DIR ${CMAKE_CURRENT_BINARY_DIR} + ) + else() + set(source_directory ${CPM_ARGS_SOURCE_DIR}) + endif() + if(NOT EXISTS ${source_directory}) + string(TOLOWER ${CPM_ARGS_NAME} lower_case_name) + # remove timestamps so CMake will re-download the dependency + file(REMOVE_RECURSE "${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-subbuild") + endif() + elseif(CPM_SOURCE_CACHE AND NOT CPM_ARGS_NO_CACHE) + string(TOLOWER ${CPM_ARGS_NAME} lower_case_name) + set(origin_parameters ${CPM_ARGS_UNPARSED_ARGUMENTS}) + list(SORT origin_parameters) + if(CPM_USE_NAMED_CACHE_DIRECTORIES) + string(SHA1 origin_hash "${origin_parameters};NEW_CACHE_STRUCTURE_TAG") + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${origin_hash}/${CPM_ARGS_NAME}) + else() + string(SHA1 origin_hash "${origin_parameters}") + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${origin_hash}) + endif() + # Expand `download_directory` relative path. This is important because EXISTS doesn't work for + # relative paths. + get_filename_component(download_directory ${download_directory} ABSOLUTE) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS SOURCE_DIR ${download_directory}) + + if(CPM_SOURCE_CACHE) + file(LOCK ${download_directory}/../cmake.lock) + endif() + + if(EXISTS ${download_directory}) + if(CPM_SOURCE_CACHE) + file(LOCK ${download_directory}/../cmake.lock RELEASE) + endif() + + cpm_store_fetch_properties( + ${CPM_ARGS_NAME} "${download_directory}" + "${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-build" + ) + cpm_get_fetch_properties("${CPM_ARGS_NAME}") + + if(DEFINED CPM_ARGS_GIT_TAG AND NOT (PATCH_COMMAND IN_LIST CPM_ARGS_UNPARSED_ARGUMENTS)) + # warn if cache has been changed since checkout + cpm_check_git_working_dir_is_clean(${download_directory} ${CPM_ARGS_GIT_TAG} IS_CLEAN) + if(NOT ${IS_CLEAN}) + message( + WARNING "${CPM_INDENT} Cache for ${CPM_ARGS_NAME} (${download_directory}) is dirty" + ) + endif() + endif() + + cpm_add_subdirectory( + "${CPM_ARGS_NAME}" + "${DOWNLOAD_ONLY}" + "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + "${${CPM_ARGS_NAME}_BINARY_DIR}" + "${CPM_ARGS_EXCLUDE_FROM_ALL}" + "${CPM_ARGS_SYSTEM}" + "${CPM_ARGS_OPTIONS}" + ) + set(PACKAGE_INFO "${PACKAGE_INFO} at ${download_directory}") + + # As the source dir is already cached/populated, we override the call to FetchContent. + set(CPM_SKIP_FETCH TRUE) + cpm_override_fetchcontent( + "${lower_case_name}" SOURCE_DIR "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + BINARY_DIR "${${CPM_ARGS_NAME}_BINARY_DIR}" + ) + + else() + # Enable shallow clone when GIT_TAG is not a commit hash. Our guess may not be accurate, but + # it should guarantee no commit hash get mis-detected. + if(NOT DEFINED CPM_ARGS_GIT_SHALLOW) + cpm_is_git_tag_commit_hash("${CPM_ARGS_GIT_TAG}" IS_HASH) + if(NOT ${IS_HASH}) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_SHALLOW TRUE) + endif() + endif() + + # remove timestamps so CMake will re-download the dependency + file(REMOVE_RECURSE ${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-subbuild) + set(PACKAGE_INFO "${PACKAGE_INFO} to ${download_directory}") + endif() + endif() + + cpm_create_module_file(${CPM_ARGS_NAME} "CPMAddPackage(\"${ARGN}\")") + + if(CPM_PACKAGE_LOCK_ENABLED) + if((CPM_ARGS_VERSION AND NOT CPM_ARGS_SOURCE_DIR) OR CPM_INCLUDE_ALL_IN_PACKAGE_LOCK) + cpm_add_to_package_lock(${CPM_ARGS_NAME} "${ARGN}") + elseif(CPM_ARGS_SOURCE_DIR) + cpm_add_comment_to_package_lock(${CPM_ARGS_NAME} "local directory") + else() + cpm_add_comment_to_package_lock(${CPM_ARGS_NAME} "${ARGN}") + endif() + endif() + + cpm_message( + STATUS "${CPM_INDENT} Adding package ${CPM_ARGS_NAME}@${CPM_ARGS_VERSION} (${PACKAGE_INFO})" + ) + + if(NOT CPM_SKIP_FETCH) + cpm_declare_fetch( + "${CPM_ARGS_NAME}" "${CPM_ARGS_VERSION}" "${PACKAGE_INFO}" "${CPM_ARGS_UNPARSED_ARGUMENTS}" + ) + cpm_fetch_package("${CPM_ARGS_NAME}" populated) + if(CPM_CACHE_SOURCE AND download_directory) + file(LOCK ${download_directory}/../cmake.lock RELEASE) + endif() + if(${populated}) + cpm_add_subdirectory( + "${CPM_ARGS_NAME}" + "${DOWNLOAD_ONLY}" + "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + "${${CPM_ARGS_NAME}_BINARY_DIR}" + "${CPM_ARGS_EXCLUDE_FROM_ALL}" + "${CPM_ARGS_SYSTEM}" + "${CPM_ARGS_OPTIONS}" + ) + endif() + cpm_get_fetch_properties("${CPM_ARGS_NAME}") + endif() + + set(${CPM_ARGS_NAME}_ADDED YES) + cpm_export_variables("${CPM_ARGS_NAME}") +endfunction() + +# Fetch a previously declared package +macro(CPMGetPackage Name) + if(DEFINED "CPM_DECLARATION_${Name}") + CPMAddPackage(NAME ${Name}) + else() + message(SEND_ERROR "${CPM_INDENT} Cannot retrieve package ${Name}: no declaration available") + endif() +endmacro() + +# export variables available to the caller to the parent scope expects ${CPM_ARGS_NAME} to be set +macro(cpm_export_variables name) + set(${name}_SOURCE_DIR + "${${name}_SOURCE_DIR}" + PARENT_SCOPE + ) + set(${name}_BINARY_DIR + "${${name}_BINARY_DIR}" + PARENT_SCOPE + ) + set(${name}_ADDED + "${${name}_ADDED}" + PARENT_SCOPE + ) + set(CPM_LAST_PACKAGE_NAME + "${name}" + PARENT_SCOPE + ) +endmacro() + +# declares a package, so that any call to CPMAddPackage for the package name will use these +# arguments instead. Previous declarations will not be overridden. +macro(CPMDeclarePackage Name) + if(NOT DEFINED "CPM_DECLARATION_${Name}") + set("CPM_DECLARATION_${Name}" "${ARGN}") + endif() +endmacro() + +function(cpm_add_to_package_lock Name) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + cpm_prettify_package_arguments(PRETTY_ARGN false ${ARGN}) + file(APPEND ${CPM_PACKAGE_LOCK_FILE} "# ${Name}\nCPMDeclarePackage(${Name}\n${PRETTY_ARGN})\n") + endif() +endfunction() + +function(cpm_add_comment_to_package_lock Name) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + cpm_prettify_package_arguments(PRETTY_ARGN true ${ARGN}) + file(APPEND ${CPM_PACKAGE_LOCK_FILE} + "# ${Name} (unversioned)\n# CPMDeclarePackage(${Name}\n${PRETTY_ARGN}#)\n" + ) + endif() +endfunction() + +# includes the package lock file if it exists and creates a target `cpm-update-package-lock` to +# update it +macro(CPMUsePackageLock file) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + get_filename_component(CPM_ABSOLUTE_PACKAGE_LOCK_PATH ${file} ABSOLUTE) + if(EXISTS ${CPM_ABSOLUTE_PACKAGE_LOCK_PATH}) + include(${CPM_ABSOLUTE_PACKAGE_LOCK_PATH}) + endif() + if(NOT TARGET cpm-update-package-lock) + add_custom_target( + cpm-update-package-lock COMMAND ${CMAKE_COMMAND} -E copy ${CPM_PACKAGE_LOCK_FILE} + ${CPM_ABSOLUTE_PACKAGE_LOCK_PATH} + ) + endif() + set(CPM_PACKAGE_LOCK_ENABLED true) + endif() +endmacro() + +# registers a package that has been added to CPM +function(CPMRegisterPackage PACKAGE VERSION) + list(APPEND CPM_PACKAGES ${PACKAGE}) + set(CPM_PACKAGES + ${CPM_PACKAGES} + CACHE INTERNAL "" + ) + set("CPM_PACKAGE_${PACKAGE}_VERSION" + ${VERSION} + CACHE INTERNAL "" + ) +endfunction() + +# retrieve the current version of the package to ${OUTPUT} +function(CPMGetPackageVersion PACKAGE OUTPUT) + set(${OUTPUT} + "${CPM_PACKAGE_${PACKAGE}_VERSION}" + PARENT_SCOPE + ) +endfunction() + +# declares a package in FetchContent_Declare +function(cpm_declare_fetch PACKAGE VERSION INFO) + if(${CPM_DRY_RUN}) + cpm_message(STATUS "${CPM_INDENT} Package not declared (dry run)") + return() + endif() + + FetchContent_Declare(${PACKAGE} ${ARGN}) +endfunction() + +# returns properties for a package previously defined by cpm_declare_fetch +function(cpm_get_fetch_properties PACKAGE) + if(${CPM_DRY_RUN}) + return() + endif() + + set(${PACKAGE}_SOURCE_DIR + "${CPM_PACKAGE_${PACKAGE}_SOURCE_DIR}" + PARENT_SCOPE + ) + set(${PACKAGE}_BINARY_DIR + "${CPM_PACKAGE_${PACKAGE}_BINARY_DIR}" + PARENT_SCOPE + ) +endfunction() + +function(cpm_store_fetch_properties PACKAGE source_dir binary_dir) + if(${CPM_DRY_RUN}) + return() + endif() + + set(CPM_PACKAGE_${PACKAGE}_SOURCE_DIR + "${source_dir}" + CACHE INTERNAL "" + ) + set(CPM_PACKAGE_${PACKAGE}_BINARY_DIR + "${binary_dir}" + CACHE INTERNAL "" + ) +endfunction() + +# adds a package as a subdirectory if viable, according to provided options +function( + cpm_add_subdirectory + PACKAGE + DOWNLOAD_ONLY + SOURCE_DIR + BINARY_DIR + EXCLUDE + SYSTEM + OPTIONS +) + + if(NOT DOWNLOAD_ONLY AND EXISTS ${SOURCE_DIR}/CMakeLists.txt) + set(addSubdirectoryExtraArgs "") + if(EXCLUDE) + list(APPEND addSubdirectoryExtraArgs EXCLUDE_FROM_ALL) + endif() + if("${SYSTEM}" AND "${CMAKE_VERSION}" VERSION_GREATER_EQUAL "3.25") + # https://cmake.org/cmake/help/latest/prop_dir/SYSTEM.html#prop_dir:SYSTEM + list(APPEND addSubdirectoryExtraArgs SYSTEM) + endif() + if(OPTIONS) + foreach(OPTION ${OPTIONS}) + cpm_parse_option("${OPTION}") + set(${OPTION_KEY} "${OPTION_VALUE}") + endforeach() + endif() + set(CPM_OLD_INDENT "${CPM_INDENT}") + set(CPM_INDENT "${CPM_INDENT} ${PACKAGE}:") + add_subdirectory(${SOURCE_DIR} ${BINARY_DIR} ${addSubdirectoryExtraArgs}) + set(CPM_INDENT "${CPM_OLD_INDENT}") + endif() +endfunction() + +# downloads a previously declared package via FetchContent and exports the variables +# `${PACKAGE}_SOURCE_DIR` and `${PACKAGE}_BINARY_DIR` to the parent scope +function(cpm_fetch_package PACKAGE populated) + set(${populated} + FALSE + PARENT_SCOPE + ) + if(${CPM_DRY_RUN}) + cpm_message(STATUS "${CPM_INDENT} Package ${PACKAGE} not fetched (dry run)") + return() + endif() + + FetchContent_GetProperties(${PACKAGE}) + + string(TOLOWER "${PACKAGE}" lower_case_name) + + if(NOT ${lower_case_name}_POPULATED) + FetchContent_Populate(${PACKAGE}) + set(${populated} + TRUE + PARENT_SCOPE + ) + endif() + + cpm_store_fetch_properties( + ${CPM_ARGS_NAME} ${${lower_case_name}_SOURCE_DIR} ${${lower_case_name}_BINARY_DIR} + ) + + set(${PACKAGE}_SOURCE_DIR + ${${lower_case_name}_SOURCE_DIR} + PARENT_SCOPE + ) + set(${PACKAGE}_BINARY_DIR + ${${lower_case_name}_BINARY_DIR} + PARENT_SCOPE + ) +endfunction() + +# splits a package option +function(cpm_parse_option OPTION) + string(REGEX MATCH "^[^ ]+" OPTION_KEY "${OPTION}") + string(LENGTH "${OPTION}" OPTION_LENGTH) + string(LENGTH "${OPTION_KEY}" OPTION_KEY_LENGTH) + if(OPTION_KEY_LENGTH STREQUAL OPTION_LENGTH) + # no value for key provided, assume user wants to set option to "ON" + set(OPTION_VALUE "ON") + else() + math(EXPR OPTION_KEY_LENGTH "${OPTION_KEY_LENGTH}+1") + string(SUBSTRING "${OPTION}" "${OPTION_KEY_LENGTH}" "-1" OPTION_VALUE) + endif() + set(OPTION_KEY + "${OPTION_KEY}" + PARENT_SCOPE + ) + set(OPTION_VALUE + "${OPTION_VALUE}" + PARENT_SCOPE + ) +endfunction() + +# guesses the package version from a git tag +function(cpm_get_version_from_git_tag GIT_TAG RESULT) + string(LENGTH ${GIT_TAG} length) + if(length EQUAL 40) + # GIT_TAG is probably a git hash + set(${RESULT} + 0 + PARENT_SCOPE + ) + else() + string(REGEX MATCH "v?([0123456789.]*).*" _ ${GIT_TAG}) + set(${RESULT} + ${CMAKE_MATCH_1} + PARENT_SCOPE + ) + endif() +endfunction() + +# guesses if the git tag is a commit hash or an actual tag or a branch name. +function(cpm_is_git_tag_commit_hash GIT_TAG RESULT) + string(LENGTH "${GIT_TAG}" length) + # full hash has 40 characters, and short hash has at least 7 characters. + if(length LESS 7 OR length GREATER 40) + set(${RESULT} + 0 + PARENT_SCOPE + ) + else() + if(${GIT_TAG} MATCHES "^[a-fA-F0-9]+$") + set(${RESULT} + 1 + PARENT_SCOPE + ) + else() + set(${RESULT} + 0 + PARENT_SCOPE + ) + endif() + endif() +endfunction() + +function(cpm_prettify_package_arguments OUT_VAR IS_IN_COMMENT) + set(oneValueArgs + NAME + FORCE + VERSION + GIT_TAG + DOWNLOAD_ONLY + GITHUB_REPOSITORY + GITLAB_REPOSITORY + GIT_REPOSITORY + SOURCE_DIR + DOWNLOAD_COMMAND + FIND_PACKAGE_ARGUMENTS + NO_CACHE + SYSTEM + GIT_SHALLOW + ) + set(multiValueArgs OPTIONS) + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + foreach(oneArgName ${oneValueArgs}) + if(DEFINED CPM_ARGS_${oneArgName}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + if(${oneArgName} STREQUAL "SOURCE_DIR") + string(REPLACE ${CMAKE_SOURCE_DIR} "\${CMAKE_SOURCE_DIR}" CPM_ARGS_${oneArgName} + ${CPM_ARGS_${oneArgName}} + ) + endif() + string(APPEND PRETTY_OUT_VAR " ${oneArgName} ${CPM_ARGS_${oneArgName}}\n") + endif() + endforeach() + foreach(multiArgName ${multiValueArgs}) + if(DEFINED CPM_ARGS_${multiArgName}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " ${multiArgName}\n") + foreach(singleOption ${CPM_ARGS_${multiArgName}}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " \"${singleOption}\"\n") + endforeach() + endif() + endforeach() + + if(NOT "${CPM_ARGS_UNPARSED_ARGUMENTS}" STREQUAL "") + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " ") + foreach(CPM_ARGS_UNPARSED_ARGUMENT ${CPM_ARGS_UNPARSED_ARGUMENTS}) + string(APPEND PRETTY_OUT_VAR " ${CPM_ARGS_UNPARSED_ARGUMENT}") + endforeach() + string(APPEND PRETTY_OUT_VAR "\n") + endif() + + set(${OUT_VAR} + ${PRETTY_OUT_VAR} + PARENT_SCOPE + ) + +endfunction() diff --git a/cmake/FindCudaExtension.cmake b/cmake/FindCudaExtension.cmake index a8fd20b3..fdff22cb 100644 --- a/cmake/FindCudaExtension.cmake +++ b/cmake/FindCudaExtension.cmake @@ -7,10 +7,30 @@ # find_package(CUDA) -include(FindPackageHandleStandardArgs) if(CUDA_FOUND) + if(USE_NVTX) + # see https://github.com/NVIDIA/NVTX + include(CPM.cmake) + + CPMAddPackage( + NAME NVTX + GITHUB_REPOSITORY NVIDIA/NVTX + GIT_TAG v3.1.0-c-cpp + GIT_SHALLOW TRUE) + + message(STATUS "CUDA NTVX_FOUND=${NTVX_FOUND}") + set(NVTX_LINK_C "nvtx3-c") + set(NVTX_LINK_CPP "nvtx3-cpp") + add_compile_definitions("ENABLE_NVTX") + else() + set(NVTX_LINK_C "") + set(NVTX_LINK_CPP "") + message(STATUS "CUDA NTVX not added.") + endif() + + include(FindPackageHandleStandardArgs) find_package_handle_standard_args( CudaExtension VERSION_VAR "0.1" @@ -38,6 +58,12 @@ function(cuda_pybind11_add_module name pybindfile) message(STATUS "CU ${pybindfile}") message(STATUS "CU ${ARGN}") cuda_add_library(${cuda_name} STATIC ${ARGN}) + target_include_directories( + ${cuda_name} PRIVATE + ${CPM_PACKAGE_NVTX_SOURCE_DIR}/include) local_pybind11_add_module(${name} "" ${pybindfile}) target_link_libraries(${name} PRIVATE ${cuda_name} stdc++) + if(USE_NVTX) + target_link_libraries(${name} PRIVATE nvtx3-cpp) + endif() endfunction() diff --git a/onnx_extended/validation/cuda/cuda_example.cu b/onnx_extended/validation/cuda/cuda_example.cu index 89f91163..ae5f0b63 100644 --- a/onnx_extended/validation/cuda/cuda_example.cu +++ b/onnx_extended/validation/cuda/cuda_example.cu @@ -1,4 +1,5 @@ #include "cuda_example.cuh" +#include "cuda_nvtx.cuh" #include "cuda_utils.h" #include #include @@ -25,6 +26,7 @@ void kernel_vector_add(unsigned int size, const float* gpu_ptr1, const float* gp void vector_add(unsigned int size, const float* ptr1, const float* ptr2, float* br, int cudaDevice) { // copy memory from CPU memory to CUDA memory + NVTX_SCOPE("vector_add") checkCudaErrors(cudaSetDevice(cudaDevice)); float *gpu_ptr1, *gpu_ptr2, *gpu_res; checkCudaErrors(cudaMalloc(&gpu_ptr1, size * sizeof(float))); @@ -106,6 +108,7 @@ float kernel_vector_sum_reduce0(float* gpu_ptr, unsigned int size, int maxThread float vector_sum0(unsigned int size, const float* ptr, int maxThreads, int cudaDevice) { // copy memory from CPU memory to CUDA memory + NVTX_SCOPE("vector_sum0") float *gpu_ptr; checkCudaErrors(cudaSetDevice(cudaDevice)); checkCudaErrors(cudaMalloc(&gpu_ptr, size * sizeof(float))); @@ -132,6 +135,7 @@ __global__ void vector_sum(float *input, float *output, unsigned int size) { float vector_sum_atomic(unsigned int size, const float* ptr, int maxThreads, int cudaDevice) { + NVTX_SCOPE("vector_sum_atomic") float *input, *output; float sum = 0.0f; cudaMalloc(&input, size * sizeof(float)); diff --git a/onnx_extended/validation/cuda/cuda_example.cuh b/onnx_extended/validation/cuda/cuda_example.cuh index 702acdaa..d325000c 100644 --- a/onnx_extended/validation/cuda/cuda_example.cuh +++ b/onnx_extended/validation/cuda/cuda_example.cuh @@ -1,5 +1,7 @@ namespace cuda_example { +unsigned int nextPow2(unsigned int x); + void vector_add(unsigned int size, const float *ptr1, const float *ptr2, float *ptr3, int cudaDevice); diff --git a/onnx_extended/validation/cuda/cuda_example_py.cpp b/onnx_extended/validation/cuda/cuda_example_py.cpp index a5fba912..24d12bb8 100644 --- a/onnx_extended/validation/cuda/cuda_example_py.cpp +++ b/onnx_extended/validation/cuda/cuda_example_py.cpp @@ -1,4 +1,5 @@ #include "cuda_example.cuh" +#include "cuda_example_reduce.cuh" #include #include #include @@ -74,6 +75,21 @@ of the same size with CUDA. }, py::arg("vect"), py::arg("max_threads") = 256, py::arg("cuda_device") = 0, R"pbdoc(Computes the sum of all coefficients with CUDA. Uses atomicAdd +:param vect: array +:param max_threads: number of threads to use (it must be a power of 2) +:param cuda_device: device id (if mulitple one) +:return: sum +)pbdoc"); + + m.def("vector_sum6", [](const py_array_float& vect, int max_threads, int cuda_device) -> float { + if (vect.size() == 0) + return 0; + auto ha = vect.request(); + const float* ptr = reinterpret_cast(ha.ptr); + return vector_sum6(static_cast(vect.size()), ptr, max_threads, cuda_device); + }, py::arg("vect"), py::arg("max_threads") = 256, py::arg("cuda_device") = 0, + R"pbdoc(Computes the sum of all coefficients with CUDA. More efficient method. + :param vect: array :param max_threads: number of threads to use (it must be a power of 2) :param cuda_device: device id (if mulitple one) diff --git a/onnx_extended/validation/cuda/cuda_example_reduce.cu b/onnx_extended/validation/cuda/cuda_example_reduce.cu new file mode 100644 index 00000000..debbc5ef --- /dev/null +++ b/onnx_extended/validation/cuda/cuda_example_reduce.cu @@ -0,0 +1,161 @@ +#include "cuda_example.cuh" +#include "cuda_example_reduce.cuh" +#include "cuda_nvtx.cuh" +#include "cuda_utils.h" +#include +#include +#include + +// https://github.com/mark-poscablo/gpu-sum-reduction/blob/master/sum_reduction/reduce.cu +// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf +// https://github.com/zchee/cuda-sample/blob/master/6_Advanced/reduction/reduction_kernel.cu +// https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/reduction/reduction_kernel.cu + +namespace cuda_example { + +#define reduce6_block_and_sync(I,I2) \ + if ((blockSize >= I) && (tid < I2)) { \ + sdata[tid] = mySum = mySum + sdata[tid + I2]; \ + } \ + __syncthreads(); + +template +__global__ void kernel_reduce6(const T *g_idata, T *g_odata, unsigned int n) { + extern __shared__ T sdata[]; + + unsigned int tid = threadIdx.x; + unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x; + unsigned int gridSize = blockSize * 2 * gridDim.x; + + // reduction per threads on all blocks + T mySum = 0; + while (i < n) { + mySum += g_idata[i]; + + if (nIsPow2 || i + blockSize < n) { + mySum += g_idata[i + blockSize]; + } + + i += gridSize; + } + + // using shared memory to store the reduction + sdata[tid] = mySum; + __syncthreads(); + + + // reduction within a block in shared memory + reduce6_block_and_sync(512, 256); + reduce6_block_and_sync(256, 128); + reduce6_block_and_sync(128, 64); + +#if (__CUDA_ARCH__ >= 300 ) + if (tid < 32) { + if (blockSize >= 64) { + mySum += sdata[tid + 32]; + } + // Reduce final warp using shuffle + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + // https://developer.nvidia.com/blog/faster-parallel-reductions-kepler/ + mySum += __shfl_down_sync(0xFFFFFFFF, mySum, offset); + } + } +#else + // fully unroll reduction within a single warp + reduce6_block_and_sync(64, 32); + reduce6_block_and_sync(32, 16); + reduce6_block_and_sync(16, 8); + reduce6_block_and_sync(8, 4); + reduce6_block_and_sync(4, 2); + reduce6_block_and_sync(2, 1); +#endif + + // write result for this block to global mem + if (tid == 0) { + g_odata[blockIdx.x] = mySum; + } +} + +bool isPow2(unsigned int n) { + if (n == 0) + return false; + return (n & (n - 1)) == 0; +} + +#define case_vector_sum_6_block(T, I, B) \ + case I: \ + kernel_reduce6<<>>(gpu_ptr, gpu_block_ptr, size); \ + break; + +float kernel_vector_sum_6(unsigned int size, const float* gpu_ptr, int maxThreads) { + + int threads = (size < maxThreads) ? nextPow2(size) : maxThreads; + int blocks = (size + threads - 1) / threads; + dim3 dimBlock(threads, 1, 1); + dim3 dimGrid(blocks, 1, 1); + float* gpu_block_ptr; + checkCudaErrors(cudaMalloc(&gpu_block_ptr, blocks * sizeof(float))); + int smemSize = (threads <= 32) ? 2 * threads * sizeof(float) : threads * sizeof(float); + + if (isPow2(size)) { + switch (threads) { + case_vector_sum_6_block(float, 512, true); + case_vector_sum_6_block(float, 256, true); + case_vector_sum_6_block(float, 128, true); + case_vector_sum_6_block(float, 64, true); + case_vector_sum_6_block(float, 32, true); + case_vector_sum_6_block(float, 16, true); + case_vector_sum_6_block(float, 8, true); + case_vector_sum_6_block(float, 4, true); + case_vector_sum_6_block(float, 2, true); + case_vector_sum_6_block(float, 1, true); + } + } + else { + switch (threads) { + case_vector_sum_6_block(float, 512, false); + case_vector_sum_6_block(float, 256, false); + case_vector_sum_6_block(float, 128, false); + case_vector_sum_6_block(float, 64, false); + case_vector_sum_6_block(float, 32, false); + case_vector_sum_6_block(float, 16, false); + case_vector_sum_6_block(float, 8, false); + case_vector_sum_6_block(float, 4, false); + case_vector_sum_6_block(float, 2, false); + case_vector_sum_6_block(float, 1, false); + } + } + + // the last reduction happens on CPU, the first step is to move + // the data from GPU to CPU. + float* cpu_ptr = new float[blocks]; + checkCudaErrors(cudaMemcpy(cpu_ptr, gpu_block_ptr, blocks * sizeof(float), cudaMemcpyDeviceToHost)); + float gpu_result = 0; + for (int i = 0; i < blocks; ++i) { + gpu_result += cpu_ptr[i]; + } + checkCudaErrors(cudaFree(gpu_block_ptr)); + delete[] cpu_ptr; + return gpu_result; + +} + +float vector_sum6(unsigned int size, const float* ptr, int maxThreads, + int cudaDevice) { + // copy memory from CPU memory to CUDA memory + NVTX_SCOPE("vector_sum6") + float *gpu_ptr; + checkCudaErrors(cudaSetDevice(cudaDevice)); + checkCudaErrors(cudaMalloc(&gpu_ptr, size * sizeof(float))); + checkCudaErrors(cudaMemcpy(gpu_ptr, ptr, size * sizeof(float), + cudaMemcpyHostToDevice)); + + // execute the code + float result = kernel_vector_sum_6(size, gpu_ptr, maxThreads); + + // free the allocated vectors + checkCudaErrors(cudaFree(gpu_ptr)); + return result; +} + +} // namespace cuda_example diff --git a/onnx_extended/validation/cuda/cuda_example_reduce.cuh b/onnx_extended/validation/cuda/cuda_example_reduce.cuh new file mode 100644 index 00000000..37b848f6 --- /dev/null +++ b/onnx_extended/validation/cuda/cuda_example_reduce.cuh @@ -0,0 +1,6 @@ +namespace cuda_example { + +float vector_sum6(unsigned int size, const float *ptr, int maxThreads, + int cudaDevice); + +} // namespace cuda_example diff --git a/onnx_extended/validation/cuda/cuda_nvtx.cuh b/onnx_extended/validation/cuda/cuda_nvtx.cuh new file mode 100644 index 00000000..62ba1141 --- /dev/null +++ b/onnx_extended/validation/cuda/cuda_nvtx.cuh @@ -0,0 +1,8 @@ +#pragma once + +#if defined(ENABLE_NVTX) +#include +#define NVTX_SCOPE(msg) nvtx3::scoped_range r{msg}; +#else +#define NVTX_SCOPE(msg) +#endif diff --git a/setup.py b/setup.py index c25a2e65..816e90de 100644 --- a/setup.py +++ b/setup.py @@ -192,6 +192,22 @@ def __init__(self, name: str, library: str = "") -> None: class cmake_build_ext(build_ext): + user_options = [ + *build_ext.user_options, + ("enable-nvtx=", None, "Enables compilation with NVTX events."), + ] + + def initialize_options(self): + self.enable_nvtx = None + build_ext.initialize_options(self) + + def finalize_options(self): + b_values = {None, 0, 1, "1", "0", True, False} + if self.enable_nvtx not in b_values: + raise ValueError(f"enable_nvtx={self.enable_nvtx!r} must be in {b_values}.") + self.enable_nvtx = self.enable_nvtx in {1, "1", True, "True"} + build_ext.finalize_options(self) + def build_extensions(self): # Ensure that CMake is present and working try: @@ -220,6 +236,8 @@ def build_extensions(self): f"-DPYTHON_VERSION_MM={versmm}", f"-DPYTHON_MODULE_EXTENSION={module_ext}", ] + if os.environ.get("USE_NVTX", "0") in (1, "1") or self.enable_nvtx: + cmake_args.append("-DUSE_NVTX=1") if iswin or isdar: include_dir = sysconfig.get_paths()["include"].replace("\\", "/") lib_dir = (