From b3d3cacde8994df313297e68713ed74c2ca279ee Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Fri, 6 Dec 2024 10:15:26 -0800 Subject: [PATCH] Integrate LLVM at llvm/llvm-project@2ccf7ed277df (#2661) + backport to gh changes from the temporary.patch --- BUILD.bazel | 1 + WORKSPACE.bazel | 4 +- build_tools/llvm_version.txt | 2 +- stablehlo/dialect/Serialization.cpp | 36 ++++++++++++++++++ stablehlo/dialect/Serialization.h | 12 ++++++ stablehlo/dialect/TypeInference.cpp | 28 ++++++++++++-- stablehlo/dialect/TypeInference.h | 23 +++++++++++ stablehlo/tests/print_types_invalid.mlir | 4 +- .../stablehlo_refine_parameters.mlir | 15 ++++++-- .../vhlo/vhlo_emit_version_api.1_1_0.mlir | 19 +++++++++ .../vhlo/vhlo_emit_version_api.1_1_0.mlir.bc | Bin 0 -> 294 bytes stablehlo/tools/StablehloTranslateMain.cpp | 18 +++++++++ .../transforms/StablehloRefineArguments.cpp | 10 ++++- 13 files changed, 158 insertions(+), 14 deletions(-) create mode 100644 stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir create mode 100644 stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir.bc diff --git a/BUILD.bazel b/BUILD.bazel index 53a82dd0b0b..50013d0faa7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -925,6 +925,7 @@ cc_library( ":stablehlo_serialization", ":version", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeReader", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 9791a1b98fd..627e84b4173 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "bd92e46204331b9af296f53abb708317e72ab7a8" +LLVM_COMMIT = "2ccf7ed277df28651b94bbee9fccefdf22fb074f" -LLVM_SHA256 = "60f71fc5b237e10729edbed8cbe23b7081dabe254fbcb1ea82db8789cb7eaecf" +LLVM_SHA256 = "ca68a54dcd12c0dde32732a90899bf57e0f3f96fc43d8d1124d95a5eae627508" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 9835e90f7ae..9f6d726f19c 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -bd92e46204331b9af296f53abb708317e72ab7a8 +2ccf7ed277df28651b94bbee9fccefdf22fb074f diff --git a/stablehlo/dialect/Serialization.cpp b/stablehlo/dialect/Serialization.cpp index 3da045dad93..6bbdfacb3ac 100644 --- a/stablehlo/dialect/Serialization.cpp +++ b/stablehlo/dialect/Serialization.cpp @@ -15,6 +15,9 @@ limitations under the License. #include "stablehlo/dialect/Serialization.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "mlir/Bytecode/BytecodeReader.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" @@ -29,6 +32,8 @@ limitations under the License. #include "stablehlo/dialect/VhloOps.h" #include "stablehlo/transforms/Passes.h" +#define DEBUG_TYPE "compat-passes" + namespace mlir { namespace stablehlo { @@ -89,5 +94,36 @@ OwningOpRef deserializePortableArtifact(StringRef sourceStr, return module; } +FailureOr getPortableArtifactVersion(llvm::StringRef bytecode) { + auto logFailure = [&](llvm::StringRef message) { + LLVM_DEBUG(llvm::dbgs() << "Failed to get portable artifact version: " + << message << "\n"); + return failure(); + }; + // Must start with MLiRxStableHLO_vX.Y.Z, minimum length of 19. + constexpr size_t minHeaderLength = 19; + if (bytecode.size() < minHeaderLength) return logFailure("min header"); + + // Truncate to the end of the null-terminated producer string. + size_t pos = bytecode.find('\0'); + if (pos == llvm::StringRef::npos) return logFailure("no terminator"); + bytecode = bytecode.substr(0, pos); + + // Check if the bytecode is valid, starts with MLiR magic number. + if (!isBytecode( + llvm::MemoryBuffer::getMemBuffer(bytecode)->getMemBufferRef())) + return logFailure("not bytecode"); + + // Skip 4 bytes for the magic number. + std::string stablehloHeader = "StableHLO_v"; + size_t stablehloPos = bytecode.find(stablehloHeader); + if (stablehloPos == llvm::StringRef::npos) + return logFailure("not a StableHLO portable artifact"); + + // Skip the 11 bytes for StableHLO_v to get the StableHLO version to parse. + StringRef version = bytecode.substr(stablehloPos + stablehloHeader.size()); + return vhlo::Version::fromString(version); +} + } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/dialect/Serialization.h b/stablehlo/dialect/Serialization.h index f51c7794749..abe95e63336 100644 --- a/stablehlo/dialect/Serialization.h +++ b/stablehlo/dialect/Serialization.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LogicalResult.h" +#include "stablehlo/dialect/Version.h" namespace mlir { namespace stablehlo { @@ -43,6 +44,17 @@ LogicalResult serializePortableArtifact(ModuleOp module, OwningOpRef deserializePortableArtifact(StringRef sourceStr, MLIRContext* context); +// Get portable artifact version from the producer string after the MLIR +// Bytecode magic number `MLïRStableHLO_vX.Y.Z` -> X.Y.Z +// Returns failure if input string is not a valid portable artifact produced by +// serializePortableArtifact APIs, which would cause the bytecode artifact to +// not have the proper producer string. +// +// This method should be safe, since any changes to the bytecode format would +// warrant a bytecode version bump, and MLIR bytecode gives the option to +// specify a forward compatible bytecode version to target. +FailureOr getPortableArtifactVersion(llvm::StringRef bytecode); + } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index ffe55992708..83ba24af635 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -392,7 +392,7 @@ LogicalResult checkDimsDistinct(std::optional loc, LogicalResult checkDimInBounds(std::optional loc, int64_t dim, int64_t upperBound, StringRef dimName, StringRef upperBoundName, - bool upperBoundInclusive = false) { + bool upperBoundInclusive) { StringRef rangeEnd = upperBoundInclusive ? "]" : ")"; if (dim < 0 || dim >= upperBound + (upperBoundInclusive ? 1 : 0)) return emitOptionalError(loc, "Expects ", dimName, " to be in range [0, ", @@ -2280,14 +2280,13 @@ LogicalResult inferDotOp( return success(); } -LogicalResult inferDotGeneralOp( +LogicalResult checkDotGeneralConstraints( std::optional location, Type lhsType, Type rhsType, ArrayRef lhsBatchingDimensions, ArrayRef rhsBatchingDimensions, ArrayRef lhsContractingDimensions, ArrayRef rhsContractingDimensions, - std::optional precisionConfig, - SmallVectorImpl& inferredReturnShapes) { + std::optional precisionConfig) { // dot_general_c11 if (failed(verifyPrecisionConfig(location, precisionConfig))) return failure(); @@ -2366,9 +2365,30 @@ LogicalResult inferDotGeneralOp( "contracting dimension sizes must " "match for lhs/rhs"); } + return success(); +} + +LogicalResult inferDotGeneralOp( + std::optional location, Type lhsType, Type rhsType, + ArrayRef lhsBatchingDimensions, + ArrayRef rhsBatchingDimensions, + ArrayRef lhsContractingDimensions, + ArrayRef rhsContractingDimensions, + std::optional precisionConfig, + SmallVectorImpl& inferredReturnShapes) { + if (failed(checkDotGeneralConstraints( + location, lhsType, rhsType, lhsBatchingDimensions, + rhsBatchingDimensions, lhsContractingDimensions, + rhsContractingDimensions, precisionConfig))) { + return failure(); + } // Infer the output dimensions of the operation. SmallVector dimensions; + auto lhsRankedType = cast(lhsType); + auto rhsRankedType = cast(rhsType); + auto lhsShape = lhsRankedType.getShape(); + auto rhsShape = rhsRankedType.getShape(); for (const int64_t lhsBatchingDim : lhsBatchingDimensions) dimensions.push_back(lhsShape[lhsBatchingDim]); for (int64_t i = 0; i < lhsRankedType.getRank(); i++) diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index e4710cbd95d..9c622acd761 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -54,6 +54,18 @@ FailureOr> convertWindowReversalAttribute( std::optional optionalAttr, std::optional loc, StringRef attrName); +LogicalResult checkDimInBounds(std::optional loc, int64_t dim, + int64_t upperBound, StringRef dimName, + StringRef upperBoundName, + bool upperBoundInclusive = false); + +LogicalResult checkDimsDistinct(std::optional loc, + ArrayRef lhsDims, + ArrayRef rhsDims, llvm::StringRef lhs, + llvm::StringRef rhs); + +bool verifyCompatibleDims(int64_t dimSize1, int64_t dimSize2); + // WindowDimension described how the kernel window moves across the base area // in a particular dimension. // Describes the windowing in an operation such as convolution. @@ -86,6 +98,9 @@ LogicalResult verifyReplicaGroups(std::optional location, bool useGlobalDeviceIds, std::optional expectedGroupSize); +LogicalResult verifyPrecisionConfig(std::optional loc, + std::optional maybeArrayAttr); + LogicalResult verifyConvolutionAttributes( std::optional location, Type lhsType, Type rhsType, int64_t inputBatchDimension, int64_t inputFeatureDimension, @@ -207,6 +222,14 @@ LogicalResult inferDotOp( RankedTensorType rhsType, std::optional precisionConfig, SmallVectorImpl& inferredReturnShapes); +LogicalResult checkDotGeneralConstraints( + std::optional location, Type lhsType, Type rhsType, + ArrayRef lhsBatchingDimensions, + ArrayRef rhsBatchingDimensions, + ArrayRef lhsContractingDimensions, + ArrayRef rhsContractingDimensions, + std::optional precisionConfig); + LogicalResult inferDotGeneralOp( std::optional location, Type lhsType, Type rhsType, ArrayRef lhsBatchingDimensions, diff --git a/stablehlo/tests/print_types_invalid.mlir b/stablehlo/tests/print_types_invalid.mlir index 2c694897b43..f87691e4c75 100644 --- a/stablehlo/tests/print_types_invalid.mlir +++ b/stablehlo/tests/print_types_invalid.mlir @@ -113,7 +113,7 @@ func.func @tuple_type_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> { // ----- func.func @tuple_count_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> { - // expected-error @+1 {{custom op 'stablehlo.tuple' 2 operands present, but expected 1}} + // expected-error @+1 {{custom op 'stablehlo.tuple' number of operands and types do not match: got 2 operands and 1 types}} %0 = stablehlo.tuple %arg0, %arg0 : tuple> func.return %0 : tensor<1xf64> } @@ -121,7 +121,7 @@ func.func @tuple_count_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> { // ----- func.func @pairwise_count_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> { - // expected-error @+1 {{custom op 'stablehlo.optimization_barrier' 2 operands present, but expected 1}} + // expected-error @+1 {{custom op 'stablehlo.optimization_barrier' number of operands and types do not match: got 2 operands and 1 types}} %0 = stablehlo.optimization_barrier %arg0, %arg0 : tensor<1xf64> func.return %0 : tensor<1xf64> } diff --git a/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir index bb45fbb9c09..541f2d79990 100644 --- a/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +++ b/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir @@ -7,7 +7,7 @@ // RUN: not stablehlo-opt --stablehlo-refine-arguments='types=tensor,tensor<1xf32>,tensor,tensor<*xf32>,tensor<*xf32>,!stablehlo.token' %s 2>&1 | FileCheck %s --check-prefixes=UNRANKED-ERROR func.func @main(%arg0: tensor, %arg1: tensor<1xf32>, %arg2: tensor, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: !stablehlo.token) { - // UNRANKED-ERROR: invalid refinement for argument 3, refinement must be ranked in 'tensor<1x?x?xf32>'->'tensor<*xf32>' + // UNRANKED-ERROR: invalid refinement for argument 3, refinement must be ranked in tensor<1x?x?xf32> -> tensor<*xf32> return } @@ -43,21 +43,28 @@ func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor) { // ----- -// expected-error @+1 {{invalid refinement for argument 5, refinement must be a tensor in 'tensor'->'!stablehlo.token'}} +// expected-error @+1 {{invalid refinement for argument 5, refinement must be a tensor in tensor -> !stablehlo.token}} func.func @refine_arguments_invalid_type_mismatch(%arg0: tensor, %arg1: tensor<1xf32>, %arg2: tensor, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: tensor) { return } // ----- -// expected-error @+1 {{invalid refinement for argument 1, refinement rank must match operand rank in 'tensor'->'tensor<1xf32>'}} +// expected-error @+1 {{invalid refinement for argument 1, refinement element types must match in tensor<1xi32> -> tensor<1xf32>}} +func.func @refine_arguments_invalid_element_type_mismatch(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: tensor) { + return +} + +// ----- + +// expected-error @+1 {{invalid refinement for argument 1, refinement rank must match operand rank in tensor -> tensor<1xf32>}} func.func @refine_arguments_invalid_refine_rank_mismatch(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: !stablehlo.token) { return } // ----- -// expected-error @+1 {{invalid refinement for argument 1, refinement dimension sizes must match for static dimensions in 'tensor<2xf32>'->'tensor<1xf32>'}} +// expected-error @+1 {{invalid refinement for argument 1, refinement dimension sizes must match for static dimensions in tensor<2xf32> -> tensor<1xf32>}} func.func @refine_arguments_invalid_static_dim_mismatch(%arg0: tensor, %arg1: tensor<2xf32>, %arg2: tensor, %arg3: tensor<1x?x?xf32>, %arg4: tensor<*xf32>, %arg5: !stablehlo.token) { return } diff --git a/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir b/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir new file mode 100644 index 00000000000..a3c1ec1685e --- /dev/null +++ b/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s.bc | FileCheck %s --check-prefix=CHECK-VERSION +// RUN: stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize --print-stablehlo-version | FileCheck %s --check-prefix=CHECK-VERSION-LATEST +// RUN: stablehlo-translate --deserialize --print-stablehlo-version %s | FileCheck %s --check-prefix=CHECK-VERSION-NOT-BYTECODE + +// This file tests the `getPortableArtifactVersion` Serialization API. +// Any breakages to this file likely indicate that the MLIR Bytecode Format +// has changed, or that the StableHLO producer string emit by +// `serializePortableArtifact` has changed. +// +// See the `getPortableArtifactVersion` doc comments for more details. + +// CHECK-VERSION: // Reading portable artifact with StableHLO version: 1.1.0 +// CHECK-VERSION-NOT-BYTECODE: // Failed parsing StableHLO version from portable artifact +// CHECK-VERSION-LATEST: // Reading portable artifact with StableHLO version: {{.*}} + +func.func @main(%arg0: tensor) -> tensor { + %0 = stablehlo.add %arg0, %arg0 : tensor + func.return %0 : tensor +} diff --git a/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir.bc b/stablehlo/tests/vhlo/vhlo_emit_version_api.1_1_0.mlir.bc new file mode 100644 index 0000000000000000000000000000000000000000..c0694bf0eef5742b655343ae6dcdc942948eb607 GIT binary patch literal 294 zcmYjMQEJ0547IAbUNBiOA(&DMgMN)cvYuf-`{-y{4-nkOqXs7_juY4+cGg{CcE9#O z(&Gt9Papf=&+O#Y+3NTGp_zBpu6iY~qCitXjw~-EzLX3yUPNHZ1A;8s;y^PA01_mI z*h*Ll32TP1#ZA5u&p-tNLYkZ?91&RZhCuO1ktH=!j>w+Upe$p8&<@HnQu4SckpGmq z$hkOA#(L9}`DzE!4eew#`I`Dmuu4?h#^SV}TpyXg8rL>A<@|R&-k%Cn|CraKAM3d4 io*La5-^|*LX6PGrGgW7eBb^HO$U{kkE1u< stripDebuginfoOption( "strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"), llvm::cl::init(false)); +llvm::cl::opt printStablehloVersion( + "print-stablehlo-version", + llvm::cl::desc( + "When deserializing a portable artifact, print the StableHLO version"), + llvm::cl::init(false)); + llvm::cl::opt targetOption( "target", llvm::cl::desc("Target version for serialization"), llvm::cl::init("")); @@ -306,6 +313,17 @@ TranslateFromMLIRRegistration serializeRegistration( TranslateToMLIRRegistration deserializeRegistration( "deserialize", "Deserialize a portable artifact into a StableHLO program", [](llvm::StringRef input, mlir::MLIRContext *context) { + if (printStablehloVersion.getValue()) { + auto version = stablehlo::getPortableArtifactVersion(input); + if (failed(version)) { + llvm::outs() + << "// Failed parsing StableHLO version from portable artifact\n"; + } else { + llvm::outs() + << "// Reading portable artifact with StableHLO version: " + << *version << "\n"; + } + } return stablehlo::deserializePortableArtifact(input, context); }, [](DialectRegistry ®istry) { diff --git a/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/transforms/StablehloRefineArguments.cpp index 3eb7f5ce6bc..a903b7fea04 100644 --- a/stablehlo/transforms/StablehloRefineArguments.cpp +++ b/stablehlo/transforms/StablehloRefineArguments.cpp @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/StablehloOps.h" @@ -79,7 +80,8 @@ LogicalResult refinementError(func::FuncOp func, int64_t idx, Type argType, Type refinedType, StringRef msg) { return func.emitOpError() << "invalid refinement for argument " << idx << ", refinement " << msg - << " in " << argType << "->" << refinedType; + << " in " << mlir::debugString(argType) << " -> " + << mlir::debugString(refinedType); } // Validates refinement types: @@ -113,6 +115,12 @@ LogicalResult validateRefinedTypes(func::FuncOp func, TypeRange refinedTypes) { return refinementError(func, i, type, refinedType, "must be a tensor"); } + // Check that element types match + if (tensorType.getElementType() != refinedTensorType.getElementType()) { + return refinementError(func, i, type, refinedType, + "element types must match"); + } + // Refined rank cannot be unranked if mismatch if (isa(refinedType)) { return refinementError(func, i, type, refinedType, "must be ranked");