From 7d055af14b7dd7e782b87fb883205eda65e8bd44 Mon Sep 17 00:00:00 2001 From: Joshua Cao Date: Mon, 5 Feb 2024 22:59:03 -0800 Subject: [PATCH] [mlir][Symbol] Add verification that symbol's parent is a SymbolTable (#80590) Following the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table. I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message. --------- Co-authored-by: Mehdi Amini --- mlir/include/mlir/IR/SymbolInterfaces.td | 5 +++ mlir/test/Dialect/LLVMIR/global.mlir | 2 +- .../Dialect/Linalg/transform-op-replace.mlir | 6 ++-- mlir/test/Dialect/Transform/ops-invalid.mlir | 3 +- mlir/test/IR/invalid-func-op.mlir | 4 +-- mlir/test/IR/region.mlir | 7 ++-- mlir/test/IR/traits.mlir | 33 +++++++++-------- mlir/test/Transforms/canonicalize-dce.mlir | 14 ++++---- mlir/test/Transforms/canonicalize.mlir | 13 ++++--- mlir/test/Transforms/constant-fold.mlir | 11 +++--- mlir/test/Transforms/cse.mlir | 11 +++--- mlir/test/Transforms/test-legalizer-full.mlir | 8 +++-- mlir/test/python/ir/value.py | 36 +++---------------- 13 files changed, 68 insertions(+), 85 deletions(-) diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index 844601f8f6837c..60b38185fa8ccb 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -171,6 +171,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> { if (concreteOp.isDeclaration() && concreteOp.isPublic()) return concreteOp.emitOpError("symbol declaration cannot have public " "visibility"); + auto parent = $_op->getParentOp(); + if (parent && !parent->hasTrait() && parent->isRegistered()) { + return concreteOp.emitOpError("symbol's parent must have the SymbolTable " + "trait"); + } return success(); }]; diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir index 0649e814bfdfc0..3fa7636d4dd686 100644 --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label // ----- func.func @foo() { - // expected-error @+1 {{must appear at the module level}} + // expected-error @+1 {{op symbol's parent must have the SymbolTable trait}} llvm.mlir.global internal @bar(42) : i32 return diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir index 2801522e81ac2c..1a40912977dec2 100644 --- a/mlir/test/Dialect/Linalg/transform-op-replace.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir @@ -12,8 +12,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.replace %0 { - func.func @foo() { - "dummy_op"() : () -> () + builtin.module { + func.func @foo() { + "dummy_op"() : () -> () + } } } : (!transform.any_op) -> !transform.any_op transform.yield diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index e3f5bcf403f2ad..73a5f36af92952 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -433,10 +433,9 @@ module { // ----- module attributes { transform.with_named_sequence} { - // expected-note @below {{ancestor transform op}} transform.sequence failures(suppress) { ^bb0(%arg0: !transform.any_op): - // expected-error @below {{cannot be defined inside another transform op}} + // expected-error @below {{op symbol's parent must have the SymbolTable trai}} transform.named_sequence @nested() { transform.yield } diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir index d995689ebb8d0b..8fd7af22e9598b 100644 --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -31,7 +31,7 @@ func.func @func_op() { // ----- func.func @func_op() { - // expected-error@+1 {{entry block must have 1 arguments to match function signature}} + // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}} func.func @mixed_named_arguments(f32) { ^entry: return @@ -42,7 +42,7 @@ func.func @func_op() { // ----- func.func @func_op() { - // expected-error@+1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}} + // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}} func.func @mixed_named_arguments(f32) { ^entry(%arg : i32): return diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir index bf4b1bb4e5ab1d..0b959915d6bbbe 100644 --- a/mlir/test/IR/region.mlir +++ b/mlir/test/IR/region.mlir @@ -87,18 +87,17 @@ func.func @named_region_has_wrong_number_of_blocks() { // CHECK: test.single_no_terminator_op "test.single_no_terminator_op"() ( { - func.func @foo1() { return } - func.func @foo2() { return } + %foo = arith.constant 1 : i32 } ) : () -> () // CHECK: test.variadic_no_terminator_op "test.variadic_no_terminator_op"() ( { - func.func @foo1() { return } + %foo = arith.constant 1 : i32 }, { - func.func @foo2() { return } + %bar = arith.constant 1 : i32 } ) : () -> () diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 0402ebe7587508..1e046706379cdb 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -572,15 +572,13 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () { // Ensure that SSACFG regions of operations in GRAPH regions are // checked for dominance -func.func @illegalInsideDominanceFreeScope() -> () { +func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () { test.graph_region { - func.func @test() -> i1 { - ^bb1: + scf.if %cond { // expected-error @+1 {{operand #0 does not dominate this use}} %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) // expected-note @+1 {{operand defined here}} - %1 = "baz"(%2#0) : (i1) -> (i64) - return %2#1 : i1 + %1 = "baz"(%2#0) : (i1) -> (i64) } "terminator"() : () -> () } @@ -591,20 +589,21 @@ func.func @illegalInsideDominanceFreeScope() -> () { // Ensure that SSACFG regions of operations in GRAPH regions are // checked for dominance -func.func @illegalCDFGInsideDominanceFreeScope() -> () { +func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () { test.graph_region { - func.func @test() -> i1 { - ^bb1: - // expected-error @+1 {{operand #0 does not dominate this use}} - %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) - cf.br ^bb4 - ^bb2: - cf.br ^bb2 - ^bb4: - %1 = "foo"() : ()->i64 // expected-note {{operand defined here}} - return %2#1 : i1 + scf.if %cond { + "test.ssacfg_region"() ({ + ^bb1: + // expected-error @+1 {{operand #0 does not dominate this use}} + %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) + cf.br ^bb4 + ^bb2: + cf.br ^bb2 + ^bb4: + %1 = "foo"() : ()->i64 // expected-note {{operand defined here}} + }) : () -> () } - "terminator"() : () -> () + "terminator"() : () -> () } return } diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir index 46545d2e9fd510..3048a7fed636b5 100644 --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) { // Test case: Recursively DCE into enclosed regions. -// CHECK: func @f(%arg0: f32) -// CHECK-NEXT: func @g(%arg1: f32) -// CHECK-NEXT: return +// CHECK: func.func @f(%arg0: f32) +// CHECK-NOT: arith.addf func.func @f(%arg0: f32) { - func.func @g(%arg1: f32) { - %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32 - return - } + "test.region"() ( + { + %0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32 + } + ) : () -> () return } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 9b578e6c2631a7..2cf86b50d432f6 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -424,16 +424,15 @@ func.func @write_only_alloca_fold(%v: f32) { // CHECK-LABEL: func @dead_block_elim func.func @dead_block_elim() { // CHECK-NOT: ^bb - func.func @nested() { - return + builtin.module { + func.func @nested() { + return - ^bb1: - return + ^bb1: + return + } } return - -^bb1: - return } // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index) diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 45ee03fa31d25f..253163f2af9110 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -756,12 +756,15 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 // CHECK-LABEL: func @nested_isolated_region func.func @nested_isolated_region() { + // CHECK-NEXT: builtin.module { // CHECK-NEXT: func @isolated_op // CHECK-NEXT: arith.constant 2 - func.func @isolated_op() { - %0 = arith.constant 1 : i32 - %2 = arith.addi %0, %0 : i32 - "foo.yield"(%2) : (i32) -> () + builtin.module { + func.func @isolated_op() { + %0 = arith.constant 1 : i32 + %2 = arith.addi %0, %0 : i32 + "foo.yield"(%2) : (i32) -> () + } } // CHECK: "foo.unknown_region" diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index c764d2b9bd57d8..11a33102684733 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -228,11 +228,14 @@ func.func @nested_isolated() -> i32 { // CHECK-NEXT: arith.constant 1 %0 = arith.constant 1 : i32 + // CHECK-NEXT: builtin.module // CHECK-NEXT: @nested_func - func.func @nested_func() { - // CHECK-NEXT: arith.constant 1 - %foo = arith.constant 1 : i32 - "foo.yield"(%foo) : (i32) -> () + builtin.module { + func.func @nested_func() { + // CHECK-NEXT: arith.constant 1 + %foo = arith.constant 1 : i32 + "foo.yield"(%foo) : (i32) -> () + } } // CHECK: "foo.region" diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 74f312e8144a02..5f1148cac65012 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -37,9 +37,11 @@ func.func @recursively_legal_invalid_op() { } /// Operation that is dynamically legal, i.e. the function has a pattern /// applied to legalize the argument type before it becomes recursively legal. - func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} { - %ignored = "test.illegal_op_f"() : () -> (i32) - "test.return"() : () -> () + builtin.module { + func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} { + %ignored = "test.illegal_op_f"() : () -> (i32) + "test.return"() : () -> () + } } "test.return"() : () -> () diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index acbf463113a6d5..28ef0f2ef3e25c 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -167,28 +167,15 @@ def testValuePrintAsOperand(): print(value2) topFn = func.FuncOp("test", ([i32, i32], [])) - entry_block1 = Block.create_at_start(topFn.operation.regions[0], [i32, i32]) + entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32]) - with InsertionPoint(entry_block1): + with InsertionPoint(entry_block): value3 = Operation.create("custom.op3", results=[i32]).results[0] # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) print(value3) value4 = Operation.create("custom.op4", results=[i32]).results[0] # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) print(value4) - - f = func.FuncOp("test", ([i32, i32], [])) - entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32]) - with InsertionPoint(entry_block2): - value5 = Operation.create("custom.op5", results=[i32]).results[0] - # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32) - print(value5) - value6 = Operation.create("custom.op6", results=[i32]).results[0] - # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32) - print(value6) - - func.ReturnOp([]) - func.ReturnOp([]) # CHECK: %[[VAL1]] @@ -215,20 +202,10 @@ def testValuePrintAsOperand(): # CHECK: %1 print(value4.get_name(use_local_scope=True)) - # CHECK: %[[VAL5]] - print(value5.get_name()) - # CHECK: %[[VAL6]] - print(value6.get_name()) - # CHECK: %[[ARG0:.*]] - print(entry_block1.arguments[0].get_name()) + print(entry_block.arguments[0].get_name()) # CHECK: %[[ARG1:.*]] - print(entry_block1.arguments[1].get_name()) - - # CHECK: %[[ARG2:.*]] - print(entry_block2.arguments[0].get_name()) - # CHECK: %[[ARG3:.*]] - print(entry_block2.arguments[1].get_name()) + print(entry_block.arguments[1].get_name()) # CHECK: module { # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 @@ -236,11 +213,6 @@ def testValuePrintAsOperand(): # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 - # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) { - # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32 - # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32 - # CHECK: return - # CHECK: } # CHECK: return # CHECK: } # CHECK: }