Skip to content

Commit

Permalink
Transpose batch
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 12, 2024
1 parent 7e38527 commit 8b25dcb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,40 @@ struct SHLOConstantOpBatchInterface
}
};

struct SHLOTransposeOpBatchInterface
: public BatchOpInterface::ExternalModel<SHLOTransposeOpBatchInterface,
TransposeOp> {

mlir::Operation *createBatch(Operation *src, IRMapping &mapper,
Operation::CloneOptions options,
std::map<Operation *, Operation *> &opMap,
ArrayRef<int64_t> batchSizes) const {
SmallVector<Type> resultTypes(src->getResultTypes().begin(),
src->getResultTypes().end());
for (auto &Ty : resultTypes) {
auto T = cast<TensorType>(Ty);
SmallVector<int64_t> shape(batchSizes.begin(), batchSizes.end());
shape.append(T.getShape().begin(), T.getShape().end());
Ty = T.clone(shape);
}
mlir::NamedAttrList attrs;
for (auto attr : src->getAttrs()) {
auto eattr = cast<DenseI64ArrayAttr>(attr.getValue());
SmallVector<int64_t> shape;
for (size_t i = 0; i < batchSizes.size(); i++)
shape.push_back(i);
for (auto val : eattr.asArrayRef())
shape.push_back(val + batchSizes.size());
attr.setValue(DenseI64ArrayAttr::get(src->getContext(), shape));
attrs.append(attr);
}
auto cop = mlir::Operation::create(
src->getLoc(), src->getName(), resultTypes, {}, std::move(attrs),
OpaqueProperties(nullptr), mlir::BlockRange(), 0);
return cop;
}
};

struct ADDataFlowSortOp
: public ADDataFlowOpInterface::ExternalModel<ADDataFlowSortOp, SortOp> {

Expand Down Expand Up @@ -981,5 +1015,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
ReduceOp::attachInterface<AutoDiffReduceRev>(*context);
ConcatenateOp::attachInterface<AutoDiffConcatenateRev>(*context);
ConstantOp::attachInterface<SHLOConstantOpBatchInterface>(*context);
TransposeOp::attachInterface<SHLOTransposeOpBatchInterface>(*context);
});
}
15 changes: 15 additions & 0 deletions test/lit_tests/batchtests/transpose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: enzymexlamlir-opt %s --enzyme-batch | FileCheck %s

func.func private @relu_broadcast_scalar(%arg0: tensor<3x4xf64>) -> (tensor<4x3xf64>) {
%1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64>
return %1 : tensor<4x3xf64>
}
func.func @main(%arg0: tensor<2x5x3x4xf64>) -> (tensor<2x5x4x3xf64>) {
%1 = enzyme.batch @relu_broadcast_scalar(%arg0) {batch_shape = array<i64: 2, 5>} : (tensor<2x5x3x4xf64>) -> (tensor<2x5x4x3xf64>)
return %1 : tensor<2x5x4x3xf64>
}

// CHECK: func.func private @batched_relu_broadcast_scalar(%arg0: tensor<2x5x3x4xf64>) -> tensor<2x5x4x3xf64> {
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [0, 1, 3, 2] : (tensor<2x5x3x4xf64>) -> tensor<2x5x4x3xf64>
// CHECK-NEXT: return %0 : tensor<2x5x4x3xf64>
// CHECK-NEXT: }

0 comments on commit 8b25dcb

Please sign in to comment.