Skip to content

Commit

Permalink
repo-sync-2024-07-08T11:48:19+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Jul 8, 2024
1 parent c03eccf commit 2bbc132
Show file tree
Hide file tree
Showing 21 changed files with 321 additions and 89 deletions.
12 changes: 6 additions & 6 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b1.tar.gz",
],
strip_prefix = "yacl-0.4.5b3",
sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1",
strip_prefix = "yacl-0.4.5b1",
sha256 = "28064053b9add0db8e1e8e648421a0579f1d3e7ee8a4bbd7bd5959cb59598088",
)

def _libpsi():
Expand Down Expand Up @@ -169,10 +169,10 @@ def _com_github_pybind11():
http_archive,
name = "pybind11",
build_file = "@pybind11_bazel//:pybind11.BUILD",
sha256 = "bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7",
strip_prefix = "pybind11-2.12.0",
sha256 = "51631e88960a8856f9c497027f55c9f2f9115cafb08c0005439838a05ba17bfc",
strip_prefix = "pybind11-2.13.1",
urls = [
"https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz",
"https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.tar.gz",
],
)

Expand Down
2 changes: 1 addition & 1 deletion docs/reference/np_op_status.json

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions docs/reference/np_op_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,20 @@ Please check *Supported Dtypes* as well.
- uint16
- uint32

## bitwise_count

JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_count.html
### Status

**PASS**
Please check *Supported Dtypes* as well.
### Supported Dtypes

- int16
- int32
- uint16
- uint32

## bitwise_not

JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_not.html
Expand Down
6 changes: 3 additions & 3 deletions docs/reference/pphlo_doc.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
PPHlo API reference
PPHLO API reference
===================

PPHlo is short for (SPU High level ops), it's the assembly language of SPU.
PPHLO is short for (Privacy-Preserving High-Level Operations), it's the assembly language of SPU.

PPHlo is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here <spu/blob/main/libspu/dialect/pphlo/IR/ops.td>`.
PPHLO is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here <spu/blob/main/libspu/dialect/pphlo/IR/ops.td>`.

Op List
~~~~~~~
Expand Down
58 changes: 38 additions & 20 deletions docs/reference/pphlo_op_doc.md
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ Ref https://www.tensorflow.org/xla/operation_semantics#dot.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

Expand Down Expand Up @@ -1626,55 +1626,63 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values values

### `pphlo.power` (spu::pphlo::PowOp)
### `pphlo.popcnt` (spu::pphlo::PopcntOp)

_Power operator_
_Popcnt operator, ties away from zero_


Syntax:

