Skip to content

Commit

Permalink
repo-sync-2024-07-15T10:35:29+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Jul 15, 2024
1 parent fcef2ff commit 7d6355d
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 0 deletions.
1 change: 1 addition & 0 deletions libspu/compiler/core/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void Core::buildPipeline(mlir::PassManager *pm) {
}

optPM.addPass(mlir::createLoopInvariantCodeMotionPass());
optPM.addPass(mlir::spu::pphlo::createRegionAccessFixture());
optPM.addPass(mlir::createCSEPass());

if (!options.disable_deallocation_insertion()) {
Expand Down
11 changes: 11 additions & 0 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,17 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
addValue(sscope, op.getOutput(), std::move(iota_ret), opts);
}

void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
mlir::spu::pphlo::BroadcastShapeAsOp &op,
const ExecutionOptions &opts) {
// Start indices
const auto &lhs = lookupValue(sscope, op.getLhs(), opts);
const auto &rhs = lookupValue(sscope, op.getRhs(), opts);

addValue(sscope, op.getResult(),
kernel::hlo::Broadcast(sctx, lhs, rhs.shape(), {}), opts);
}

void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
mlir::spu::pphlo::RemOp &op, const ExecutionOptions &opts) {
// FIXME: When hal has a remainder, use that
Expand Down
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class PPHloVerifier {
NO_VERIFY_DEFN(ImagOp)
NO_VERIFY_DEFN(ComplexOp)
NO_VERIFY_DEFN(SimpleSortOp)
NO_VERIFY_DEFN(BroadcastShapeAsOp)

#undef NO_VERIFY_DEFN
};
Expand Down
6 changes: 6 additions & 0 deletions libspu/dialect/pphlo/IR/ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,12 @@ def PPHLO_BroadcastOp
}];
}


def PPHLO_BroadcastShapeAsOp: PPHLO_BinaryElementwiseOp<"broadcast_as", [Pure,
SameOperandsAndResultShape], PPHLO_Tensor> {
let summary = "BroadcastShapeAs operator";
}

def PPHLO_ClampOp
: PPHLO_Op<"clamp", [Pure, SameOperandsAndResultShape]> {
let summary = "Clamp operator";
Expand Down
3 changes: 3 additions & 0 deletions libspu/dialect/pphlo/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInlineSecretControlFlow();
// Convert signbit pattern to SignOp
std::unique_ptr<OperationPass<func::FuncOp>> createRewriteSignbitPatterns();

// Fix region access shape mismatch
std::unique_ptr<OperationPass<func::FuncOp>> createRegionAccessFixture();

} // namespace spu::pphlo

} // namespace mlir
6 changes: 6 additions & 0 deletions libspu/dialect/pphlo/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,9 @@ def InlineSecretControlFlow: Pass<"inline-secret-control-flow", "func::FuncOp">
let constructor = "createInlineSecretControlFlow()";
let dependentDialects = ["pphlo::PPHloDialect"];
}

def RegionAccessFixture: Pass<"region-access-fixture", "func::FuncOp"> {
let summary = "Fix region access mismatched shape";
let constructor = "createRegionAccessFixture()";
let dependentDialects = ["pphlo::PPHloDialect"];
}
141 changes: 141 additions & 0 deletions libspu/dialect/pphlo/transforms/region_access_fixture.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2024 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>

#include "mlir/Pass/Pass.h"

#include "libspu/dialect/pphlo/IR/ops.h"
#include "libspu/dialect/pphlo/transforms/pass_details.h"

namespace mlir::spu::pphlo {

namespace {

struct Deallocator {
public:
LogicalResult transformOp(Operation *op) {
for (auto &r : op->getRegions()) {
if (failed(transformRegion(r))) {
return failure();
}
}

const auto &operands = op->getOperands();

if (op->getNumOperands() < 2 ||
!op->hasTrait<::mlir::OpTrait::Elementwise>() ||
std::all_of(operands.begin(), operands.end(), [](const auto &operand) {
return operand.template getDefiningOp<ConstantOp>();
})) {
return success();
}

auto *op_region = op->getParentRegion();

Value base_val;
llvm::SmallVector<int64_t, 2> values_to_update;

OpBuilder builder(op->getContext());
builder.setInsertionPoint(op);

for (const auto &[idx, operand] : llvm::enumerate(operands)) {
// Get defining region
auto *defining_op = operand.getDefiningOp();

Region *defining_region = nullptr;

if (defining_op != nullptr) {
defining_region = defining_op->getParentRegion();
}

if (defining_op == nullptr || defining_region == op_region) {
// BlockArg or op defined in current region can be a base val
base_val = operand;
continue;
}

if (defining_region != op_region) {
// This op is accessing a variable out of op's region.
// Insert a broadcast as to fix runtime shape mismatch during simd
// region execution
values_to_update.emplace_back(idx);
}
}

if (!base_val) {
return values_to_update.empty()
? failure() // same region however failed to pick base value
: success(); // can't pick base value since multi-level
// nesting
}

for (const auto &idx : values_to_update) {
auto op_to_broadcast = op->getOperand(idx);
auto b = builder.create<BroadcastShapeAsOp>(
op->getLoc(), op_to_broadcast.getType(), op_to_broadcast, base_val);
op->setOperand(idx, b);
}

return success();
}

LogicalResult transformBlock(Block &block) {
for (auto &op : llvm::make_early_inc_range(block.without_terminator())) {
auto opResult = transformOp(&op);
if (failed(opResult)) {
return failure();
}
}
return success();
}

LogicalResult transformRegion(Region &r) {
for (auto &b : r.getBlocks()) {
if (failed(transformBlock(b))) {
return failure();
}
}
return success();
}

LogicalResult transformFuncOp(func::FuncOp op) {
if (op->getNumRegions() == 0) {
return success();
}

// Transform function body.
if (failed(transformRegion(op.getBody()))) {
return failure();
}

return success();
}
};

struct RegionAccessFixture
: public RegionAccessFixtureBase<RegionAccessFixture> {
void runOnOperation() override {
if (failed(Deallocator().transformFuncOp(getOperation()))) {
signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createRegionAccessFixture() {
return std::make_unique<RegionAccessFixture>();
}

} // namespace mlir::spu::pphlo

0 comments on commit 7d6355d

Please sign in to comment.