Skip to content

Commit

Permalink
repo-sync-2024-07-29T15:33:18+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Jul 29, 2024
1 parent 54ec5f1 commit 94c9bd5
Show file tree
Hide file tree
Showing 21 changed files with 46 additions and 109 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
## staging
>
> please add your unreleased change here.
- [Feature] Add more send/recv actions profiling

## 20240716
Expand Down
16 changes: 8 additions & 8 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3_nightly_20240722.tar.gz",
],
strip_prefix = "yacl-0.4.5b3",
sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1",
strip_prefix = "yacl-0.4.5b3_nightly_20240722",
sha256 = "ccca599e6ded6089c5afbb87c8f5e09383195af256caacd50089f0c7443e8604",
)

def _libpsi():
maybe(
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0beta.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.1.dev240722.tar.gz",
],
strip_prefix = "psi-0.4.0beta",
sha256 = "c2fbf486a66eca9d3ec1725a81d93a7c6e80a9206ef1c9263a1608e0bef95e1a",
strip_prefix = "psi-0.4.1.dev240722",
sha256 = "878cd8af2c7b9850944a27adf91f21dd4937d09d38e8365baad3b5165db8b39a",
)

def _rules_proto_grpc():
Expand Down Expand Up @@ -136,8 +136,8 @@ def _bazel_skylib():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "8533a6869ae02fb3b15a8a12739a982fc3c9f6e7"
OPENXLA_SHA256 = "d5b076825c992f59542f6b94e5480c7e7c6c627cd18c80ec60b6d5b295c160d4"
OPENXLA_COMMIT = "04f2bfe797408c9efe742b89e2e4db6cf526ebb7"
OPENXLA_SHA256 = "7e1d24737815be7607eed5f02fe7f81d97ffe358dfb7b4876f97bce8f48b3b3e"

