Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
+ backport to gh changes from the temporary.patch
  • Loading branch information
abhigunj authored Dec 6, 2024
1 parent 74e5cf4 commit b3d3cac
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 14 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ cc_library(
":stablehlo_serialization",
":version",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BytecodeReader",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "bd92e46204331b9af296f53abb708317e72ab7a8"
LLVM_COMMIT = "2ccf7ed277df28651b94bbee9fccefdf22fb074f"

LLVM_SHA256 = "60f71fc5b237e10729edbed8cbe23b7081dabe254fbcb1ea82db8789cb7eaecf"
LLVM_SHA256 = "ca68a54dcd12c0dde32732a90899bf57e0f3f96fc43d8d1124d95a5eae627508"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
bd92e46204331b9af296f53abb708317e72ab7a8
2ccf7ed277df28651b94bbee9fccefdf22fb074f
36 changes: 36 additions & 0 deletions stablehlo/dialect/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ limitations under the License.

#include "stablehlo/dialect/Serialization.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/MemoryBuffer.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -29,6 +32,8 @@ limitations under the License.
#include "stablehlo/dialect/VhloOps.h"
#include "stablehlo/transforms/Passes.h"

#define DEBUG_TYPE "compat-passes"

namespace mlir {
namespace stablehlo {

Expand Down Expand Up @@ -89,5 +94,36 @@ OwningOpRef<ModuleOp> deserializePortableArtifact(StringRef sourceStr,
return module;
}

FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode) {
auto logFailure = [&](llvm::StringRef message) {
LLVM_DEBUG(llvm::dbgs() << "Failed to get portable artifact version: "
<< message << "\n");
return failure();
};
// Must start with MLiRxStableHLO_vX.Y.Z, minimum length of 19.
constexpr size_t minHeaderLength = 19;
if (bytecode.size() < minHeaderLength) return logFailure("min header");

// Truncate to the end of the null-terminated producer string.
size_t pos = bytecode.find('\0');
if (pos == llvm::StringRef::npos) return logFailure("no terminator");
bytecode = bytecode.substr(0, pos);

// Check if the bytecode is valid, starts with MLiR magic number.
if (!isBytecode(
llvm::MemoryBuffer::getMemBuffer(bytecode)->getMemBufferRef()))
return logFailure("not bytecode");

// Skip 4 bytes for the magic number.
std::string stablehloHeader = "StableHLO_v";
size_t stablehloPos = bytecode.find(stablehloHeader);
if (stablehloPos == llvm::StringRef::npos)
return logFailure("not a StableHLO portable artifact");

// Skip the 11 bytes for StableHLO_v to get the StableHLO version to parse.
StringRef version = bytecode.substr(stablehloPos + stablehloHeader.size());
return vhlo::Version::fromString(version);
}

} // namespace stablehlo
} // namespace mlir
12 changes: 12 additions & 0 deletions stablehlo/dialect/Serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {
Expand All @@ -43,6 +44,17 @@ LogicalResult serializePortableArtifact(ModuleOp module,
OwningOpRef<ModuleOp> deserializePortableArtifact(StringRef sourceStr,
MLIRContext* context);

// Get portable artifact version from the producer string after the MLIR
// Bytecode magic number `MLïRStableHLO_vX.Y.Z` -> X.Y.Z
// Returns failure if input string is not a valid portable artifact produced by
// serializePortableArtifact APIs, which would cause the bytecode artifact to
// not have the proper producer string.
//
// This method should be safe, since any changes to the bytecode format would
// warrant a bytecode version bump, and MLIR bytecode gives the option to
// specify a forward compatible bytecode version to target.
FailureOr<vhlo::Version> getPortableArtifactVersion(llvm::StringRef bytecode);

} // namespace stablehlo
} // namespace mlir

