Skip to content

Commit

Permalink
repo-sync-2024-03-26T13:27:15+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Mar 26, 2024
1 parent be4c6c4 commit 743d105
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 43 deletions.
3 changes: 2 additions & 1 deletion examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
import pandas as pd

import spu.utils.distributed as ppd
import spu.utils.distributed_impl as ppd_impl
from spu.utils.polyfill import Process

with open("examples/python/conf/3pc.json", 'r') as file:
conf = json.load(file)

logger = logging.getLogger(ppd.__name__)
logger = logging.getLogger(ppd_impl.__name__)
logger.setLevel(level=logging.WARN)

_test_perf_table = pd.DataFrame({'name': [], 'duration': []})
Expand Down
5 changes: 4 additions & 1 deletion libspu/compiler/passes/hlo_legalize_to_pphlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,10 @@ class HloToPPHloOpConverter<stablehlo::CustomCallOp>
op, result_types, materializeInputs(op, adaptor.getOperands()),
op.getCallTargetName(), op.getHasSideEffect());

new_op->setAttr("mhlo.attributes", op->getAttr("mhlo.attributes"));
auto attr = op->getAttr("mhlo.attributes");
if (attr) {
new_op->setAttr("mhlo.attributes", attr);
}

return success();
}
Expand Down
7 changes: 6 additions & 1 deletion libspu/core/ndarray_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ class NdArrayRef {
// create a compact clone.
NdArrayRef clone() const;

bool isCompact() const { return strides_ == makeCompactStrides(shape_); }
bool isCompact() const {
if (numel() < 2) {
return true;
}
return strides_ == makeCompactStrides(shape_);
}

// Test only
bool canUseFastIndexing() const { return use_fast_indexing_; }
Expand Down
11 changes: 11 additions & 0 deletions libspu/core/pt_buffer_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@

namespace spu {

namespace detail {

bool isCompact(const Strides& stride, const Shape& shape) {
if (shape.numel() < 2) {
return true;
}
return stride == makeCompactStrides(shape);
}

} // namespace detail

std::ostream& operator<<(std::ostream& out, PtBufferView v) {
out << fmt::format("PtBufferView<{},{}x{},{}>", v.ptr,
fmt::join(v.shape, "x"), v.pt_type,
Expand Down
8 changes: 5 additions & 3 deletions libspu/core/pt_buffer_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ constexpr bool
decltype(std::declval<T>().strides())>> =
true;

bool isCompact(const Strides& stride, const Shape& shape);

} // namespace detail

// A view of a plaintext buffer.
Expand All @@ -62,7 +64,7 @@ struct PtBufferView {
shape(std::move(in_shape)),
strides(std::move(in_strides)),
write_able(!std::is_const_v<std::remove_pointer_t<Pointer>>),
compacted(strides == makeCompactStrides(shape)),
compacted(detail::isCompact(strides, shape)),
is_bitset(is_bitset) {
static_assert(std::is_pointer_v<Pointer>);
if (is_bitset) {
Expand Down Expand Up @@ -103,7 +105,7 @@ struct PtBufferView {
pt_type(PtTypeToEnum<typename T::value_type>::value),
shape(t.shape().begin(), t.shape().end()),
strides(t.strides().begin(), t.strides().end()),
compacted(strides == makeCompactStrides(shape)) {}
compacted(detail::isCompact(strides, shape)) {}

template <typename T,
std::enable_if_t<detail::is_tensor_like_v<T>, bool> = true>
Expand All @@ -113,7 +115,7 @@ struct PtBufferView {
shape(t.shape().begin(), t.shape().end()),
strides(t.strides().begin(), t.strides().end()),
write_able(true),
compacted(strides == makeCompactStrides(shape)) {}
compacted(detail::isCompact(strides, shape)) {}

template <typename S = uint8_t>
const S& get(const Index& indices) const {
Expand Down
7 changes: 7 additions & 0 deletions libspu/core/pt_buffer_view_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ TEST(PtBufferView, Scalar) {
EXPECT_TRUE(bv_i1.strides.empty());
}

TEST(PtBufferView, Compact) {
int64_t i = 1;
PtBufferView view(&i, PT_I64, {1}, {1});

EXPECT_TRUE(view.isCompact());
}

TEST(PtBufferView, Vector) {
std::vector<int32_t> raw_i32(10, 0);
PtBufferView bv_i32(raw_i32);
Expand Down
9 changes: 8 additions & 1 deletion libspu/core/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ struct ActionRecord final {
// the communication bytes information.
size_t send_bytes_start;
size_t send_bytes_end;
size_t recv_bytes_start;
size_t recv_bytes_end;
};

class ProfState final {
Expand Down Expand Up @@ -234,6 +236,8 @@ class TraceAction final {
// the action communication information.
size_t send_bytes_start_;
size_t send_bytes_end_;
size_t recv_bytes_start_;
size_t recv_bytes_end_;

int64_t saved_tracer_flag_;

Expand All @@ -242,6 +246,7 @@ class TraceAction final {
start_ = std::chrono::high_resolution_clock::now();
if (lctx_) {
send_bytes_start_ = lctx_->GetStats()->sent_bytes.load();
recv_bytes_start_ = lctx_->GetStats()->recv_bytes.load();
}
const auto flag = flag_ & tracer_->getFlag();
if ((flag & TR_LOGB) != 0) {
Expand All @@ -263,6 +268,7 @@ class TraceAction final {
end_ = std::chrono::high_resolution_clock::now();
if (lctx_) {
send_bytes_end_ = lctx_->GetStats()->sent_bytes.load();
recv_bytes_end_ = lctx_->GetStats()->recv_bytes.load();
}
const auto flag = flag_ & tracer_->getFlag();
if ((flag & TR_LOGE) != 0) {
Expand All @@ -272,7 +278,8 @@ class TraceAction final {
if ((flag & TR_REC) != 0 && (flag & TR_MODALL) != 0) {
tracer_->getProfState()->addRecord(
ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_,
send_bytes_start_, send_bytes_end_});
send_bytes_start_, send_bytes_end_, recv_bytes_start_,
recv_bytes_end_});
}
}

Expand Down
36 changes: 27 additions & 9 deletions libspu/device/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ExecutionStats {

struct CommunicationStats {
size_t send_bytes = 0;
size_t recv_bytes = 0;
size_t send_actions = 0;

void reset(const std::shared_ptr<yacl::link::Context> &lctx) {
Expand All @@ -75,13 +76,15 @@ struct CommunicationStats {
}
send_actions = lctx->GetStats()->sent_actions;
send_bytes = lctx->GetStats()->sent_bytes;
recv_bytes = lctx->GetStats()->recv_bytes;
}

void diff(const std::shared_ptr<yacl::link::Context> &lctx) {
if (!lctx) {
return;
}
send_bytes = lctx->GetStats()->sent_bytes - send_bytes;
recv_bytes = lctx->GetStats()->recv_bytes - recv_bytes;
send_actions = lctx->GetStats()->sent_actions - send_actions;
}
};
Expand All @@ -101,6 +104,8 @@ struct ActionStats {
Duration total_time = {};
// total send bytes.
size_t send_bytes = 0;
// total recv bytes.
size_t recv_bytes = 0;

inline double getTotalTimeInSecond() const {
return std::chrono::duration_cast<std::chrono::duration<double>>(total_time)
Expand Down Expand Up @@ -175,27 +180,40 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
stat.total_time +=
std::chrono::duration_cast<Duration>(rec.end - rec.start);
stat.send_bytes += (rec.send_bytes_end - rec.send_bytes_start);
stat.recv_bytes += (rec.recv_bytes_end - rec.recv_bytes_start);
}

static std::map<int64_t, std::string> kModules = {
{TR_HLO, "HLO"}, {TR_HAL, "HAL"}, {TR_MPC, "MPC"}};

for (const auto &[mod_flag, mod_name] : kModules) {
if ((tracer->getFlag() & mod_flag) == 0) {
continue;
}

double total_time = 0.0;
std::vector<ActionKey> sorted_by_time;
for (const auto &[key, stat] : stats) {
if ((key.flag & mod_flag) != 0) {
total_time += stat.getTotalTimeInSecond();
sorted_by_time.emplace_back(key);
}
}
if ((tracer->getFlag() & mod_flag) != 0) {
SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time);
for (const auto &[key, stat] : stats) {
if ((key.flag & mod_flag) != 0) {
SPDLOG_INFO("- {}, executed {} times, duration {}s, send bytes {}",
key.name, stat.count, stat.getTotalTimeInSecond(),
stat.send_bytes);
}
}

std::sort(sorted_by_time.begin(), sorted_by_time.end(),
[&](const auto &k0, const auto &k1) {
return stats.find(k0)->second.getTotalTimeInSecond() >
stats.find(k1)->second.getTotalTimeInSecond();
});

SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time);
for (const auto &key : sorted_by_time) {
const auto &stat = stats.find(key)->second;
SPDLOG_INFO(
"- {}, executed {} times, duration {}s, send bytes {} recv "
"bytes {}",
key.name, stat.count, stat.getTotalTimeInSecond(), stat.send_bytes,
stat.recv_bytes);
}
}
}
Expand Down
40 changes: 40 additions & 0 deletions libspu/dialect/pphlo/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,46 @@ void CustomCallOp::getEffects(
effects.emplace_back(MemoryEffects::Read::get());
}

class MarkValueOnlyTopK : public OpRewritePattern<CustomCallOp> {
public:
using OpRewritePattern<CustomCallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(pphlo::CustomCallOp op,
PatternRewriter& rewriter) const override {
if (op.getCallTargetName() != "mhlo.topk" || op->getNumResults() != 2) {
return failure();
}

auto indices = op.getResult(1);
if (!indices.use_empty()) {
return failure();
}

auto attr = op->getAttr("mhlo.attributes").dyn_cast<mlir::DictionaryAttr>();

auto new_op = rewriter.create<CustomCallOp>(
op->getLoc(), TypeRange{op->getResultTypes()[0]}, op->getOperands(),
op.getCallTargetName());

auto new_attr = DictionaryAttr::get(
op->getContext(),
{NamedAttribute(rewriter.getStringAttr("k"), attr.get("k")),
NamedAttribute(rewriter.getStringAttr("largest"), attr.get("largest")),
NamedAttribute(rewriter.getStringAttr("value_only"),
rewriter.getBoolAttr(true))});
new_op->setAttr("mhlo.attributes", new_attr);

rewriter.replaceAllUsesWith(op->getResult(0), new_op->getResult(0));

return success();
}
};

void CustomCallOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
results.add<MarkValueOnlyTopK>(context);
}

} // namespace mlir::spu::pphlo

#define GET_OP_CLASSES
Expand Down
1 change: 1 addition & 0 deletions libspu/dialect/pphlo/ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def PPHLO_CustomCallOp: PPHLO_Op<"custom_call",
custom<CustomCallTarget>($call_target_name) `(` $inputs `)`
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
32 changes: 22 additions & 10 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,31 +282,43 @@ TEST_P(ArithmeticTest, MatMulAV) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

const int64_t M = 3;
const int64_t K = 4;
const int64_t N = 3;
const Shape shape_A = {M, K};
const Shape shape_B = {K, N};
const Shape shape_C = {M, N};

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto obj = factory(conf, lctx);
if (not obj->hasKernel("mmul_av")) {
return;
}

/* GIVEN */
auto p0 = rand_p(obj.get(), shape_A);
auto r = rand_p(obj.get(), shape_A);
auto b0 = p2b(obj.get(), r);
auto a0 = p2a(obj.get(), r);
auto p1 = rand_p(obj.get(), shape_B);
auto a0 = p2a(obj.get(), p0);
auto v1 = p2v(obj.get(), p1, 0);

/* WHEN */
auto prev = obj->prot()->getState<Communicator>()->getStats();
auto _tmp = mmul_av(obj.get(), a0, v1);
if (!_tmp.has_value()) {
return;
}
auto tmp = _tmp.value();
// mmul_sv -> a * v
auto tmp1 = mmul_sv(obj.get(), a0, v1);
// mmul_sv -> b * v -> a * v
auto tmp0 = mmul_sv(obj.get(), b0, v1);
auto cost = obj->prot()->getState<Communicator>()->getStats() - prev;
auto r_aa = a2p(obj.get(), tmp);
auto r_pp = mmul_pp(obj.get(), p0, p1);

auto r0_aa = a2p(obj.get(), tmp0);
auto r1_aa = a2p(obj.get(), tmp1);

auto r_pp = mmul_pp(obj.get(), b2p(obj.get(), b0), p1);

/* THEN */
EXPECT_VALUE_EQ(r_aa, r_pp);
EXPECT_VALUE_EQ(r0_aa, r_pp);
EXPECT_VALUE_EQ(r1_aa, r_pp);
ce::Params params = {{"K", SizeOf(conf.field()) * 8},
{"N", npc},
{"m", M},
Expand Down
Loading

0 comments on commit 743d105

Please sign in to comment.