```
operation ::= `pphlo.power` $lhs `,` $rhs attr-dict
`:` custom<SameOperandsAndResultType>(type($lhs), type($rhs), type($result))
operation ::= `pphlo.popcnt` $operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
```

Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor.
Performs element-wise count of the number of bits set in the `operand` tensor and produces a `result` tensor.

Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`

Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>bits</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values

#### Results:

| Result | Description |
| :----: | ----------- |
| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values

### `pphlo.prefer_a` (spu::pphlo::PreferAOp)
### `pphlo.power` (spu::pphlo::PowOp)

_Prefer AShare operator_
_Power operator_


Syntax:

```
operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
operation ::= `pphlo.power` $lhs `,` $rhs attr-dict
`:` custom<SameOperandsAndResultType>(type($lhs), type($rhs), type($result))
```

Convert input to AShare if possible.
Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor.

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`

Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Expand All @@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}`

| Operand | Description |
| :-----: | ----------- |
| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values

#### Results:

Expand Down Expand Up @@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor.
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
#### Attributes:
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>ignore_zero</code></td><td>::mlir::BoolAttr</td><td>bool attribute</td></tr>
</table>
#### Operands:
| Operand | Description |
Expand Down Expand Up @@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice
Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
Expand Down Expand Up @@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
Expand Down
3 changes: 2 additions & 1 deletion docs/reference/runtime_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ The SPU runtime configuration.
| Field | Type | Description |
| ----- | ---- | ----------- |
| server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. |
| session_id | [ string](#string) | if empty, use link id as session id. |
| adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. |
| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. |
| server_public_key | [ bytes](#bytes) | server's public key |
<!-- end Fields -->
<!-- end HasFields -->

Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/common/compilation_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace {

void SPUErrorHandler(void * /*use_data*/, const char *reason,
bool /*gen_crash_diag*/) {
SPU_THROW(reason);
SPU_THROW("{}", reason);
}

} // namespace
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/common/ir_printer_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
return;
}
print_callback(f);
}
Expand All @@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
return;
}
print_callback(f);
}
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/front_end/fe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ mlir::OwningOpRef<mlir::ModuleOp> FE::doit(const CompilationSource &source) {
module = mlir::parseSourceString<mlir::ModuleOp>(source.ir_txt(),
ctx_->getMLIRContext());

SPU_ENFORCE(module, "MLIR parser failure");

// Convert stablehlo to mhlo first
mlir::PassManager pm(ctx_->getMLIRContext());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
Expand Down
6 changes: 3 additions & 3 deletions libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
auto module_config =
xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options);
if (!module_config.status().ok()) {
SPU_THROW(module_config.status().message());
SPU_THROW("{}", module_config.status().message());
}

auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config);
if (!module.status().ok()) {
SPU_THROW(module.status().message());
SPU_THROW("{}", module.status().message());
}

xla::runHloPasses((*module).get());
Expand All @@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {

auto status = importer.Import(**module);
if (!status.ok()) {
SPU_THROW(status.message());
SPU_THROW("{}", status.message());
}

return mlir_hlo;
Expand Down
2 changes: 1 addition & 1 deletion libspu/device/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) {
(void)use_data;
(void)gen_crash_diag;
SPU_THROW(reason);
SPU_THROW("{}", reason);
}

std::mutex ErrorHandlerMutex;
Expand Down
28 changes: 22 additions & 6 deletions libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "libspu/core/bit_utils.h"
#include "libspu/core/context.h"
#include "libspu/core/trace.h"
#include "libspu/core/vectorize.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/prot_wrapper.h"
Expand All @@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) {
return _get_owner(x) == _get_owner(y);
}

void _hint_nbits(const Value &a, size_t nbits) {
if (a.storage_type().isa<BShare>()) {
const_cast<Type &>(a.storage_type()).as<BShare>()->setNbits(nbits);
}
}

// generate inverse permutation
Index _inverse_index(const Index &p) {
Index q(p.size());
Expand Down Expand Up @@ -531,20 +538,29 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm,
std::vector<spu::Value> _bit_decompose(SPUContext *ctx, const spu::Value &x,
int64_t valid_bits) {
auto x_bshare = _prefer_b(ctx, x);
const auto k1 = _constant(ctx, 1U, x.shape());
std::vector<spu::Value> rets;
size_t nbits = valid_bits != -1
? static_cast<size_t>(valid_bits)
: x_bshare.storage_type().as<BShare>()->nbits();
rets.reserve(nbits);
_hint_nbits(x_bshare, nbits);
if (ctx->hasKernel("b2a_disassemble")) {
auto ret =
dynDispatch<std::vector<spu::Value>>(ctx, "b2a_disassemble", x_bshare);
return ret;
}

const auto k1 = _constant(ctx, 1U, x.shape());
std::vector<spu::Value> rets_b;
rets_b.reserve(nbits);

for (size_t bit = 0; bit < nbits; ++bit) {
auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit);
auto lowest_bit = _and(ctx, x_bshare_shift, k1);
rets.emplace_back(_prefer_a(ctx, lowest_bit));
rets_b.push_back(_and(ctx, x_bshare_shift, k1));
}

return rets;
std::vector<spu::Value> rets_a;
vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a),
[&](const Value &x) { return _prefer_a(ctx, x); });
return rets_a;
}

// Generate vector of bit decomposition of sorting keys
Expand Down
11 changes: 11 additions & 0 deletions libspu/mpc/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const {
ctx->pushOutput(WrapValue(z));
}

void DisassembleKernel::evaluate(KernelEvalContext* ctx) const {
const auto& in = ctx->getParam<Value>(0);
auto z = proc(ctx, UnwrapValue(in));

std::vector<Value> wrapped(z.size());
for (size_t idx = 0; idx < z.size(); ++idx) {
wrapped[idx] = WrapValue(z[idx]);
}
ctx->pushOutput(wrapped);
};

void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const {
auto target = ctx->getParam<Value>(0);
auto s = ctx->getParam<int64_t>(1);
Expand Down
8 changes: 8 additions & 0 deletions libspu/mpc/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,12 @@ class ConcateKernel : public Kernel {
int64_t axis) const = 0;
};

class DisassembleKernel : public Kernel {
public:
void evaluate(KernelEvalContext* ctx) const override;

virtual std::vector<NdArrayRef> proc(KernelEvalContext* ctx,
const NdArrayRef& in) const = 0;
};

} // namespace spu::mpc
Loading

0 comments on commit 2bbc132

Please sign in to comment.