Expand Down
28 changes: 24 additions & 4 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ LogicalResult checkDimsDistinct(std::optional<Location> loc,
LogicalResult checkDimInBounds(std::optional<Location> loc, int64_t dim,
int64_t upperBound, StringRef dimName,
StringRef upperBoundName,
bool upperBoundInclusive = false) {
bool upperBoundInclusive) {
StringRef rangeEnd = upperBoundInclusive ? "]" : ")";
if (dim < 0 || dim >= upperBound + (upperBoundInclusive ? 1 : 0))
return emitOptionalError(loc, "Expects ", dimName, " to be in range [0, ",
Expand Down Expand Up @@ -2280,14 +2280,13 @@ LogicalResult inferDotOp(
return success();
}

LogicalResult inferDotGeneralOp(
LogicalResult checkDotGeneralConstraints(
std::optional<Location> location, Type lhsType, Type rhsType,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
std::optional<ArrayAttr> precisionConfig) {
// dot_general_c11
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();
Expand Down Expand Up @@ -2366,9 +2365,30 @@ LogicalResult inferDotGeneralOp(
"contracting dimension sizes must "
"match for lhs/rhs");
}
return success();
}

LogicalResult inferDotGeneralOp(
std::optional<Location> location, Type lhsType, Type rhsType,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (failed(checkDotGeneralConstraints(
location, lhsType, rhsType, lhsBatchingDimensions,
rhsBatchingDimensions, lhsContractingDimensions,
rhsContractingDimensions, precisionConfig))) {
return failure();
}

// Infer the output dimensions of the operation.
SmallVector<int64_t> dimensions;
auto lhsRankedType = cast<RankedTensorType>(lhsType);
auto rhsRankedType = cast<RankedTensorType>(rhsType);
auto lhsShape = lhsRankedType.getShape();
auto rhsShape = rhsRankedType.getShape();
for (const int64_t lhsBatchingDim : lhsBatchingDimensions)
dimensions.push_back(lhsShape[lhsBatchingDim]);
for (int64_t i = 0; i < lhsRankedType.getRank(); i++)
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ FailureOr<SmallVector<bool>> convertWindowReversalAttribute(
std::optional<DenseElementsAttr> optionalAttr, std::optional<Location> loc,
StringRef attrName);

LogicalResult checkDimInBounds(std::optional<Location> loc, int64_t dim,
int64_t upperBound, StringRef dimName,
StringRef upperBoundName,
bool upperBoundInclusive = false);

LogicalResult checkDimsDistinct(std::optional<Location> loc,
ArrayRef<int64_t> lhsDims,
ArrayRef<int64_t> rhsDims, llvm::StringRef lhs,
llvm::StringRef rhs);

bool verifyCompatibleDims(int64_t dimSize1, int64_t dimSize2);

// WindowDimension described how the kernel window moves across the base area
// in a particular dimension.
// Describes the windowing in an operation such as convolution.
Expand Down Expand Up @@ -86,6 +98,9 @@ LogicalResult verifyReplicaGroups(std::optional<Location> location,
bool useGlobalDeviceIds,
std::optional<size_t> expectedGroupSize);

LogicalResult verifyPrecisionConfig(std::optional<Location> loc,
std::optional<ArrayAttr> maybeArrayAttr);

LogicalResult verifyConvolutionAttributes(
std::optional<Location> location, Type lhsType, Type rhsType,
int64_t inputBatchDimension, int64_t inputFeatureDimension,
Expand Down Expand Up @@ -207,6 +222,14 @@ LogicalResult inferDotOp(
RankedTensorType rhsType, std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult checkDotGeneralConstraints(
std::optional<Location> location, Type lhsType, Type rhsType,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
std::optional<ArrayAttr> precisionConfig);

LogicalResult inferDotGeneralOp(
std::optional<Location> location, Type lhsType, Type rhsType,
ArrayRef<int64_t> lhsBatchingDimensions,
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/print_types_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ func.func @tuple_type_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// -----

func.func @tuple_count_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// expected-error @+1 {{custom op 'stablehlo.tuple' 2 operands present, but expected 1}}
// expected-error @+1 {{custom op 'stablehlo.tuple' number of operands and types do not match: got 2 operands and 1 types}}
%0 = stablehlo.tuple %arg0, %arg0 : tuple<tensor<1xf64>>
func.return %0 : tensor<1xf64>
}

// -----

func.func @pairwise_count_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// expected-error @+1 {{custom op 'stablehlo.optimization_barrier' 2 operands present, but expected 1}}
// expected-error @+1 {{custom op 'stablehlo.optimization_barrier' number of operands and types do not match: got 2 operands and 1 types}}
%0 = stablehlo.optimization_barrier %arg0, %arg0 : tensor<1xf64>
func.return %0 : tensor<1xf64>
}
Expand Down
15 changes: 11 additions & 4 deletions stablehlo/tests/transforms/stablehlo_refine_parameters.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

// RUN: not stablehlo-opt --stablehlo-refine-arguments='types=tensor<f32>,tensor<1xf32>,tensor<?xf32>,tensor<*xf32>,tensor<*xf32>,!stablehlo.token' %s 2>&1 | FileCheck %s --check-prefixes=UNRANKED-ERROR
func.func @main(%arg0: tensor<f32>, %arg1: tensor<1xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: !stablehlo.token) {
// UNRANKED-ERROR: invalid refinement for argument 3, refinement must be ranked in 'tensor<1x?x?xf32>'->'tensor<*xf32>'
// UNRANKED-ERROR: invalid refinement for argument 3, refinement must be ranked in tensor<1x?x?xf32> -> tensor<*xf32>
return
}

Expand Down Expand Up @@ -43,21 +43,28 @@ func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor<f32>) {

// -----

// expected-error @+1 {{invalid refinement for argument 5, refinement must be a tensor in 'tensor<f32>'->'!stablehlo.token'}}
// expected-error @+1 {{invalid refinement for argument 5, refinement must be a tensor in tensor<f32> -> !stablehlo.token}}
func.func @refine_arguments_invalid_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<1xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: tensor<f32>) {
return
}

// -----

// expected-error @+1 {{invalid refinement for argument 1, refinement rank must match operand rank in 'tensor<f32>'->'tensor<1xf32>'}}
// expected-error @+1 {{invalid refinement for argument 1, refinement element types must match in tensor<1xi32> -> tensor<1xf32>}}
func.func @refine_arguments_invalid_element_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<1xi32>, %arg2: tensor<?xf32>, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: tensor<f32>) {
return
}

// -----

// expected-error @+1 {{invalid refinement for argument 1, refinement rank must match operand rank in tensor<f32> -> tensor<1xf32>}}
func.func @refine_arguments_invalid_refine_rank_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xf32>, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: !stablehlo.token) {
return
}

// -----

// expected-error @+1 {{invalid refinement for argument 1, refinement dimension sizes must match for static dimensions in 'tensor<2xf32>'->'tensor<1xf32>'}}
// expected-error @+1 {{invalid refinement for argument 1, refinement dimension sizes must match for static dimensions in tensor<2xf32> -> tensor<1xf32>}}
func.func @refine_arguments_invalid_static_dim_mismatch(%arg0: tensor<f32>, %arg1: tensor<2xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: !stablehlo.token) {
return
}
19 changes: 19 additions & 0 deletions stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s.bc | FileCheck %s --check-prefix=CHECK-VERSION
// RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize --print-stablehlo-version | FileCheck %s --check-prefix=CHECK-VERSION-LATEST
// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s | FileCheck %s --check-prefix=CHECK-VERSION-NOT-BYTECODE

// This file tests the `getPortableArtifactVersion` Serialization API.
// Any breakages to this file likely indicate that the MLIR Bytecode Format
// has changed, or that the StableHLO producer string emit by
// `serializePortableArtifact` has changed.
//
// See the `getPortableArtifactVersion` doc comments for more details.

// CHECK-VERSION: // Reading portable artifact with StableHLO version: 1.1.0
// CHECK-VERSION-NOT-BYTECODE: // Failed parsing StableHLO version from portable artifact
// CHECK-VERSION-LATEST: // Reading portable artifact with StableHLO version: {{.*}}

func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.add %arg0, %arg0 : tensor<f32>
func.return %0 : tensor<f32>
}
Binary file not shown.
18 changes: 18 additions & 0 deletions stablehlo/tools/StablehloTranslateMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
Expand Down Expand Up @@ -65,6 +66,12 @@ llvm::cl::opt<bool> stripDebuginfoOption(
"strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"),
llvm::cl::init(false));