# We need openxla to handle xla/mhlo/stablehlo
maybe(
Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/tests/interpret/and.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT

Expand Down
8 changes: 5 additions & 3 deletions libspu/compiler/tests/interpret/generate_mlir_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None):
f.write(
"// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n"
)
f.write(
"// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n"
)
# FIXME: these tests are not stable for cheetah now
if test not in ["xor", "or", "and"]:
f.write(
"// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n"
)
# Some test values in max and min are not supported by protocol 5.
if test not in ["max", "min"]:
f.write(
Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/tests/interpret/or.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT

Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/tests/interpret/xor.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT

Expand Down
4 changes: 2 additions & 2 deletions libspu/compiler/tests/passes/optimizations/ops_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func.func @main() -> tensor<i32> {

func.func @main() -> tensor<i32> {
%0 = pphlo.constant dense<[0.000000e+00, -3.40282347E+38]> : tensor<2xf32>
// expected-error @+1 {{op broadcast_dimensions contains invalid value -6 for result with rank 1}}
// expected-error @+1 {{broadcast_dimensions contains invalid value -6 for result with rank 1}}
%1 = pphlo.broadcast %0, dims = [-6] : (tensor<2xf32>) -> tensor<2xf32>
%2 = pphlo.constant dense<5> : tensor<i32>
pphlo.return %2 : tensor<i32>
Expand All @@ -33,7 +33,7 @@ func.func @main() -> tensor<i32> {
// -----

func.func @main() -> tensor<i32> {
// expected-error @+1 {{op iota dimension cannot go beyond the output rank or be negative}}
// expected-error @+1 {{iota dimension cannot go beyond the output rank}}
%0 = pphlo.iota dim = 1000 : tensor<1xi32>
%1 = pphlo.constant dense<5> : tensor<i32>
pphlo.return %1 : tensor<i32>
Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "llvm/ADT/Twine.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir::spu {
Expand Down
80 changes: 4 additions & 76 deletions libspu/dialect/pphlo/IR/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,12 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "stablehlo/dialect/TypeInference.h"

#include "libspu/dialect/pphlo/IR/ops.h.inc"

namespace mlir::spu::pphlo {

namespace {

// Checks if the vector `nums` has duplicates.
bool hasDuplicates(const ArrayRef<int64_t> nums) {
llvm::SmallDenseSet<int64_t> set(nums.begin(), nums.end());
return set.size() != nums.size();
}

} // namespace

template <typename T>
static LogicalResult Verify(T /*op*/) {
return success();
Expand Down Expand Up @@ -386,75 +377,12 @@ LogicalResult ConcatenateOp::verify() {
}

LogicalResult BroadcastOp::verify() {
auto operandType = mlir::dyn_cast<RankedTensorType>(getOperand().getType());

auto operandRank = operandType.getRank();

if (getBroadcastDimensions().empty()) {
if (operandRank == 0) {
return success();
}
return emitOpError(
llvm::formatv("broadcast_dimensions is absent, but required because "
"operand has non-zero rank ({0})",
operandRank));
}

auto dimensionsSize = getBroadcastDimensions().size();
if (static_cast<int64_t>(dimensionsSize) != operandRank) {
return emitOpError(llvm::formatv(
"broadcast_dimensions size ({0}) does not match operand rank ({1})",
dimensionsSize, operandRank));
}

auto dimensions = getBroadcastDimensions();
if (hasDuplicates(dimensions)) {
return emitOpError("broadcast_dimensions should not have duplicates");
}

auto resultType = mlir::dyn_cast<RankedTensorType>(getResult().getType());
auto resultRank = resultType.getRank();

for (size_t i = 0; i != dimensionsSize; ++i) {
auto dimIndex = dimensions[i];
if ((dimIndex >= resultRank) || (dimIndex < 0)) {
return emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result with rank {1}",
dimIndex, resultRank));
}

if (!operandType.isDynamicDim(i)) {
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
return emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
"1 or size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
}
}
}

return success();
return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(),
getBroadcastDimensions(), getResult());
}

LogicalResult IotaOp::verify() {
auto shape = mlir::dyn_cast<ShapedType>(getType());
if (!shape.hasRank()) {
return success();
}

if (shape.getRank() == 0) {
return emitOpError() << "does not support scalars.";
}

auto iotaDimension = static_cast<int64_t>(this->getIotaDimension());
if (iotaDimension >= shape.getRank() || iotaDimension < 0) {
return emitOpError()
<< "iota dimension cannot go beyond the output rank or be negative.";
}
return success();
return hlo::verifyIotaOp(getLoc(), getIotaDimension(), getResult());
}

LogicalResult SliceOp::verify() {
Expand Down
10 changes: 10 additions & 0 deletions libspu/dialect/pphlo/IR/print_parse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,19 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
if (parser.parseRParen()) {
return failure();
}
// Parse optional properties
if (succeeded(parser.parseOptionalLess()) &&
(failed(parser.parseAttribute(result.propertiesAttr)) ||
failed(parser.parseGreater()))) {
return failure();
}

// Parse optional attributes
if (parser.parseOptionalAttrDict(result.attributes)) {
return failure();
}

// Parse type signature
if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
parser.parseArrow()) {
return failure();
Expand Down
5 changes: 2 additions & 3 deletions libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ class CaseConverter : public OpRewritePattern<CaseOp> {
if (target_type.getNumElements() == in_type.getNumElements()) {
return rewriter.create<ReshapeOp>(loc, broadcasted_mask_type, in);
} else {
return rewriter.create<BroadcastOp>(
loc, broadcasted_mask_type, in,
llvm::SmallVector<int64_t>(target_type.getRank(), 0));
return rewriter.create<BroadcastOp>(loc, broadcasted_mask_type, in,
llvm::SmallVector<int64_t>{0});
}
}

Expand Down
6 changes: 3 additions & 3 deletions libspu/mpc/aby3/boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in,

// TODO: the hal dtype should tell us about the max number of possible bits.
const auto field = ctx->getState<Z2kState>()->getDefaultField();
const size_t out_nbits =
std::min<size_t>(in_ty->nbits() + *std::max_element(bits.begin(), bits.end()),
SizeOf(field) * 8);
const size_t out_nbits = std::min<size_t>(
in_ty->nbits() + *std::max_element(bits.begin(), bits.end()),
SizeOf(field) * 8);
const PtType out_btype = calcBShareBacktype(out_nbits);
bool is_splat = bits.size() == 1;

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#include "yacl/kernel/algorithms/base_ot.h"
#include "yacl/kernel/algorithms/ferret_ote.h"
#include "yacl/kernel/algorithms/iknp_ote.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/softspoken_ote.h"
#include "yacl/kernel/type/ot_store.h"

#include "libspu/core/prelude.h"
#include "libspu/mpc/cheetah/ot/ot_util.h"
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/securenn/boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in,

int64_t out_nbits = in.eltype().as<BShare>()->nbits() +
*std::max_element(shift.begin(), shift.end());
out_nbits =
std::clamp<int64_t>(out_nbits, 0L, static_cast<int64_t>(SizeOf(field) * 8));
out_nbits = std::clamp<int64_t>(out_nbits, 0L,
static_cast<int64_t>(SizeOf(field) * 8));

return makeBShare(ring_lshift(in, shift), field, out_nbits);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ int main(int argc, char* argv[]) {
std::string key;
SPU_ENFORCE(
butil::Base64Decode(ttp_server_config::FLAGS_server_private_key, &key));
decode_private_key =
yacl::Buffer(decode_private_key.data(), decode_private_key.size());
decode_private_key = yacl::Buffer(key.data(), key.size());
}

spu::mpc::semi2k::beaver::ttp_server::ServerOptions ops{
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ spu_cc_library(
"//libspu/mpc/spdz2k/ot:tiny_ot",
"//libspu/mpc/utils:ring_ops",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/kernel/algorithms:ot_store",
"@yacl//yacl/kernel/type:ot_store",
"@yacl//yacl/link",
"@yacl//yacl/utils:matrix_utils",
"@yacl//yacl/utils:serialize",
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/beaver_tinyot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "yacl/crypto/rand/rand.h"
#include "yacl/crypto/tools/prg.h"
#include "yacl/kernel/algorithms/base_ot.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"
#include "yacl/utils/serialize.h"

#include "libspu/mpc/common/prg_tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/beaver_tinyot.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#pragma once

#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"
#include "yacl/link/context.h"

#include "libspu/mpc/common/prg_state.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ spu_cc_library(
"//libspu/mpc/utils:ring_ops",
"@com_github_emptoolkit_emp_tool//:emp-tool",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/kernel/algorithms:ot_store",
"@yacl//yacl/kernel/type:ot_store",
"@yacl//yacl/link",
],
)
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/kos_ote.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once
#include "absl/types/span.h"
#include "yacl/base/dynamic_bitset.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"
#include "yacl/link/link.h"
namespace spu::mpc::spdz2k {

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/tiny_ot.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#include <vector>

#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"

#include "libspu/mpc/common/communicator.h"

Expand Down

0 comments on commit 94c9bd5

Please sign in to comment.