diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py index 6d70f875..58f94234 100644 --- a/examples/python/ml/ml_test.py +++ b/examples/python/ml/ml_test.py @@ -25,12 +25,13 @@ import pandas as pd import spu.utils.distributed as ppd +import spu.utils.distributed_impl as ppd_impl from spu.utils.polyfill import Process with open("examples/python/conf/3pc.json", 'r') as file: conf = json.load(file) -logger = logging.getLogger(ppd.__name__) +logger = logging.getLogger(ppd_impl.__name__) logger.setLevel(level=logging.WARN) _test_perf_table = pd.DataFrame({'name': [], 'duration': []}) diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index bd0d982a..d44c1e6a 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -1265,7 +1265,10 @@ class HloToPPHloOpConverter op, result_types, materializeInputs(op, adaptor.getOperands()), op.getCallTargetName(), op.getHasSideEffect()); - new_op->setAttr("mhlo.attributes", op->getAttr("mhlo.attributes")); + auto attr = op->getAttr("mhlo.attributes"); + if (attr) { + new_op->setAttr("mhlo.attributes", attr); + } return success(); } diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index b21d585d..208c5f6c 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -124,7 +124,12 @@ class NdArrayRef { // create a compact clone. NdArrayRef clone() const; - bool isCompact() const { return strides_ == makeCompactStrides(shape_); } + bool isCompact() const { + if (numel() < 2) { + return true; + } + return strides_ == makeCompactStrides(shape_); + } // Test only bool canUseFastIndexing() const { return use_fast_indexing_; } diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc index 9be15040..50e6f891 100644 --- a/libspu/core/pt_buffer_view.cc +++ b/libspu/core/pt_buffer_view.cc @@ -19,6 +19,17 @@ namespace spu { +namespace detail { + +bool isCompact(const Strides& stride, const Shape& shape) { + if (shape.numel() < 2) { + return true; + } + return stride == makeCompactStrides(shape); +} + +} // namespace detail + std::ostream& operator<<(std::ostream& out, PtBufferView v) { out << fmt::format("PtBufferView<{},{}x{},{}>", v.ptr, fmt::join(v.shape, "x"), v.pt_type, diff --git a/libspu/core/pt_buffer_view.h b/libspu/core/pt_buffer_view.h index a76f4f01..cc5df06e 100644 --- a/libspu/core/pt_buffer_view.h +++ b/libspu/core/pt_buffer_view.h @@ -38,6 +38,8 @@ constexpr bool decltype(std::declval().strides())>> = true; +bool isCompact(const Strides& stride, const Shape& shape); + } // namespace detail // A view of a plaintext buffer. @@ -62,7 +64,7 @@ struct PtBufferView { shape(std::move(in_shape)), strides(std::move(in_strides)), write_able(!std::is_const_v>), - compacted(strides == makeCompactStrides(shape)), + compacted(detail::isCompact(strides, shape)), is_bitset(is_bitset) { static_assert(std::is_pointer_v); if (is_bitset) { @@ -103,7 +105,7 @@ struct PtBufferView { pt_type(PtTypeToEnum::value), shape(t.shape().begin(), t.shape().end()), strides(t.strides().begin(), t.strides().end()), - compacted(strides == makeCompactStrides(shape)) {} + compacted(detail::isCompact(strides, shape)) {} template , bool> = true> @@ -113,7 +115,7 @@ struct PtBufferView { shape(t.shape().begin(), t.shape().end()), strides(t.strides().begin(), t.strides().end()), write_able(true), - compacted(strides == makeCompactStrides(shape)) {} + compacted(detail::isCompact(strides, shape)) {} template const S& get(const Index& indices) const { diff --git a/libspu/core/pt_buffer_view_test.cc b/libspu/core/pt_buffer_view_test.cc index 84ff8934..d70c9ce4 100644 --- a/libspu/core/pt_buffer_view_test.cc +++ b/libspu/core/pt_buffer_view_test.cc @@ -48,6 +48,13 @@ TEST(PtBufferView, Scalar) { EXPECT_TRUE(bv_i1.strides.empty()); } +TEST(PtBufferView, Compact) { + int64_t i = 1; + PtBufferView view(&i, PT_I64, {1}, {1}); + + EXPECT_TRUE(view.isCompact()); +} + TEST(PtBufferView, Vector) { std::vector raw_i32(10, 0); PtBufferView bv_i32(raw_i32); diff --git a/libspu/core/trace.h b/libspu/core/trace.h index c62c0d3c..c67dbff0 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -152,6 +152,8 @@ struct ActionRecord final { // the communication bytes information. size_t send_bytes_start; size_t send_bytes_end; + size_t recv_bytes_start; + size_t recv_bytes_end; }; class ProfState final { @@ -234,6 +236,8 @@ class TraceAction final { // the action communication information. size_t send_bytes_start_; size_t send_bytes_end_; + size_t recv_bytes_start_; + size_t recv_bytes_end_; int64_t saved_tracer_flag_; @@ -242,6 +246,7 @@ class TraceAction final { start_ = std::chrono::high_resolution_clock::now(); if (lctx_) { send_bytes_start_ = lctx_->GetStats()->sent_bytes.load(); + recv_bytes_start_ = lctx_->GetStats()->recv_bytes.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGB) != 0) { @@ -263,6 +268,7 @@ class TraceAction final { end_ = std::chrono::high_resolution_clock::now(); if (lctx_) { send_bytes_end_ = lctx_->GetStats()->sent_bytes.load(); + recv_bytes_end_ = lctx_->GetStats()->recv_bytes.load(); } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGE) != 0) { @@ -272,7 +278,8 @@ class TraceAction final { if ((flag & TR_REC) != 0 && (flag & TR_MODALL) != 0) { tracer_->getProfState()->addRecord( ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_, - send_bytes_start_, send_bytes_end_}); + send_bytes_start_, send_bytes_end_, recv_bytes_start_, + recv_bytes_end_}); } } diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 1d4ba690..8a2fd7cf 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -67,6 +67,7 @@ struct ExecutionStats { struct CommunicationStats { size_t send_bytes = 0; + size_t recv_bytes = 0; size_t send_actions = 0; void reset(const std::shared_ptr &lctx) { @@ -75,6 +76,7 @@ struct CommunicationStats { } send_actions = lctx->GetStats()->sent_actions; send_bytes = lctx->GetStats()->sent_bytes; + recv_bytes = lctx->GetStats()->recv_bytes; } void diff(const std::shared_ptr &lctx) { @@ -82,6 +84,7 @@ struct CommunicationStats { return; } send_bytes = lctx->GetStats()->sent_bytes - send_bytes; + recv_bytes = lctx->GetStats()->recv_bytes - recv_bytes; send_actions = lctx->GetStats()->sent_actions - send_actions; } }; @@ -101,6 +104,8 @@ struct ActionStats { Duration total_time = {}; // total send bytes. size_t send_bytes = 0; + // total recv bytes. + size_t recv_bytes = 0; inline double getTotalTimeInSecond() const { return std::chrono::duration_cast>(total_time) @@ -175,27 +180,40 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, stat.total_time += std::chrono::duration_cast(rec.end - rec.start); stat.send_bytes += (rec.send_bytes_end - rec.send_bytes_start); + stat.recv_bytes += (rec.recv_bytes_end - rec.recv_bytes_start); } static std::map kModules = { {TR_HLO, "HLO"}, {TR_HAL, "HAL"}, {TR_MPC, "MPC"}}; for (const auto &[mod_flag, mod_name] : kModules) { + if ((tracer->getFlag() & mod_flag) == 0) { + continue; + } + double total_time = 0.0; + std::vector sorted_by_time; for (const auto &[key, stat] : stats) { if ((key.flag & mod_flag) != 0) { total_time += stat.getTotalTimeInSecond(); + sorted_by_time.emplace_back(key); } } - if ((tracer->getFlag() & mod_flag) != 0) { - SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time); - for (const auto &[key, stat] : stats) { - if ((key.flag & mod_flag) != 0) { - SPDLOG_INFO("- {}, executed {} times, duration {}s, send bytes {}", - key.name, stat.count, stat.getTotalTimeInSecond(), - stat.send_bytes); - } - } + + std::sort(sorted_by_time.begin(), sorted_by_time.end(), + [&](const auto &k0, const auto &k1) { + return stats.find(k0)->second.getTotalTimeInSecond() > + stats.find(k1)->second.getTotalTimeInSecond(); + }); + + SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time); + for (const auto &key : sorted_by_time) { + const auto &stat = stats.find(key)->second; + SPDLOG_INFO( + "- {}, executed {} times, duration {}s, send bytes {} recv " + "bytes {}", + key.name, stat.count, stat.getTotalTimeInSecond(), stat.send_bytes, + stat.recv_bytes); } } } diff --git a/libspu/dialect/pphlo/ops.cc b/libspu/dialect/pphlo/ops.cc index c2c61369..5af62daa 100644 --- a/libspu/dialect/pphlo/ops.cc +++ b/libspu/dialect/pphlo/ops.cc @@ -981,6 +981,46 @@ void CustomCallOp::getEffects( effects.emplace_back(MemoryEffects::Read::get()); } +class MarkValueOnlyTopK : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pphlo::CustomCallOp op, + PatternRewriter& rewriter) const override { + if (op.getCallTargetName() != "mhlo.topk" || op->getNumResults() != 2) { + return failure(); + } + + auto indices = op.getResult(1); + if (!indices.use_empty()) { + return failure(); + } + + auto attr = op->getAttr("mhlo.attributes").dyn_cast(); + + auto new_op = rewriter.create( + op->getLoc(), TypeRange{op->getResultTypes()[0]}, op->getOperands(), + op.getCallTargetName()); + + auto new_attr = DictionaryAttr::get( + op->getContext(), + {NamedAttribute(rewriter.getStringAttr("k"), attr.get("k")), + NamedAttribute(rewriter.getStringAttr("largest"), attr.get("largest")), + NamedAttribute(rewriter.getStringAttr("value_only"), + rewriter.getBoolAttr(true))}); + new_op->setAttr("mhlo.attributes", new_attr); + + rewriter.replaceAllUsesWith(op->getResult(0), new_op->getResult(0)); + + return success(); + } +}; + +void CustomCallOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} + } // namespace mlir::spu::pphlo #define GET_OP_CLASSES diff --git a/libspu/dialect/pphlo/ops.td b/libspu/dialect/pphlo/ops.td index 5d6f9c7a..58ba702a 100644 --- a/libspu/dialect/pphlo/ops.td +++ b/libspu/dialect/pphlo/ops.td @@ -1064,6 +1064,7 @@ def PPHLO_CustomCallOp: PPHLO_Op<"custom_call", custom($call_target_name) `(` $inputs `)` attr-dict `:` functional-type(operands, results) }]; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/libspu/mpc/ab_api_test.cc b/libspu/mpc/ab_api_test.cc index d0d9aa27..c52e4664 100644 --- a/libspu/mpc/ab_api_test.cc +++ b/libspu/mpc/ab_api_test.cc @@ -282,31 +282,43 @@ TEST_P(ArithmeticTest, MatMulAV) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); + const int64_t M = 3; const int64_t K = 4; const int64_t N = 3; const Shape shape_A = {M, K}; const Shape shape_B = {K, N}; const Shape shape_C = {M, N}; + utils::simulate(npc, [&](const std::shared_ptr& lctx) { auto obj = factory(conf, lctx); + if (not obj->hasKernel("mmul_av")) { + return; + } + /* GIVEN */ - auto p0 = rand_p(obj.get(), shape_A); + auto r = rand_p(obj.get(), shape_A); + auto b0 = p2b(obj.get(), r); + auto a0 = p2a(obj.get(), r); auto p1 = rand_p(obj.get(), shape_B); - auto a0 = p2a(obj.get(), p0); auto v1 = p2v(obj.get(), p1, 0); + /* WHEN */ auto prev = obj->prot()->getState()->getStats(); - auto _tmp = mmul_av(obj.get(), a0, v1); - if (!_tmp.has_value()) { - return; - } - auto tmp = _tmp.value(); + // mmul_sv -> a * v + auto tmp1 = mmul_sv(obj.get(), a0, v1); + // mmul_sv -> b * v -> a * v + auto tmp0 = mmul_sv(obj.get(), b0, v1); auto cost = obj->prot()->getState()->getStats() - prev; - auto r_aa = a2p(obj.get(), tmp); - auto r_pp = mmul_pp(obj.get(), p0, p1); + + auto r0_aa = a2p(obj.get(), tmp0); + auto r1_aa = a2p(obj.get(), tmp1); + + auto r_pp = mmul_pp(obj.get(), b2p(obj.get(), b0), p1); + /* THEN */ - EXPECT_VALUE_EQ(r_aa, r_pp); + EXPECT_VALUE_EQ(r0_aa, r_pp); + EXPECT_VALUE_EQ(r1_aa, r_pp); ce::Params params = {{"K", SizeOf(conf.field()) * 8}, {"N", npc}, {"m", M}, diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index e3132b11..714b5a15 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -322,12 +322,12 @@ Value equal_pp(SPUContext* ctx, const Value& x, const Value& y) { OptionalAPI equal_sp(SPUContext* ctx, const Value& x, const Value& y) { SPU_TRACE_MPC_DISP(ctx, x, y); - TRY_DISPATCH(ctx, x, y); - if (IsA(x) && ctx->hasKernel("equal_ap")) { - return dynDispatch(ctx, "equal_ap", x, y); - } else if (IsB(x) && ctx->hasKernel("equal_bp")) { - return dynDispatch(ctx, "equal_bp", x, y); + if (IsA(x)) { + TRY_NAMED_DISPATCH(ctx, "equal_ap", x, y); + } + if (IsB(x)) { + TRY_NAMED_DISPATCH(ctx, "equal_bp", x, y); } return NotAvailable; @@ -335,21 +335,21 @@ OptionalAPI equal_sp(SPUContext* ctx, const Value& x, const Value& y) { OptionalAPI equal_ss(SPUContext* ctx, const Value& x, const Value& y) { SPU_TRACE_MPC_DISP(ctx, x, y); - TRY_DISPATCH(ctx, x, y); // try fast path // TODO: use cost model instead of hand-coded priority. - if (IsA(x) && IsA(y) && ctx->hasKernel("equal_aa")) { - return dynDispatch(ctx, "equal_aa", x, y); - } else if (IsB(x) && IsB(y) && ctx->hasKernel("equal_bb")) { - return dynDispatch(ctx, "equal_bb", x, y); + if (IsA(x) && IsA(y)) { + TRY_NAMED_DISPATCH(ctx, "equal_aa", x, y); + } else if (IsB(x) && IsB(y)) { + TRY_NAMED_DISPATCH(ctx, "equal_bb", x, y); } else if ((IsA(x) && IsB(y)) || (IsB(x) && IsA(y))) { // mixed a & b, both OK, hardcode to a. if (ctx->hasKernel("equal_aa")) { - return dynDispatch(ctx, "equal_aa", _2a(ctx, x), _2a(ctx, y)); + FORCE_NAMED_DISPATCH(ctx, "equal_aa", _2a(ctx, x), _2a(ctx, y)); } + if (ctx->hasKernel("equal_bb")) { - return dynDispatch(ctx, "equal_bb", _2b(ctx, x), _2b(ctx, y)); + FORCE_NAMED_DISPATCH(ctx, "equal_bb", _2b(ctx, x), _2b(ctx, y)); } } @@ -468,12 +468,13 @@ Value mmul_ss(SPUContext* ctx, const Value& x, const Value& y) { Value mmul_sv(SPUContext* ctx, const Value& x, const Value& y) { SPU_TRACE_MPC_DISP(ctx, x, y); - TRY_DISPATCH(ctx, x, y); - if (IsA(x)) { - if (auto res = mmul_av(ctx, x, y)) { - return res.value(); - } + + if (ctx->hasKernel("mmul_av")) { + // call a * v is available which is faster than calling a * a + FORCE_NAMED_DISPATCH(ctx, "mmul_av", _2a(ctx, x), y); } + + // b * a will finally call a * a return mmul_ss(ctx, x, v2s(ctx, y)); }