Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Re-land linalg.matmul move to ODS + Remove/update failing obsolete OpDSL tests for linalg.matmul. #115319

Merged
merged 4 commits into from
Nov 7, 2024

Conversation

shahidact
Copy link
Contributor

The earlier PR(#104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"

shahidact and others added 3 commits November 6, 2024 23:53
…ul' ops. (llvm#104783)

The main goal of this patch is to extend the semantic of 'linalg.matmul'
named op to include per operand transpose semantic while also laying out
a way to move ops definition from OpDSL to tablegen. Hence, it is
implemented in tablegen. Transpose semantic is as follows.

By default 'linalg.matmul' behavior will remain as is.
Transpose/broadcast semantics can be appiled explicitly by specifying
the optional indexing_map attribute. By default, no transpose/broadcast
is mandated.

    Example:
    ```

linalg.matmul indexing_maps = [
                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
                   affine_map<(d0, d1, d2) -> (d2, d1)>,
                   affine_map<(d0, d1, d2) -> (d0, d1)>
                   ]
                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
                   outs(%arg2: memref<3x7xf32>)

linalg.matmul indexing_maps = [
                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
                   affine_map<(d0, d1, d2) -> (d2, d1)>,
                   affine_map<(d0, d1, d2) -> (d0, d1)>
                  ]
                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
                  outs(%arg2: memref<3x7xf32>)
    ```
…atmul.

The earlier PR(llvm#104783) was reverted
due to two failing OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based
OpDSL, these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below tests.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"
@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-nvgpu

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

The earlier PR(#104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"


Patch is 67.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115319.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+10)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-72)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+134)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+12-5)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+251-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+6)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-17)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+111)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+159)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+243)
  • (modified) mlir/test/python/dialects/linalg/ops.py (-75)
  • (modified) mlir/test/python/integration/dialects/linalg/opsrun.py (-115)
  • (modified) mlir/test/python/integration/dialects/transform.py (+14-12)
  • (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+5-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b81a4c9c8760cf..c0eff99c850752 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the user has supplied an explicit indexing maps for this op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUserDefinedMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed..ee88ca516de6ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c2fee8ea55c960..2b47414ff5e924 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{
+    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+    }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      SmallVector<AffineMap> getDefaultIndexingMaps();
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bd77965194b27f..0cffadf8fb64a0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,14 +15,21 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
+
     if (indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
   }
-
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..c909d13e4314b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap, 3>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+  AffineExpr d0, d1, d2;
+  SmallVector<AffineMap, 3> indexingMaps;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+  return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+  return llvm::map_to_vector(
+      getDefaultIndexingMapsForMatmul(context),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 }
 
 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
-                                   ValueRange inputs, ValueRange outputs) {
-  p.printOptionalAttrDict(
-      op->getAttrs(),
-      /*elidedAttrs=*/{"operandSegmentSizes",
-                       // See generated code in
-                       // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                                   ValueRange inputs, ValueRange outputs,
+                                   ArrayRef<StringRef> elidedAttrs = {}) {
+  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+  auto explicitRange = explictMap.getResults();
+  auto defaultRange = defaultMap.getResults();
+  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+  llvm::set_union(explicitSet, defaultSet);
+  return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+  return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+                                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opInde...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-mlir

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

The earlier PR(#104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"


Patch is 67.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115319.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+10)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-72)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+134)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+12-5)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+251-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+6)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-17)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+111)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+159)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+243)
  • (modified) mlir/test/python/dialects/linalg/ops.py (-75)
  • (modified) mlir/test/python/integration/dialects/linalg/opsrun.py (-115)
  • (modified) mlir/test/python/integration/dialects/transform.py (+14-12)
  • (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+5-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b81a4c9c8760cf..c0eff99c850752 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the user has supplied an explicit indexing maps for this op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUserDefinedMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed..ee88ca516de6ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c2fee8ea55c960..2b47414ff5e924 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{
+    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+    }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      SmallVector<AffineMap> getDefaultIndexingMaps();
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bd77965194b27f..0cffadf8fb64a0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,14 +15,21 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
+
     if (indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
   }
-
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..c909d13e4314b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap, 3>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+  AffineExpr d0, d1, d2;
+  SmallVector<AffineMap, 3> indexingMaps;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+  return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+  return llvm::map_to_vector(
+      getDefaultIndexingMapsForMatmul(context),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 }
 
 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
-                                   ValueRange inputs, ValueRange outputs) {
-  p.printOptionalAttrDict(
-      op->getAttrs(),
-      /*elidedAttrs=*/{"operandSegmentSizes",
-                       // See generated code in
-                       // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                                   ValueRange inputs, ValueRange outputs,
+                                   ArrayRef<StringRef> elidedAttrs = {}) {
+  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+  auto explicitRange = explictMap.getResults();
+  auto defaultRange = defaultMap.getResults();
+  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+  llvm::set_union(explicitSet, defaultSet);
+  return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+  return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+                                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opInde...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

The earlier PR(#104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"


Patch is 67.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115319.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+10)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-72)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+134)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+12-5)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+251-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+6)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-17)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+111)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+159)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+243)
  • (modified) mlir/test/python/dialects/linalg/ops.py (-75)
  • (modified) mlir/test/python/integration/dialects/linalg/opsrun.py (-115)
  • (modified) mlir/test/python/integration/dialects/transform.py (+14-12)
  • (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+5-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b81a4c9c8760cf..c0eff99c850752 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the user has supplied an explicit indexing maps for this op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUserDefinedMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed..ee88ca516de6ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c2fee8ea55c960..2b47414ff5e924 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{
+    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+    }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      SmallVector<AffineMap> getDefaultIndexingMaps();
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bd77965194b27f..0cffadf8fb64a0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,14 +15,21 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
+
     if (indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
   }
-
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..c909d13e4314b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap, 3>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+  AffineExpr d0, d1, d2;
+  SmallVector<AffineMap, 3> indexingMaps;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+  return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+  return llvm::map_to_vector(
+      getDefaultIndexingMapsForMatmul(context),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 }
 
 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
-                                   ValueRange inputs, ValueRange outputs) {
-  p.printOptionalAttrDict(
-      op->getAttrs(),
-      /*elidedAttrs=*/{"operandSegmentSizes",
-                       // See generated code in
-                       // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                                   ValueRange inputs, ValueRange outputs,
+                                   ArrayRef<StringRef> elidedAttrs = {}) {
+  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+  auto explicitRange = explictMap.getResults();
+  auto defaultRange = defaultMap.getResults();
+  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+  llvm::set_union(explicitSet, defaultSet);
+  return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+  return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+                                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opInde...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

The earlier PR(#104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"


Patch is 67.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115319.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+10)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-72)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+134)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+12-5)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+251-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+6)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-17)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+111)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+159)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+243)
  • (modified) mlir/test/python/dialects/linalg/ops.py (-75)
  • (modified) mlir/test/python/integration/dialects/linalg/opsrun.py (-115)
  • (modified) mlir/test/python/integration/dialects/transform.py (+14-12)
  • (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+5-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b81a4c9c8760cf..c0eff99c850752 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the user has supplied an explicit indexing maps for this op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUserDefinedMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed..ee88ca516de6ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c2fee8ea55c960..2b47414ff5e924 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{
+    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+    }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      SmallVector<AffineMap> getDefaultIndexingMaps();
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bd77965194b27f..0cffadf8fb64a0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,14 +15,21 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
+
     if (indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
   }
-
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..c909d13e4314b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap, 3>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+  AffineExpr d0, d1, d2;
+  SmallVector<AffineMap, 3> indexingMaps;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+  return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+  return llvm::map_to_vector(
+      getDefaultIndexingMapsForMatmul(context),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 }
 
 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
-                                   ValueRange inputs, ValueRange outputs) {
-  p.printOptionalAttrDict(
-      op->getAttrs(),
-      /*elidedAttrs=*/{"operandSegmentSizes",
-                       // See generated code in
-                       // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                                   ValueRange inputs, ValueRange outputs,
+                                   ArrayRef<StringRef> elidedAttrs = {}) {
+  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+  auto explicitRange = explictMap.getResults();
+  auto defaultRange = defaultMap.getResults();
+  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+  llvm::set_union(explicitSet, defaultSet);
+  return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+  return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+                                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opInde...
[truncated]

Copy link

github-actions bot commented Nov 7, 2024

✅ With the latest revision this PR passed the Python code formatter.

@@ -99,26 +99,28 @@ def basic(target: any_op_t()):
# CHECK-LABEL: TEST: test_apply_patterns
@construct_and_print_in_module
def test_apply_patterns(module_):
M, N, K = 3, 5, 3
b, M, N, K = 1, 3, 5, 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but remember, we're do the same with brgemm soon, so we'll need to (re)move this eventually for good.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give it another go. The fixes here were suggested by the breaking PR, so I'll approve and merge. Thanks!

@rengolin rengolin merged commit 3ad0148 into llvm:main Nov 7, 2024
8 checks passed
s-perron pushed a commit to s-perron/llvm-project that referenced this pull request Nov 7, 2024
…ling obsolete OpDSL tests. (llvm#115319)

The earlier PR(llvm#104783) which
introduces
transpose and broadcast semantic to linalg.matmul was reverted due to
two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of
Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files.
All other files
were part of earlier PR and just cherry picked.
    "mlir/test/python/integration/dialects/linalg/opsrun.py"
    "mlir/test/python/integration/dialects/transform.py"

---------

Co-authored-by: Renato Golin <rengolin@systemcall.eu>
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still dont understand whats happening in this change. Its a pretty big change but it is titled as if it is removing obsolete changes. Please describe what is happening with this change and why it is touching all these files. A 900+ line change without accurate title and description is strange.

@rengolin
Copy link
Member

rengolin commented Nov 8, 2024

I have changed the message before merge to reflect that this is a re-land of a previous PR(#104783) with the tests that broke a post-commit bot fixed, as recommended in #104783 (comment).

@MaheshRavishankar
Copy link
Contributor

I have changed the message before merge to reflect that this is a re-land of a previous PR(#104783) with the tests that broke a post-commit bot fixed, as recommended in #104783 (comment).

Ok I see now that this is a reland. I didnt click on the link for the previous PR in the description to see the changes there compared to in this PR and my read of the description/title didnt make that jump as well. Sorry for the noise!

rolfmorel added a commit that referenced this pull request Nov 11, 2024
…FC (#115763)

#115319 added a tablegen description for an op (MatmulOp) but missed
closing one of the code blocks in the description. As a result the
generated docs broke, i.e. https://mlir.llvm.org/docs/Dialects/Linalg
was broken after this code block.
@makslevental
Copy link
Contributor

makslevental commented Nov 11, 2024

FYI after this it's basically impossible to emit a named linalg.matmul from python since opdsl pulls double duty. And just restoring the deleted opdsl impl isn't sufficient because the linalg.MatMulOp now needs the indexing attributes or else you trip some assert somewhere (ie idx < size()). In order to restore linalg.matmul I had to

@linalg.linalg_structured_op
def _matmul_generic(
    A=linalg.TensorDef(linalg.T1, linalg.S.M, linalg.S.K),
    B=linalg.TensorDef(linalg.T2, linalg.S.K, linalg.S.N),
    C=linalg.TensorDef(linalg.U, linalg.S.M, linalg.S.N, output=True),
    cast=linalg.TypeFnAttrDef(default=linalg.TypeFn.cast_signed),
):
    linalg.domain(linalg.D.m, linalg.D.n, linalg.D.k)
    linalg.implements(linalg.ContractionOpInterface)
    C[linalg.D.m, linalg.D.n] += cast(linalg.U, A[linalg.D.m, linalg.D.k]) * cast(
        linalg.U, B[linalg.D.k, linalg.D.n]
    )


_matmul_generic.op_name = "matmul"


def matmul(A, B, C, *, loc=None, ip=None):
    op_configs = linalg.LinalgOpConfig.from_linalg_op_def(
        _matmul_generic.op_def, context=ir.Context.current
    )
    op_config = op_configs[0]
    (
        _all_arg_defs,
        _in_arg_defs,
        _out_arg_defs,
        _outs,
        result_types,
        _type_mapping,
        indexing_maps_attr,
        _iterator_types_attr,
        _index_attrs,
        _fn_attr_mapping,
        _block_arg_types,
    ) = linalg.opdsl.lang.emitter.prepare_common_structured_op(
        op_config.structured_op, A, B, outs=[C], loc=loc, ip=ip
    )
    named_op = linalg.MatmulOp(
        result_types,
        inputs=[A, B],
        outputs=[C],
        indexing_maps=indexing_maps_attr,
        cast=linalg.TypeFn.cast_signed,
        loc=loc,
        ip=ip,
    )
    linalg.fill_builtin_region(named_op.operation)
    return named_op.results

I don't know what to propose/suggest here because

  1. I haven't been keeping up with this stream
  2. I don't know if anyone even cares - is anyone (except me) using linalg from Python?

@MaheshRavishankar
Copy link
Contributor

I was afraid of such a thing when the decision was to call the new op added to TableGen the same as the OpDSL name. Previous suggestion was that we change the name of the TableGen version of the op to something else, but that wasnt accepted. I would still suggest doing that to avoid such breakages.

@rengolin
Copy link
Member

rengolin commented Nov 12, 2024

I was afraid of such a thing when the decision was to call the new op added to TableGen the same as the OpDSL name. Previous suggestion was that we change the name of the TableGen version of the op to something else, but that wasnt accepted. I would still suggest doing that to avoid such breakages.

Having two matmuls will create more confusion than this. We need to move away from OpDSL and the first op is the one that was bound to break things. If we leave the ops in OpDSL, then the Python code will use one set and the C++ code another, and that's way way worse than any temporary code churn.

It seems obvious to me that Python should be able to call other ODS operations. If so, we need to tie the same mechanism to matmul, and then the follow up operations. This seems like a mechanical solution would work here.

I don't know the Python code well enough, @makslevental, @ftynse can you help us here?

@makslevental
Copy link
Contributor

It seems obvious to me that Python should be able to call other ODS operations. If so, we need to tie the same mechanism to matmul, and then the follow up operations. This seems like a mechanical solution would work here.

It can - there are "binded" builders for every op in every dialect - eg linalg.MatMulOp or linalg.FillOp etc. The snippet above shows that it can be done (I landed exactly this "fix" in mlir-python-wheels).

What makes the current state onerous is that for whatever reason (no doubt due to changes going on in this stream), suddenly you need to explicitly pass the indexing maps even for linalg.MatMulOp whereas before you did not. Thus you have to build the indexing_maps_attr by hand and then call the binded builder.

But like I said: I highly doubt anyone is using Python to emit linalg other than me so I wouldn't call this a catastrophe. Just an FYI. And I agree with you @rengolin, I don't think introducing overlapping names (matmul_op_dsl, matmul_non_op_dsl) is a good idea and I also agree that OpDSL should go.

@stellaraccident
Copy link
Contributor

stellaraccident commented Nov 12, 2024

Some mechanical APIs for python would be what is called for here, and the presence or lack of them can't gate the core development.

Unless if a maintainer stands up who would be willing to pay to update the opdsl infrastructure and separate the concerns, then that path is dead. In the protracted prior debate, no one was willing to make such a commitment, so my strong suspicion is that the only path forward here is for interested parties to develop more mechanical API adaptors/helpers like you have done. I know that's not great, but I'm not sure there is another path open to the project which has costs that anyone will pay.

(I'm sorry - I know how frustrating this kind of thing is)

@rengolin
Copy link
Member

rengolin commented Nov 12, 2024

It can - there are "binded" builders for every op in every dialect - eg linalg.MatMulOp or linalg.FillOp etc. The snippet above shows that it can be done (I landed exactly this "fix" in mlir-python-wheels).

If this sounds like the right fix upstream, we should merge it and do the same for the upcoming ops. (@shahidact, @javedabsar, @adam-smnk).

What makes the current state onerous is that for whatever reason (no doubt due to changes going on in this stream), suddenly you need to explicitly pass the indexing maps even for linalg.MatMulOp whereas before you did not. Thus you have to build the indexing_maps_attr by hand and then call the binded builder.

This was not the intention. The parser accepts plain syntax and builds the affine map by default. We need to find a way for that not to be necessary from Python when the maps are standard. You can keep that logic to support broadcast/transpose maps, though (which is the new thing).

I highly doubt anyone is using Python to emit linalg other than me so I wouldn't call this a catastrophe.

This gives us time to fix it properly.

I'm not sure there is another path open to the project which has costs that anyone will pay.

If there is, I can't see it either. Neither could anyone else on the original discussion. So let's just all commit to this path and fix all the problems we find along the way.

Groverkss added a commit to Groverkss/llvm-project that referenced this pull request Nov 15, 2024
…date failing obsolete OpDSL tests. (llvm#115319)"

This reverts commit 3ad0148.
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…date failing obsolete OpDSL tests. (llvm#115319)"

This reverts commit 3ad0148.
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…ling obsolete OpDSL tests. (llvm#115319)

The earlier PR(llvm#104783) which
introduces
transpose and broadcast semantic to linalg.matmul was reverted due to
two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of
Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files.
All other files
were part of earlier PR and just cherry picked.
    "mlir/test/python/integration/dialects/linalg/opsrun.py"
    "mlir/test/python/integration/dialects/transform.py"

---------

Co-authored-by: Renato Golin <rengolin@systemcall.eu>
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…FC (llvm#115763)

llvm#115319 added a tablegen description for an op (MatmulOp) but missed
closing one of the code blocks in the description. As a result the
generated docs broke, i.e. https://mlir.llvm.org/docs/Dialects/Linalg
was broken after this code block.
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…date failing obsolete OpDSL tests. (llvm#115319)"

This reverts commit 3ad0148.
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…date failing obsolete OpDSL tests. (llvm#115319)"

This reverts commit 3ad0148.
@rengolin rengolin changed the title [MLIR][Linalg] Remove/update failing obsolete OpDSL tests for linalg.matmul. [MLIR][Linalg] Re-land linalg.matmul move to ODS. + Remove/update failing obsolete OpDSL tests for linalg.matmul. Nov 18, 2024
@rengolin rengolin changed the title [MLIR][Linalg] Re-land linalg.matmul move to ODS. + Remove/update failing obsolete OpDSL tests for linalg.matmul. [MLIR][Linalg] Re-land linalg.matmul move to ODS + Remove/update failing obsolete OpDSL tests for linalg.matmul. Nov 18, 2024
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 18, 2024
…date failing obsolete OpDSL tests. (llvm#115319)"

This reverts commit 3ad0148.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants