diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index d6126647..37173af0 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3011,6 +3011,21 @@ struct ConvertSimplify : public OpRewritePattern { struct SliceSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + static int64_t getSizeInBytes(Type type) { + if (auto shapedType = dyn_cast(type)) + return shapedType.getNumElements() * + getSizeInBytes(shapedType.getElementType()); + + if (type.isIntOrFloat()) + return std::max(type.getIntOrFloatBitWidth(), (unsigned)8) / 8; + + if (auto complexType = dyn_cast(type)) + return getSizeInBytes(complexType.getElementType()) * 2; + + report_fatal_error( + invalidArgument("Unsupported type: %s", debugString(type).c_str())); + } + LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op, PatternRewriter &rewriter) const final { DenseElementsAttr inp; @@ -3020,12 +3035,41 @@ struct SliceSimplify : public OpRewritePattern { if (inp.isSplat()) { out = inp.resizeSplat(op.getType()); } else { + bool contiguous = true; + size_t offset = 0; + auto inshape = op.getOperand().getType().getShape(); + size_t total = 1; + for (int i = inshape.size() - 1; i >= 0; i--) { + if (op.getStrides()[i] != 1) { + contiguous = false; + } + auto start = op.getStartIndices()[i]; + auto lim = op.getLimitIndices()[i]; + if (start != 0 || lim != inshape[i]) { + if (offset != 0) { + contiguous = false; + } + } + offset *= inshape[i]; + offset += start; + total *= inshape[i]; + } + auto ten = mlir::stablehlo::constantOp(inp); - out = fromTensor(mlir::stablehlo::sliceOp( - ten, stablehlo::Sizes(op.getStartIndices()), - stablehlo::Sizes(op.getStrides()), op.getType())); - } + if (contiguous) { + auto elementType = op.getOperand().getType().getElementType(); + const char *elementPtr = + out.getRawData().data() + getSizeInBytes(elementType) * offset; + + auto values = ArrayRef((char *)elementPtr, total); + out = DenseIntOrFPElementsAttr::getFromRawBuffer(op.getType(), + floatValues); + } else + out = fromTensor(mlir::stablehlo::sliceOp( + ten, stablehlo::Sizes(op.getStartIndices()), + stablehlo::Sizes(op.getStrides()), op.getType())); + } rewriter.replaceOpWithNewOp(op, op.getType(), out); return success(); }