Skip to content

Commit

Permalink
repo-sync-2023-08-25T22:24:06+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Aug 25, 2023
1 parent cac1c0b commit 342c47c
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 18 deletions.
2 changes: 1 addition & 1 deletion libspu/core/pt_buffer_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct PtBufferView {
void const* const ptr; // Pointer to the underlying storage
PtType const pt_type; // Plaintext data type.
Shape const shape; // Shape of the tensor.
Strides const strides; // Strides in byte.
Strides const strides; // Strides in number of elements.

// We have to take a concrete buffer as a view.
PtBufferView() = delete;
Expand Down
87 changes: 70 additions & 17 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,38 +58,49 @@ std::string mlirObjectToString(T &&mlir_obj) {
return buf;
}

spu::PtType getPtTypeFromMlirType(mlir::Type mlir_ty) {
std::pair<spu::PtType, bool> getPtTypeFromMlirType(mlir::Type mlir_ty) {
mlir::pphlo::TypeTools tool;
auto express_type = tool.getExpressedType(mlir_ty);

if (auto ft = express_type.dyn_cast<mlir::FloatType>()) {
switch (ft.getWidth()) {
case 16:
return spu::PT_F16;
return {spu::PT_F16, false};
case 32:
return spu::PT_F32;
return {spu::PT_F32, false};
case 64:
return spu::PT_F64;
return {spu::PT_F64, false};
}
} else if (auto it = express_type.dyn_cast<mlir::IntegerType>()) {
if (it.getWidth() == 1) {
return spu::PT_BOOL;
return {spu::PT_BOOL, false};
}
// In mlir, isSigned is for si[1-9][0-9]* type, isUnsigned is for
// ui[1-9][0-9]*, i[1-9][0-9]* is signless IntegerType... So here, we only
// check for isUnsigned, signless we treat it as signed.
// See https://reviews.llvm.org/D72533
switch (it.getWidth()) {
case 8:
return it.isUnsigned() ? spu::PT_U8 : spu::PT_I8;
return it.isUnsigned() ? std::make_pair(spu::PT_U8, false)
: std::make_pair(spu::PT_I8, false);
case 16:
return it.isUnsigned() ? spu::PT_U16 : spu::PT_I16;
return it.isUnsigned() ? std::make_pair(spu::PT_U16, false)
: std::make_pair(spu::PT_I16, false);
case 32:
return it.isUnsigned() ? spu::PT_U32 : spu::PT_I32;
return it.isUnsigned() ? std::make_pair(spu::PT_U32, false)
: std::make_pair(spu::PT_I32, false);
case 64:
return it.isUnsigned() ? spu::PT_U64 : spu::PT_I64;
return it.isUnsigned() ? std::make_pair(spu::PT_U64, false)
: std::make_pair(spu::PT_I64, false);
}
} else if (auto ct = express_type.dyn_cast<mlir::ComplexType>()) {
if (ct.getElementType().isF32()) {
return {spu::PT_F32, true};
} else if (ct.getElementType().isF64()) {
return {spu::PT_F64, true};
}
}

SPU_THROW("invalid type {}", mlirObjectToString(mlir_ty));
}

Expand Down Expand Up @@ -122,6 +133,12 @@ spu::DataType getDtypeFromMlirType(mlir::Type mlir_ty) {
default:
SPU_THROW("unsupported fp type {}", mlirObjectToString(flp_ty));
}
} else if (auto ct = express_type.dyn_cast<mlir::ComplexType>()) {
if (ct.getElementType().isF32()) {
return spu::DT_F32;
} else if (ct.getElementType().isF64()) {
return spu::DT_F64;
}
}
SPU_THROW("invalid type {}", mlirObjectToString(mlir_ty));
}
Expand Down Expand Up @@ -177,6 +194,9 @@ void do_type_checker(mlir::Value key, const spu::Value &val,
auto expectedType = getDtypeFromMlirType(mlir_type);
SPU_ENFORCE(expectedType == val.dtype(), "Expected mlir_type {}, got {}",
expectedType, val.dtype());
if (mlir_type.isa<mlir::ComplexType>()) {
SPU_ENFORCE(val.imag().has_value(), "Expected complex type");
}

// Check vtype
if (tool.isMPCType<mlir::pphlo::PublicType>(mlir_type)) {
Expand Down Expand Up @@ -686,13 +706,21 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
auto ret_el_type = type_tools.getExpressedType(ret_type);
auto pt_type = getPtTypeFromMlirType(ret_el_type);

spu::Value iota_ret = kernel::hlo::Iota(sctx, getEncodeType(pt_type), numel);
spu::Value iota_ret =
kernel::hlo::Iota(sctx, getEncodeType(pt_type.first), numel);

if (ret_type.getShape().size() > 1) {
// Need a broadcast
iota_ret = kernel::hlo::Broadcast(sctx, iota_ret, ret_type.getShape(), {});
}

if (pt_type.second) {
// Complex
auto zeros = kernel::hlo::Constant(sctx, 0.0F, ret_type.getShape());
zeros = kernel::hlo::Cast(sctx, zeros, iota_ret.vtype(), iota_ret.dtype());
iota_ret = kernel::hlo::Complex(sctx, iota_ret, zeros);
}

addValue(sscope, op.getOutput(), std::move(iota_ret), opts);
}

Expand Down Expand Up @@ -1029,6 +1057,7 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
// https://github.com/llvm/llvm-project/blob/3696941dae5cc5bb379c50eae6190e29f7edbbb1/mlir/include/mlir/IR/BuiltinAttributes.h#L188
// We need to normalize the value to 0,1
if (dea.getElementType().isInteger(1)) {
SPU_ENFORCE(pt_type.second == false);
if (dea.isSplat()) {
addValue(
sscope, op.getResult(),
Expand All @@ -1039,19 +1068,43 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
for (const auto &v : llvm::enumerate(dea.getValues<bool>())) {
buf[v.index()] = static_cast<uint8_t>(v.value());
}
PtBufferView view(reinterpret_cast<const bool *>(buf.data()), pt_type,
dst_shape, makeCompactStrides(dst_shape));
PtBufferView view(reinterpret_cast<const bool *>(buf.data()),
pt_type.first, dst_shape,
makeCompactStrides(dst_shape));

addValue(sscope, op.getResult(),
kernel::hlo::Constant(sctx, view, dst_shape), opts);
}
} else {
PtBufferView view(
dea.getRawData().data(), pt_type, dea.isSplat() ? Shape() : dst_shape,
dea.isSplat() ? Strides() : makeCompactStrides(dst_shape));
if (!pt_type.second) {
// Real numbers
PtBufferView view(
dea.getRawData().data(), pt_type.first,
dea.isSplat() ? Shape() : dst_shape,
dea.isSplat() ? Strides() : makeCompactStrides(dst_shape));

addValue(sscope, op.getResult(),
kernel::hlo::Constant(sctx, view, dst_shape), opts);
addValue(sscope, op.getResult(),
kernel::hlo::Constant(sctx, view, dst_shape), opts);
} else {
// Complex constant
// real view
auto cs = makeCompactStrides(dst_shape);
if (!cs.empty()) {
cs.back() *= 2;
}
PtBufferView real_view(dea.getRawData().data(), pt_type.first,
dea.isSplat() ? Shape() : dst_shape,
dea.isSplat() ? Strides() : cs);
PtBufferView imag_view(dea.getRawData().data() + SizeOf(pt_type.first),
pt_type.first, dea.isSplat() ? Shape() : dst_shape,
dea.isSplat() ? Strides() : cs);

auto real = kernel::hlo::Constant(sctx, real_view, dst_shape);
auto imag = kernel::hlo::Constant(sctx, imag_view, dst_shape);

addValue(sscope, op.getResult(), kernel::hlo::Complex(sctx, real, imag),
opts);
}
}
}

Expand Down
30 changes: 30 additions & 0 deletions libspu/device/pphlo/pphlo_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ func.func @main() -> (tensor<2x!pphlo.pub<i32>>) {
r.verifyOutput(expected.data());
}

TEST_P(ExecutorTest, ComplexConstant) {
Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()),
std::get<2>(GetParam()));

