Skip to content

Commit

Permalink
Add a linear algebra library dependency
Browse files Browse the repository at this point in the history
Add `mdspan` and `stdBLAS` as dependencies in order to allow use of a
linear algebra library.

Change-Id: I9d2e1f43f6df315a922753d1d88e5904888b3d16
  • Loading branch information
oliverlee committed Jul 22, 2023
1 parent cf9f724 commit da086b7
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 2 deletions.
55 changes: 55 additions & 0 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,61 @@ http_archive(
url = "https://github.com/fmtlib/fmt/archive/%s.tar.gz" % FMT_VERSION,
)

MDSPAN_VERSION = "9d0a451e11177cbdeaef035c7914b0aa73ddd1e2"

http_archive(
name = "mdspan",
build_file_content = """
load("@rules_cc//cc:defs.bzl", "cc_library")
cc_library(
name = "mdspan",
hdrs = glob(["include/**"]),
includes = ["include"],
defines = [
"MDSPAN_USE_BRACKET_OPERATOR=1",
"MDSPAN_USE_PAREN_OPERATOR=1", # needed by stdBLAS
],
visibility = ["//visibility:public"],
)
""",
sha256 = "d7751653cd93f3e73118796fd5f42f0424ef20ffeaa224005b4008b8987e6d81",
strip_prefix = "mdspan-%s" % MDSPAN_VERSION,
url = "https://github.com/kokkos/mdspan/archive/%s.tar.gz" % MDSPAN_VERSION,
)

LINALG_VERSION = "d1a1a116a1a62a03d726d70220e573aa7c3dba68"

http_archive(
name = "linalg",
build_file_content = """
load("@rules_cc//cc:defs.bzl", "cc_library")
# It appears this can be empty if we don't use atomic ref, blas, or kokkos
# https://github.com/kokkos/stdBLAS/blob/main/include/experimental/__p1673_bits/linalg_config.h.in
genrule(
name = "linalg_config",
outs = ["include/experimental/__p1673_bits/linalg_config.h"],
cmd = "touch $@",
)
cc_library(
name = "linalg",
srcs = ["linalg_config"],
hdrs = glob(["include/**"]),
includes = [
"include",
"include/experimental",
],
deps = ["@mdspan"],
visibility = ["//visibility:public"],
)
""",
sha256 = "2cff70d080186949dcfa6da509eab41c3ba25df23c15bcbca8d810d0a1370bdb",
strip_prefix = "stdBLAS-%s" % LINALG_VERSION,
url = "https://github.com/kokkos/stdblas/archive/%s.tar.gz" % LINALG_VERSION,
)

BOOST_UT_VERSION = "e53a47d37bc594e80bd5f1b8dc1ade8dce4429d3"

http_archive(
Expand Down
2 changes: 2 additions & 0 deletions test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ cc_test(
deps = [
"@boost_ut",
"@fmt",
"@linalg",
"@mdspan",
],
)
76 changes: 74 additions & 2 deletions test/dummy_test.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,87 @@
#include <boost/ut.hpp>

#include <array>
#include <cassert>
#include <cstddef>
#include <experimental/linalg>
#include <experimental/mdspan>
#include <fmt/core.h>

namespace stx = std::experimental;

template <class T, std::size_t R, std::size_t C>
struct matrix
{
static_assert(R != std::dynamic_extent);
static_assert(C != std::dynamic_extent);

using extents_type = stx::extents<std::size_t, R, C>;
using mdspan_type = stx::mdspan<T, extents_type>;
using const_mdspan_type = stx::mdspan<const T, extents_type>;

std::array<T, R * C> data{};

matrix() = default;

matrix(std::initializer_list<std::initializer_list<T>> init)
{
assert(R == init.size());

auto i = std::size_t{};
for (const auto& row : init) {
assert(C == row.size());

auto j = std::size_t{};
for (const auto& elem : row) {
span()[i, j] = elem;

++j;
}

++i;
}
}

[[nodiscard]]
constexpr auto span() -> mdspan_type { return mdspan_type{data.data()}; }
[[nodiscard]]
constexpr auto span() const -> const_mdspan_type
{
return const_mdspan_type{data.data()};
}

[[nodiscard]]
constexpr auto
operator[](std::size_t i, std::size_t j) -> mdspan_type::reference
{
return span()[i, j];
}
[[nodiscard]]
constexpr auto
operator[](std::size_t i, std::size_t j) const -> const_mdspan_type::reference
{
return span()[i, j];
}
};

auto main() -> int
{
using ::boost::ut::expect;
using ::boost::ut::test;

// NOLINTBEGIN(readability-magic-numbers)
test("true is true") = [] {
const auto s = fmt::format("The answer is {}.", 42);
using Mat = matrix<int, 2, 2>;

auto A = Mat{{1, 2}, {3, 4}};
auto B = Mat{{5, 6}, {7, 8}};
auto C = Mat{};

stx::linalg::matrix_product(A.span(), B.span(), C.span());

const auto s = fmt::format("The answer is {}.", C[0, 0]);

expect("The answer is 42." == s);
expect("The answer is 19." == s);
};
// NOLINTEND(readability-magic-numbers)
}

0 comments on commit da086b7

Please sign in to comment.