llvm::cl::opt<bool> printStablehloVersion(
"print-stablehlo-version",
llvm::cl::desc(
"When deserializing a portable artifact, print the StableHLO version"),
llvm::cl::init(false));

llvm::cl::opt<std::string> targetOption(
"target", llvm::cl::desc("Target version for serialization"),
llvm::cl::init(""));
Expand Down Expand Up @@ -306,6 +313,17 @@ TranslateFromMLIRRegistration serializeRegistration(
TranslateToMLIRRegistration deserializeRegistration(
"deserialize", "Deserialize a portable artifact into a StableHLO program",
[](llvm::StringRef input, mlir::MLIRContext *context) {
if (printStablehloVersion.getValue()) {
auto version = stablehlo::getPortableArtifactVersion(input);
if (failed(version)) {
llvm::outs()
<< "// Failed parsing StableHLO version from portable artifact\n";
} else {
llvm::outs()
<< "// Reading portable artifact with StableHLO version: "
<< *version << "\n";
}
}
return stablehlo::deserializePortableArtifact(input, context);
},
[](DialectRegistry &registry) {
Expand Down
10 changes: 9 additions & 1 deletion stablehlo/transforms/StablehloRefineArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand Down Expand Up @@ -79,7 +80,8 @@ LogicalResult refinementError(func::FuncOp func, int64_t idx, Type argType,
Type refinedType, StringRef msg) {
return func.emitOpError()
<< "invalid refinement for argument " << idx << ", refinement " << msg
<< " in " << argType << "->" << refinedType;
<< " in " << mlir::debugString(argType) << " -> "
<< mlir::debugString(refinedType);
}

// Validates refinement types:
Expand Down Expand Up @@ -113,6 +115,12 @@ LogicalResult validateRefinedTypes(func::FuncOp func, TypeRange refinedTypes) {
return refinementError(func, i, type, refinedType, "must be a tensor");
}

// Check that element types match
if (tensorType.getElementType() != refinedTensorType.getElementType()) {
return refinementError(func, i, type, refinedType,
"element types must match");
}

// Refined rank cannot be unranked if mismatch
if (isa<UnrankedTensorType>(refinedType)) {
return refinementError(func, i, type, refinedType, "must be ranked");
Expand Down

0 comments on commit b3d3cac

Please sign in to comment.