r.getConfig().set_enable_type_checker(false);

r.run(R"(
func.func @main() -> (tensor<2x!pphlo.pub<complex<f32>>>) {
%0 = "pphlo.constant"() {value = dense<[(1.5, 2.5), (3.5, 4.5)]> : tensor<2xcomplex<f32>>} : () -> tensor<2x!pphlo.pub<complex<f32>>>
return %0 : tensor<2x!pphlo.pub<complex<f32>>>
})");

std::vector<std::complex<float>> expected = {{1.5, 2.5}, {3.5, 4.5}};
r.verifyOutput(expected.data());
}

TEST_P(ExecutorTest, InvalidIR) {
Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()),
std::get<2>(GetParam()));
Expand Down Expand Up @@ -765,6 +781,20 @@ func.func @main() -> (tensor<4x2x!pphlo.pub<i32>>) {
r.verifyOutput(expect.data());
}

TEST_P(ExecutorTest, IotaComplex) {
Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()),
std::get<2>(GetParam()));

r.run(R"(
func.func @main() -> (tensor<4x!pphlo.pub<complex<f32>>>) {
%0 = "pphlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4x!pphlo.pub<complex<f32>>>
return %0 : tensor<4x!pphlo.pub<complex<f32>>>
})");

std::vector<std::complex<float>> expect = {{0, 0}, {1, 0}, {2, 0}, {3, 0}};
r.verifyOutput(expect.data());
}

TEST_P(ExecutorTest, SimpleBitcast) {
GTEST_SKIP();

Expand Down
11 changes: 11 additions & 0 deletions libspu/device/pphlo/pphlo_executor_test_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

namespace spu::device::pphlo::test {

template <typename T>
struct is_complex_t : public std::false_type {};

template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};

class Runner {
public:
Runner(size_t world_size, FieldType field, ProtocolKind protocol);
Expand Down Expand Up @@ -51,6 +57,11 @@ class Runner {
for (size_t i = 0; i < numel; ++i) {
if constexpr (std::is_integral_v<T>) {
EXPECT_EQ(_out[i], expected[i]) << "i = " << i << "\n";
} else if constexpr (is_complex_t<T>::value) {
EXPECT_TRUE(std::abs(_out[i].real() - expected[i].real()) <= 1e-2 &&
std::abs(_out[i].imag() - expected[i].imag()) <= 1e-2)
<< "i = " << i << " in = " << _out[i]
<< " expected = " << expected[i] << "\n";
} else {
EXPECT_TRUE(std::abs(_out[i] - expected[i]) <= 1e-2)
<< "i = " << i << " in = " << _out[i]
Expand Down

0 comments on commit 342c47c

Please sign in to comment.