From da086b7c5c882167a8ded232771b6d2c14d520b5 Mon Sep 17 00:00:00 2001 From: Oliver Lee Date: Sat, 22 Jul 2023 22:30:03 +0200 Subject: [PATCH] Add a linear algebra library dependency Add `mdspan` and `stdBLAS` as dependencies in order to allow use of a linear algebra library. Change-Id: I9d2e1f43f6df315a922753d1d88e5904888b3d16 --- WORKSPACE.bazel | 55 ++++++++++++++++++++++++++++++++ test/BUILD.bazel | 2 ++ test/dummy_test.cpp | 76 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 131 insertions(+), 2 deletions(-) diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index c098d47..abcca43 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -124,6 +124,61 @@ http_archive( url = "https://github.com/fmtlib/fmt/archive/%s.tar.gz" % FMT_VERSION, ) +MDSPAN_VERSION = "9d0a451e11177cbdeaef035c7914b0aa73ddd1e2" + +http_archive( + name = "mdspan", + build_file_content = """ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "mdspan", + hdrs = glob(["include/**"]), + includes = ["include"], + defines = [ + "MDSPAN_USE_BRACKET_OPERATOR=1", + "MDSPAN_USE_PAREN_OPERATOR=1", # needed by stdBLAS + ], + visibility = ["//visibility:public"], +) +""", + sha256 = "d7751653cd93f3e73118796fd5f42f0424ef20ffeaa224005b4008b8987e6d81", + strip_prefix = "mdspan-%s" % MDSPAN_VERSION, + url = "https://github.com/kokkos/mdspan/archive/%s.tar.gz" % MDSPAN_VERSION, +) + +LINALG_VERSION = "d1a1a116a1a62a03d726d70220e573aa7c3dba68" + +http_archive( + name = "linalg", + build_file_content = """ +load("@rules_cc//cc:defs.bzl", "cc_library") + +# It appears this can be empty if we don't use atomic ref, blas, or kokkos +# https://github.com/kokkos/stdBLAS/blob/main/include/experimental/__p1673_bits/linalg_config.h.in +genrule( + name = "linalg_config", + outs = ["include/experimental/__p1673_bits/linalg_config.h"], + cmd = "touch $@", +) + +cc_library( + name = "linalg", + srcs = ["linalg_config"], + hdrs = glob(["include/**"]), + includes = [ + "include", + "include/experimental", + ], + deps = ["@mdspan"], + visibility = ["//visibility:public"], +) +""", + sha256 = "2cff70d080186949dcfa6da509eab41c3ba25df23c15bcbca8d810d0a1370bdb", + strip_prefix = "stdBLAS-%s" % LINALG_VERSION, + url = "https://github.com/kokkos/stdblas/archive/%s.tar.gz" % LINALG_VERSION, +) + BOOST_UT_VERSION = "e53a47d37bc594e80bd5f1b8dc1ade8dce4429d3" http_archive( diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 4366d60..c932793 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -7,5 +7,7 @@ cc_test( deps = [ "@boost_ut", "@fmt", + "@linalg", + "@mdspan", ], ) diff --git a/test/dummy_test.cpp b/test/dummy_test.cpp index f218cf6..c99a880 100644 --- a/test/dummy_test.cpp +++ b/test/dummy_test.cpp @@ -1,15 +1,87 @@ #include +#include +#include +#include +#include +#include #include +namespace stx = std::experimental; + +template +struct matrix +{ + static_assert(R != std::dynamic_extent); + static_assert(C != std::dynamic_extent); + + using extents_type = stx::extents; + using mdspan_type = stx::mdspan; + using const_mdspan_type = stx::mdspan; + + std::array data{}; + + matrix() = default; + + matrix(std::initializer_list> init) + { + assert(R == init.size()); + + auto i = std::size_t{}; + for (const auto& row : init) { + assert(C == row.size()); + + auto j = std::size_t{}; + for (const auto& elem : row) { + span()[i, j] = elem; + + ++j; + } + + ++i; + } + } + + [[nodiscard]] + constexpr auto span() -> mdspan_type { return mdspan_type{data.data()}; } + [[nodiscard]] + constexpr auto span() const -> const_mdspan_type + { + return const_mdspan_type{data.data()}; + } + + [[nodiscard]] + constexpr auto + operator[](std::size_t i, std::size_t j) -> mdspan_type::reference + { + return span()[i, j]; + } + [[nodiscard]] + constexpr auto + operator[](std::size_t i, std::size_t j) const -> const_mdspan_type::reference + { + return span()[i, j]; + } +}; + auto main() -> int { using ::boost::ut::expect; using ::boost::ut::test; + // NOLINTBEGIN(readability-magic-numbers) test("true is true") = [] { - const auto s = fmt::format("The answer is {}.", 42); + using Mat = matrix; + + auto A = Mat{{1, 2}, {3, 4}}; + auto B = Mat{{5, 6}, {7, 8}}; + auto C = Mat{}; + + stx::linalg::matrix_product(A.span(), B.span(), C.span()); + + const auto s = fmt::format("The answer is {}.", C[0, 0]); - expect("The answer is 42." == s); + expect("The answer is 19." == s); }; + // NOLINTEND(readability-magic-numbers) }