From 51fa3edaf7c44581509b2cdbe953f304795cb684 Mon Sep 17 00:00:00 2001 From: "shanzhu.cjm" Date: Tue, 30 Jan 2024 11:49:54 +0800 Subject: [PATCH] repo-sync-2024-01-30T11:49:42+0800 --- bazel/patches/zlib.patch | 2 +- yacl/crypto/primitives/dpf/mpfss.cc | 2 +- yacl/math/gadget.h | 2 + yacl/math/galois_field/BUILD.bazel | 1 + yacl/math/galois_field/gf_scalar.h | 142 +----------------- .../mpint_field/mpint_field_test.cc | 2 +- yacl/utils/spi/BUILD.bazel | 2 +- yacl/utils/spi/argument/BUILD.bazel | 2 +- yacl/utils/spi/item.h | 100 ++++++++---- yacl/utils/spi/item_test.cc | 96 +++++++++++- yacl/utils/spi/sketch/BUILD.bazel | 32 ++++ yacl/utils/spi/sketch/scalar_call.h | 121 +++++++++++++++ yacl/utils/spi/sketch/scalar_define.h | 77 ++++++++++ 13 files changed, 404 insertions(+), 177 deletions(-) create mode 100644 yacl/utils/spi/sketch/BUILD.bazel create mode 100644 yacl/utils/spi/sketch/scalar_call.h create mode 100644 yacl/utils/spi/sketch/scalar_define.h diff --git a/bazel/patches/zlib.patch b/bazel/patches/zlib.patch index efeaed42..98ed9550 100644 --- a/bazel/patches/zlib.patch +++ b/bazel/patches/zlib.patch @@ -7,7 +7,7 @@ index 15ceebe..1d19c83 100644 include(CheckIncludeFile) include(CheckCSourceCompiles) -enable_testing() - + check_include_file(sys/types.h HAVE_SYS_TYPES_H) check_include_file(stdint.h HAVE_STDINT_H) @@ -193,26 +192,3 @@ endif() diff --git a/yacl/crypto/primitives/dpf/mpfss.cc b/yacl/crypto/primitives/dpf/mpfss.cc index 5dffc90d..208b72ea 100644 --- a/yacl/crypto/primitives/dpf/mpfss.cc +++ b/yacl/crypto/primitives/dpf/mpfss.cc @@ -169,7 +169,7 @@ void MpfssSend(const std::shared_ptr& ctx, } ctx->SendAsync( ctx->NextRank(), - ByteContainerView(send_msgs.data(), send_msgs.size() * sizeof(uint128_t)), + ByteContainerView(send_msgs.data(), send_msgs.size() * sizeof(uint64_t)), "MpVole_msg"); } diff --git a/yacl/math/gadget.h b/yacl/math/gadget.h index 19040dcd..336765a0 100644 --- a/yacl/math/gadget.h +++ b/yacl/math/gadget.h @@ -86,6 +86,8 @@ uint128_t inline GfMul(uint64_t a, uint128_t b) { // f2k-Universal Hash // ------------------------ +// see difference between universal hash and collision-resistent hash functions: +// https://crypto.stackexchange.com/a/88247/61581 template T UniversalHash(T seed, absl::Span data) { T ret = 0; diff --git a/yacl/math/galois_field/BUILD.bazel b/yacl/math/galois_field/BUILD.bazel index e0c8e508..43b560b8 100644 --- a/yacl/math/galois_field/BUILD.bazel +++ b/yacl/math/galois_field/BUILD.bazel @@ -34,6 +34,7 @@ yacl_cc_library( "//yacl/io/msgpack:buffer", "//yacl/io/msgpack:spec_traits", "//yacl/utils:parallel", + "//yacl/utils/spi/sketch", "@com_google_absl//absl/types:span", ], ) diff --git a/yacl/math/galois_field/gf_scalar.h b/yacl/math/galois_field/gf_scalar.h index e829e065..2dddaab9 100644 --- a/yacl/math/galois_field/gf_scalar.h +++ b/yacl/math/galois_field/gf_scalar.h @@ -20,6 +20,7 @@ #include "yacl/io/msgpack/spec_traits.h" #include "yacl/math/galois_field/gf_spi.h" #include "yacl/utils/parallel.h" +#include "yacl/utils/spi/sketch/scalar_define.h" namespace yacl::math { @@ -85,28 +86,6 @@ class GFScalarSketch : public GaloisField { virtual T DeserializeT(ByteContainerView buffer) const = 0; private: -#define DefineBoolUnaryFunc(FuncName) \ - Item FuncName(const Item& x) const override { \ - if (x.IsArray()) { \ - auto xsp = x.AsSpan(); \ - /* std::vector cannot write in parallel */ \ - std::vector res; \ - res.resize(xsp.length()); \ - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ - for (int64_t i = beg; i < end; ++i) { \ - res[i] = FuncName(xsp[i]); \ - } \ - }); \ - /* convert std::vector to std::vector */ \ - std::vector bv; \ - bv.resize(res.size()); \ - std::copy(res.begin(), res.end(), bv.begin()); \ - return Item::Take(std::move(bv)); \ - } else { \ - return FuncName(x.As()); \ - } \ - } - // if x is scalar, returns bool // if x is vectored, returns std::vector DefineBoolUnaryFunc(IsIdentityOne); @@ -151,39 +130,6 @@ class GFScalarSketch : public GaloisField { // operations defined on set // //================================// -#define DefineUnaryFunc(FuncName) \ - Item FuncName(const Item& x) const override { \ - using RES_T = decltype(FuncName(std::declval())); \ - \ - if (x.IsArray()) { \ - auto xsp = x.AsSpan(); \ - std::vector res; \ - res.resize(xsp.length()); \ - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ - for (int64_t i = beg; i < end; ++i) { \ - res[i] = FuncName(xsp[i]); \ - } \ - }); \ - return Item::Take(std::move(res)); \ - } else { \ - return FuncName(x.As()); \ - } \ - } - -#define DefineUnaryInplaceFunc(FuncName) \ - void FuncName(Item* x) const override { \ - if (x->IsArray()) { \ - auto xsp = x->AsSpan(); \ - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ - for (int64_t i = beg; i < end; ++i) { \ - FuncName(&xsp[i]); \ - } \ - }); \ - } else { \ - FuncName(x->As()); \ - } \ - } - // get the additive inverse −a for all elements in set DefineUnaryFunc(Neg); DefineUnaryInplaceFunc(NegInplace); @@ -192,65 +138,6 @@ class GFScalarSketch : public GaloisField { DefineUnaryFunc(Inv); DefineUnaryInplaceFunc(InvInplace); -#define DefineBinaryFunc(FuncName) \ - Item FuncName(const Item& x, const Item& y) const override { \ - using RES_T = \ - decltype(FuncName(std::declval(), std::declval())); \ - \ - switch (x, y) { \ - case OperandType::Scalar2Scalar: { \ - return FuncName(x.As(), y.As()); \ - } \ - case OperandType::Vector2Vector: { \ - auto xsp = x.AsSpan(); \ - auto ysp = y.AsSpan(); \ - YACL_ENFORCE_EQ( \ - xsp.length(), ysp.length(), \ - "operands must have the same length, x.len={}, y.len={}", \ - xsp.length(), ysp.length()); \ - \ - std::vector res; \ - res.resize(xsp.length()); \ - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ - for (int64_t i = beg; i < end; ++i) { \ - res[i] = FuncName(xsp[i], ysp[i]); \ - } \ - }); \ - return Item::Take(std::move(res)); \ - } \ - default: \ - YACL_THROW("GFScalarSketch method [{}] doesn't support broadcast now", \ - #FuncName); \ - } \ - } - -#define DefineBinaryInplaceFunc(FuncName) \ - void FuncName(yacl::Item* x, const yacl::Item& y) const override { \ - switch (*x, y) { \ - case OperandType::Scalar2Scalar: { \ - FuncName(x->As(), y.As()); \ - return; \ - } \ - case OperandType::Vector2Vector: { \ - auto xsp = x->AsSpan(); \ - auto ysp = y.AsSpan(); \ - YACL_ENFORCE_EQ( \ - xsp.length(), ysp.length(), \ - "operands must have the same length, x.len={}, y.len={}", \ - xsp.length(), ysp.length()); \ - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ - for (int64_t i = beg; i < end; ++i) { \ - FuncName(&xsp[i], ysp[i]); \ - } \ - }); \ - return; \ - } \ - default: \ - YACL_THROW("GFScalarSketch method [{}] doesn't support broadcast now", \ - #FuncName); \ - } \ - } - DefineBinaryFunc(Add); DefineBinaryInplaceFunc(AddInplace); @@ -264,34 +151,11 @@ class GFScalarSketch : public GaloisField { DefineBinaryInplaceFunc(DivInplace); Item Pow(const Item& x, const MPInt& y) const override { - if (x.IsArray()) { - auto xsp = x.AsSpan(); - std::vector res; - res.resize(xsp.length()); - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { - for (int64_t i = beg; i < end; ++i) { - res[i] = Pow(xsp[i], y); - } - }); - return Item::Take(std::move(res)); - } else { - return Pow(x.As(), y); - } + CallUnaryFunc(Pow, T, x, y); } void PowInplace(Item* x, const MPInt& y) const override { - if (x->IsArray()) { - auto xsp = x->AsSpan(); - yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { - for (int64_t i = beg; i < end; ++i) { - PowInplace(&xsp[i], y); - } - }); - return; - } else { - PowInplace(x->As(), y); - return; - } + CallUnaryInplaceFunc(PowInplace, T, x, y); } Item Random() const override { return RandomT(); } diff --git a/yacl/math/galois_field/mpint_field/mpint_field_test.cc b/yacl/math/galois_field/mpint_field/mpint_field_test.cc index 7b34ab06..5f6c3de4 100644 --- a/yacl/math/galois_field/mpint_field/mpint_field_test.cc +++ b/yacl/math/galois_field/mpint_field/mpint_field_test.cc @@ -213,7 +213,7 @@ TEST_F(MPIntFieldTest, VectorIoWorks) { // subspan auto item1 = Item::Take({0_mp, 1_mp, 2_mp, 3_mp}); - auto item2 = item1.SubSpan(0, 1); + auto item2 = item1.SubItem(0, 1); ASSERT_TRUE(item2.IsView()); ASSERT_FALSE(item2.IsReadOnly()); ASSERT_TRUE(item2.WrappedTypeIs>()); diff --git a/yacl/utils/spi/BUILD.bazel b/yacl/utils/spi/BUILD.bazel index 96705766..855a78a0 100644 --- a/yacl/utils/spi/BUILD.bazel +++ b/yacl/utils/spi/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2022 Ant Group Co., Ltd. +# Copyright 2023 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/yacl/utils/spi/argument/BUILD.bazel b/yacl/utils/spi/argument/BUILD.bazel index 04e4473a..3f4c1b7f 100644 --- a/yacl/utils/spi/argument/BUILD.bazel +++ b/yacl/utils/spi/argument/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2022 Ant Group Co., Ltd. +# Copyright 2023 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/yacl/utils/spi/item.h b/yacl/utils/spi/item.h index bffea4a3..c47d3207 100644 --- a/yacl/utils/spi/item.h +++ b/yacl/utils/spi/item.h @@ -46,8 +46,11 @@ struct is_container< template constexpr bool is_container_v = is_container::value; +// If T is a reference type, then is_const::value is always false. The +// proper way to check a potential-reference type for const-ness is to +// remove the reference: is_const::type> template -constexpr bool is_const_t = +constexpr bool is_const_v = std::is_const_v>>; template @@ -62,38 +65,35 @@ enum class OperandType : int { class Item { public: + // Take or copy scalar template >> - /* implicit */ Item(T&& value) : v_(std::forward(value)) { + /* implicit */ constexpr Item(T&& value) : v_(std::forward(value)) { Setup(false, false, false); } + // Ref a vector/span template >> static Item Ref(T& c) { - Item item; - item.v_ = absl::MakeSpan(c); - // If T is a reference type then is_const::value is always false. The - // proper way to check a potentially-reference type for const-ness is to - // remove the reference: is_const::type> - item.Setup(true, true, is_const_t); - return item; + return Item{absl::MakeSpan(c)}; } template static Item Ref(T* ptr, size_t len) { - Item item; - item.v_ = absl::MakeSpan(ptr, len); - item.Setup(true, true, is_const_t); - return item; + return Item{absl::MakeSpan(ptr, len)}; } + // Take vector template static Item Take(std::vector&& v) { - Item item; - item.v_ = std::move(v); - item.Setup(true, false, false); - return item; + return Item{std::move(v)}; } + Item(const Item&) = default; + Item(Item&&) = default; + Item& operator=(const Item&) = default; + Item& operator=(Item&&) = default; + virtual ~Item() = default; + template explicit operator T() const& { YACL_ENFORCE(v_.type() == typeid(T), "Type mismatch: convert {} to {} fail", @@ -172,7 +172,7 @@ class Item { // const array case if (IsReadOnly()) { YACL_ENFORCE( - is_const_t, + is_const_v, "This is a read-only item, please use AsSpan instead"); // const value -> const T @@ -220,6 +220,22 @@ class Item { } } + template + absl::Span ResizeAndSpan(size_t expected_size) { + auto sp = AsSpan(); + if (sp.size() == expected_size) { + return sp; + } + + // size doesn't match, now we try to resize + YACL_ENFORCE(IsArray() && !IsView() && !IsReadOnly(), + "Resize item fail, actual size={}, resize={}, item_info: {}", + sp.size(), expected_size, ToString()); + auto& vec = As>(); + vec.resize(expected_size); + return absl::MakeSpan(vec); + } + bool HasValue() const noexcept { return v_.has_value(); } // Check the type that directly stored, which is, container type + data type @@ -267,24 +283,50 @@ class Item { // operations only for array template - Item SubSpan(size_t pos, size_t len = absl::Span::npos) { + Item SubItem(size_t pos, size_t len = absl::Span::npos) { YACL_ENFORCE(IsArray(), "You cannot do slice for scalar value"); - if (IsReadOnly() || is_const_t) { + if (IsReadOnly() || is_const_v) { // force const - return SubConstSpanImpl>(pos, len); + return SubConstItemImpl>(pos, len); } - return SubSpanImpl>(pos, len); + return SubItemImpl>(pos, len); } template - Item SubSpan(size_t pos, size_t len = absl::Span::npos) const { + Item SubItem(size_t pos, size_t len = absl::Span::npos) const { YACL_ENFORCE(IsArray(), "You cannot do slice for scalar value"); - return SubConstSpanImpl>(pos, len); + return SubConstItemImpl>(pos, len); } - std::string ToString() const; + virtual std::string ToString() const; + + protected: + template + constexpr explicit Item(const absl::Span& span) : v_(span) { + Setup(true, true, is_const_v); + } + + template + constexpr explicit Item(std::vector&& vec) : v_(std::move(vec)) { + Setup(true, false, false); + } + + template + constexpr void SetSlot(uint8_t value) { + static_assert(slot >= 3 && slot + len <= sizeof(meta_) * 8); + constexpr auto mask = ((static_cast(1) << len) - 1) + << slot; + meta_ &= ~mask; // set target bits to zero + meta_ |= ((value << slot) & mask); + } + + template + constexpr uint8_t GetSlot() const { + static_assert(slot >= 3 && slot + len <= sizeof(meta_) * 8); + return (meta_ >> slot) & ((static_cast(1) << len) - 1); + } private: constexpr Item() {} @@ -297,9 +339,9 @@ class Item { // For developer: This is not a const func. DO NOT add const qualifier template - Item SubSpanImpl(size_t pos, size_t len) { + Item SubItemImpl(size_t pos, size_t len) { YACL_ENFORCE(!IsReadOnly(), - "Cannot make a read-write subspan of a const span"); + "Cannot make a read-write sub-item of a const span"); Item item; if (IsView()) { @@ -314,7 +356,7 @@ class Item { } template - Item SubConstSpanImpl(size_t pos, size_t len) const { + Item SubConstItemImpl(size_t pos, size_t len) const { Item item; if (IsView()) { if (IsReadOnly()) { @@ -352,12 +394,12 @@ class Item { return res.load(); } + std::any v_; // The format of meta: // bit 0 -> is array? 0 - scalar; 1 - array // bit 1 -> is view ? 0 - hold value; 1 - ref/view // bit 2 -> is const? 0 - rw; 1 - read-only uint8_t meta_ = 0; - std::any v_; }; std::ostream& operator<<(std::ostream& os, const Item& a); diff --git a/yacl/utils/spi/item_test.cc b/yacl/utils/spi/item_test.cc index 2f55e062..de2798be 100644 --- a/yacl/utils/spi/item_test.cc +++ b/yacl/utils/spi/item_test.cc @@ -36,7 +36,7 @@ TEST(ItemTest, RefRW) { EXPECT_EQ(v_ref2[2], 30); // sub span - auto s_item = item.SubSpan(1); + auto s_item = item.SubItem(1); EXPECT_TRUE(s_item.IsArray()); EXPECT_TRUE(s_item.IsView()); EXPECT_FALSE(s_item.IsReadOnly()); @@ -50,7 +50,7 @@ TEST(ItemTest, RefRW) { EXPECT_EQ(s_ref2.size(), 2); // sub const span - auto c_item = item.SubSpan(1, 1000); + auto c_item = item.SubItem(1, 1000); EXPECT_TRUE(c_item.IsArray()); EXPECT_TRUE(c_item.IsView()); EXPECT_TRUE(c_item.IsReadOnly()); @@ -66,6 +66,8 @@ TEST(ItemTest, RefRW) { TEST(ItemTest, RefRO) { const std::vector v = {1, 2, 3}; + ASSERT_TRUE(is_container_v); + Item item = Item::Ref(v); EXPECT_TRUE(item.IsArray()); EXPECT_TRUE(item.IsView()); @@ -83,7 +85,7 @@ TEST(ItemTest, RefRO) { EXPECT_EQ(v_ref2[2], 3); // sub const span - auto s_item = item.SubSpan(1); + auto s_item = item.SubItem(1); EXPECT_TRUE(s_item.IsArray()); EXPECT_TRUE(s_item.IsView()); EXPECT_TRUE(s_item.IsReadOnly()); @@ -92,7 +94,7 @@ TEST(ItemTest, RefRO) { auto s_ref = s_item.AsSpan(); EXPECT_EQ(s_ref[0], 2); - const auto s_item2 = item2.SubSpan(2); + const auto s_item2 = item2.SubItem(2); EXPECT_TRUE(s_item2.IsArray()); EXPECT_TRUE(s_item2.IsView()); EXPECT_TRUE(s_item2.IsReadOnly()); @@ -100,4 +102,90 @@ TEST(ItemTest, RefRO) { EXPECT_EQ(s_ref2[0], 3); } +TEST(ItemTest, RefPtr) { + int arr[] = {1, 2, 3}; + auto item = Item::Ref(arr, 3); + EXPECT_TRUE(item.IsArray()); + EXPECT_TRUE(item.IsView()); + EXPECT_FALSE(item.IsReadOnly()); + + const int arr2[] = {1, 2, 3}; + item = Item::Ref(arr2, 3); + EXPECT_TRUE(item.IsArray()); + EXPECT_TRUE(item.IsView()); + EXPECT_TRUE(item.IsReadOnly()); +} + +TEST(ItemTest, ResizeAndSpan) { + auto item = Item::Take(std::vector()); + EXPECT_EQ(item.AsSpan().size(), 0); + + auto sp = item.ResizeAndSpan(100); + EXPECT_EQ(sp.size(), 100); + sp[29] = 456; + + EXPECT_EQ(item.SubItem(29, 20).AsSpan()[0], 456); +} + +class DummyItem : public Item { + public: + using Item::Item; + + // make function public + template + constexpr void ProxySetSlot(uint8_t value) { + Item::SetSlot(value); + } + + template + constexpr uint8_t ProxyGetSlot() const { + return Item::GetSlot(); + } +}; + +TEST(ItemTest, SlotWorks) { + DummyItem item = 123456; + + item.ProxySetSlot<6>(1); + EXPECT_EQ(item.ProxyGetSlot<6>(), 1); + item.ProxySetSlot<6>(0); + EXPECT_EQ(item.ProxyGetSlot<6>(), 0); + + item.ProxySetSlot<3>(1); + item.ProxySetSlot<5, 2>(0b11); // slot 5, 6 = 1, 1 + EXPECT_EQ((item.ProxyGetSlot<5, 2>()), 0b11); + EXPECT_EQ(item.ProxyGetSlot<3>(), 1); + EXPECT_EQ(item.ProxyGetSlot<4>(), 0); + EXPECT_EQ(item.ProxyGetSlot<5>(), 1); + EXPECT_EQ(item.ProxyGetSlot<6>(), 1); + EXPECT_EQ(item.ProxyGetSlot<4>(), 0); + + item.ProxySetSlot<5, 2>(0b01); // slot 5, 6 = 1, 0 + EXPECT_EQ((item.ProxyGetSlot<5, 2>()), 0b01); + EXPECT_EQ(item.ProxyGetSlot<3>(), 1); + EXPECT_EQ(item.ProxyGetSlot<4>(), 0); + EXPECT_EQ(item.ProxyGetSlot<5>(), 1); + EXPECT_EQ(item.ProxyGetSlot<6>(), 0); + EXPECT_EQ(item.ProxyGetSlot<4>(), 0); + + item.ProxySetSlot<5, 3>(0); + EXPECT_EQ((item.ProxyGetSlot<5, 2>()), 0); + EXPECT_EQ((item.ProxyGetSlot<5, 3>()), 0); + + EXPECT_FALSE(item.IsArray()); + EXPECT_FALSE(item.IsView()); + EXPECT_FALSE(item.IsReadOnly()); +} + +TEST(ItemTest, ItemInContainer) { + std::vector items; + EXPECT_NO_THROW(items.emplace_back(123)); + EXPECT_NO_THROW(items.emplace_back("haha")); + EXPECT_NO_THROW(items.emplace_back(true)); + + EXPECT_NO_THROW(items.push_back(456ull)); + EXPECT_NO_THROW( + items.push_back(Item::Take(std::vector{"h", "e", "l", "l", "o"}))); +} + } // namespace yacl::test diff --git a/yacl/utils/spi/sketch/BUILD.bazel b/yacl/utils/spi/sketch/BUILD.bazel new file mode 100644 index 00000000..c9b54a61 --- /dev/null +++ b/yacl/utils/spi/sketch/BUILD.bazel @@ -0,0 +1,32 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:yacl.bzl", "yacl_cc_library") + +package(default_visibility = ["//visibility:public"]) + +yacl_cc_library( + name = "sketch", + deps = [ + ":sketch_macros", + ], +) + +yacl_cc_library( + name = "sketch_macros", + hdrs = [ + "scalar_call.h", + "scalar_define.h", + ], +) diff --git a/yacl/utils/spi/sketch/scalar_call.h b/yacl/utils/spi/sketch/scalar_call.h new file mode 100644 index 00000000..3102e81a --- /dev/null +++ b/yacl/utils/spi/sketch/scalar_call.h @@ -0,0 +1,121 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// Who are using those macros? +// - galois field SPI (YACL) +// - HE SPI (HEU) + +// Call: +// virtual T FuncName(const T& x) const = 0; +// virtual T FuncName(const T& x, ...) const = 0; +#define CallUnaryFunc(FuncName, T, x, ...) \ + do { \ + using RES_T = decltype(FuncName(std::declval(), ##__VA_ARGS__)); \ + \ + if (x.IsArray()) { \ + auto xsp = x.AsSpan(); \ + std::vector res; \ + res.resize(xsp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + res[i] = FuncName(xsp[i], ##__VA_ARGS__); \ + } \ + }); \ + return Item::Take(std::move(res)); \ + } else { \ + return FuncName(x.As(), ##__VA_ARGS__); \ + } \ + } while (0) + +// Call: +// virtual void FuncName(T* x) const = 0; +// virtual void FuncName(T* x, ...) const = 0; +#define CallUnaryInplaceFunc(FuncName, T, x, ...) \ + do { \ + if (x->IsArray()) { \ + auto xsp = x->AsSpan(); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + FuncName(&xsp[i], ##__VA_ARGS__); \ + } \ + }); \ + } else { \ + FuncName(x->As(), ##__VA_ARGS__); \ + } \ + } while (0) + +// Call: +// virtual T FuncName(const TX& x, const TY& y) const = 0; +#define CallBinaryFunc(FuncName, TX, TY) \ + do { \ + using RES_T = decltype(FuncName(std::declval(), \ + std::declval())); \ + \ + switch (x, y) { \ + case yacl::OperandType::Scalar2Scalar: { \ + return FuncName(x.As(), y.As()); \ + } \ + case yacl::OperandType::Vector2Vector: { \ + auto xsp = x.AsSpan(); \ + auto ysp = y.AsSpan(); \ + YACL_ENFORCE_EQ( \ + xsp.length(), ysp.length(), \ + "operands must have the same length, x.len={}, y.len={}", \ + xsp.length(), ysp.length()); \ + \ + std::vector res; \ + res.resize(xsp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + res[i] = FuncName(xsp[i], ysp[i]); \ + } \ + }); \ + return Item::Take(std::move(res)); \ + } \ + default: \ + YACL_THROW("Scalar sketch method [{}] doesn't support broadcast now", \ + #FuncName); \ + } \ + } while (0) + +// Call: +// virtual void FuncName(TX* x, const TY& y) const = 0; +#define CallBinaryInplaceFunc(FuncName, TX, TY) \ + do { \ + switch (*x, y) { \ + case yacl::OperandType::Scalar2Scalar: { \ + FuncName(x->As(), y.As()); \ + return; \ + } \ + case yacl::OperandType::Vector2Vector: { \ + auto xsp = x->AsSpan(); \ + auto ysp = y.AsSpan(); \ + YACL_ENFORCE_EQ( \ + xsp.length(), ysp.length(), \ + "operands must have the same length, x.len={}, y.len={}", \ + xsp.length(), ysp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + FuncName(&xsp[i], ysp[i]); \ + } \ + }); \ + return; \ + } \ + default: \ + YACL_THROW("Scalar sketch method [{}] doesn't support broadcast now", \ + #FuncName); \ + } \ + } while (0) diff --git a/yacl/utils/spi/sketch/scalar_define.h b/yacl/utils/spi/sketch/scalar_define.h new file mode 100644 index 00000000..7d73a6a4 --- /dev/null +++ b/yacl/utils/spi/sketch/scalar_define.h @@ -0,0 +1,77 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/utils/spi/sketch/scalar_call.h" + +// From: +// virtual Item FuncName(const Item& x) const = 0; +// To: +// virtual bool FuncName(const T& x) const = 0; +#define DefineBoolUnaryFunc(FuncName) \ + Item FuncName(const Item& x) const override { \ + if (x.IsArray()) { \ + auto xsp = x.AsSpan(); \ + /* std::vector cannot write in parallel */ \ + std::vector res; \ + res.resize(xsp.length()); \ + yacl::parallel_for(0, xsp.length(), [&](int64_t beg, int64_t end) { \ + for (int64_t i = beg; i < end; ++i) { \ + res[i] = FuncName(xsp[i]); \ + } \ + }); \ + /* convert std::vector to std::vector */ \ + std::vector bv; \ + bv.resize(res.size()); \ + std::copy(res.begin(), res.end(), bv.begin()); \ + return Item::Take(std::move(bv)); \ + } else { \ + return FuncName(x.As()); \ + } \ + } + +// From: +// virtual Item FuncName(const Item& x) const = 0; +// To: +// virtual T FuncName(const T& x) const = 0; +#define DefineUnaryFunc(FuncName) \ + Item FuncName(const Item& x) const override { CallUnaryFunc(FuncName, T, x); } + +// From: +// virtual void FuncName(Item* x) const = 0; +// To: +// virtual void FuncName(T* x) const = 0; +#define DefineUnaryInplaceFunc(FuncName) \ + void FuncName(Item* x) const override { \ + CallUnaryInplaceFunc(FuncName, T, x); \ + } + +// From: +// virtual Item FuncName(const Item& x, const Item& y) const = 0; +// To: +// virtual T FuncName(const T& x, const T& y) const = 0; +#define DefineBinaryFunc(FuncName) \ + Item FuncName(const Item& x, const Item& y) const override { \ + CallBinaryFunc(FuncName, T, T); \ + } + +// From: +// virtual void FuncName(Item* x, const Item& y) const = 0; +// To: +// virtual void FuncName(T* x, const T& y) const = 0; +#define DefineBinaryInplaceFunc(FuncName) \ + void FuncName(Item* x, const Item& y) const override { \ + CallBinaryInplaceFunc(FuncName, T, T); \ + }