| # RUN: %PYTHON %s | FileCheck %s |
| |
| from mlir.ir import * |
| from mlir.dialects.pdl import * |
| |
| |
| def constructAndPrintInModule(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| f() |
| print(module) |
| return f |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @operations : benefit(1) { |
| # CHECK: %0 = attribute |
| # CHECK: %1 = type |
| # CHECK: %2 = operation {"attr" = %0} -> (%1 : !pdl.type) |
| # CHECK: %3 = result 0 of %2 |
| # CHECK: %4 = operand |
| # CHECK: %5 = operation(%3, %4 : !pdl.value, !pdl.value) |
| # CHECK: rewrite %5 with "rewriter" |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_operations(): |
| pattern = PatternOp(1, "operations") |
| with InsertionPoint(pattern.body): |
| attr = AttributeOp() |
| ty = TypeOp() |
| op0 = OperationOp(attributes={"attr": attr}, types=[ty]) |
| op0_result = ResultOp(op0, 0) |
| input = OperandOp() |
| root = OperationOp(args=[op0_result, input]) |
| RewriteOp(root, "rewriter") |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_with_args : benefit(1) { |
| # CHECK: %0 = operand |
| # CHECK: %1 = operation(%0 : !pdl.value) |
| # CHECK: rewrite %1 with "rewriter"(%0 : !pdl.value) |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_with_args(): |
| pattern = PatternOp(1, "rewrite_with_args") |
| with InsertionPoint(pattern.body): |
| input = OperandOp() |
| root = OperationOp(args=[input]) |
| RewriteOp(root, "rewriter", args=[input]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { |
| # CHECK: %0 = operand |
| # CHECK: %1 = operand |
| # CHECK: %2 = type |
| # CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) |
| # CHECK: %4 = result 0 of %3 |
| # CHECK: %5 = operation(%4 : !pdl.value) |
| # CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) |
| # CHECK: %7 = result 0 of %6 |
| # CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) |
| # CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation) |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_multi_root_optimal(): |
| pattern = PatternOp(1, "rewrite_multi_root_optimal") |
| with InsertionPoint(pattern.body): |
| input1 = OperandOp() |
| input2 = OperandOp() |
| ty = TypeOp() |
| op1 = OperationOp(args=[input1], types=[ty]) |
| val1 = ResultOp(op1, 0) |
| root1 = OperationOp(args=[val1]) |
| op2 = OperationOp(args=[input2], types=[ty]) |
| val2 = ResultOp(op2, 0) |
| root2 = OperationOp(args=[val1, val2]) |
| RewriteOp(name="rewriter", args=[root1, root2]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { |
| # CHECK: %0 = operand |
| # CHECK: %1 = operand |
| # CHECK: %2 = type |
| # CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) |
| # CHECK: %4 = result 0 of %3 |
| # CHECK: %5 = operation(%4 : !pdl.value) |
| # CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) |
| # CHECK: %7 = result 0 of %6 |
| # CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) |
| # CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation) |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_multi_root_forced(): |
| pattern = PatternOp(1, "rewrite_multi_root_forced") |
| with InsertionPoint(pattern.body): |
| input1 = OperandOp() |
| input2 = OperandOp() |
| ty = TypeOp() |
| op1 = OperationOp(args=[input1], types=[ty]) |
| val1 = ResultOp(op1, 0) |
| root1 = OperationOp(args=[val1]) |
| op2 = OperationOp(args=[input2], types=[ty]) |
| val2 = ResultOp(op2, 0) |
| root2 = OperationOp(args=[val1, val2]) |
| RewriteOp(root1, name="rewriter", args=[root2]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_add_body : benefit(1) { |
| # CHECK: %0 = type : i32 |
| # CHECK: %1 = type |
| # CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) |
| # CHECK: rewrite %2 { |
| # CHECK: %3 = type |
| # CHECK: %4 = operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type) |
| # CHECK: replace %2 with %4 |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_add_body(): |
| pattern = PatternOp(1, "rewrite_add_body") |
| with InsertionPoint(pattern.body): |
| ty1 = TypeOp(IntegerType.get_signless(32)) |
| ty2 = TypeOp() |
| root = OperationOp(types=[ty1, ty2]) |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| ty3 = TypeOp() |
| newOp = OperationOp(name="foo.op", types=[ty1, ty3]) |
| ReplaceOp(root, with_op=newOp) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_type : benefit(1) { |
| # CHECK: %0 = type : i32 |
| # CHECK: %1 = type |
| # CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) |
| # CHECK: rewrite %2 { |
| # CHECK: %3 = operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type) |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_type(): |
| pattern = PatternOp(1, "rewrite_type") |
| with InsertionPoint(pattern.body): |
| ty1 = TypeOp(IntegerType.get_signless(32)) |
| ty2 = TypeOp() |
| root = OperationOp(types=[ty1, ty2]) |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| newOp = OperationOp(name="foo.op", types=[ty1, ty2]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_types : benefit(1) { |
| # CHECK: %0 = types |
| # CHECK: %1 = operation -> (%0 : !pdl.range<type>) |
| # CHECK: rewrite %1 { |
| # CHECK: %2 = types : [i32, i64] |
| # CHECK: %3 = operation "foo.op" -> (%0, %2 : !pdl.range<type>, !pdl.range<type>) |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_types(): |
| pattern = PatternOp(1, "rewrite_types") |
| with InsertionPoint(pattern.body): |
| types = TypesOp() |
| root = OperationOp(types=[types]) |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| otherTypes = TypesOp( |
| [IntegerType.get_signless(32), IntegerType.get_signless(64)] |
| ) |
| newOp = OperationOp(name="foo.op", types=[types, otherTypes]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @rewrite_operands : benefit(1) { |
| # CHECK: %0 = types |
| # CHECK: %1 = operands : %0 |
| # CHECK: %2 = operation(%1 : !pdl.range<value>) |
| # CHECK: rewrite %2 { |
| # CHECK: %3 = operation "foo.op" -> (%0 : !pdl.range<type>) |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_rewrite_operands(): |
| pattern = PatternOp(1, "rewrite_operands") |
| with InsertionPoint(pattern.body): |
| types = TypesOp() |
| operands = OperandsOp(types) |
| root = OperationOp(args=[operands]) |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| newOp = OperationOp(name="foo.op", types=[types]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @native_rewrite : benefit(1) { |
| # CHECK: %0 = operation |
| # CHECK: rewrite %0 { |
| # CHECK: apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation) |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_native_rewrite(): |
| pattern = PatternOp(1, "native_rewrite") |
| with InsertionPoint(pattern.body): |
| root = OperationOp() |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @attribute_with_value : benefit(1) { |
| # CHECK: %0 = operation |
| # CHECK: rewrite %0 { |
| # CHECK: %1 = attribute = "value" |
| # CHECK: apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute) |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_attribute_with_value(): |
| pattern = PatternOp(1, "attribute_with_value") |
| with InsertionPoint(pattern.body): |
| root = OperationOp() |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| attr = AttributeOp(value=Attribute.parse('"value"')) |
| ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @erase : benefit(1) { |
| # CHECK: %0 = operation |
| # CHECK: rewrite %0 { |
| # CHECK: erase %0 |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_erase(): |
| pattern = PatternOp(1, "erase") |
| with InsertionPoint(pattern.body): |
| root = OperationOp() |
| rewrite = RewriteOp(root) |
| with InsertionPoint(rewrite.add_body()): |
| EraseOp(root) |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern @operation_results : benefit(1) { |
| # CHECK: %0 = types |
| # CHECK: %1 = operation -> (%0 : !pdl.range<type>) |
| # CHECK: %2 = results of %1 |
| # CHECK: %3 = operation(%2 : !pdl.range<value>) |
| # CHECK: rewrite %3 with "rewriter" |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_operation_results(): |
| valueRange = RangeType.get(ValueType.get()) |
| pattern = PatternOp(1, "operation_results") |
| with InsertionPoint(pattern.body): |
| types = TypesOp() |
| inputOp = OperationOp(types=[types]) |
| results = ResultsOp(valueRange, inputOp) |
| root = OperationOp(args=[results]) |
| RewriteOp(root, name="rewriter") |
| |
| |
| # CHECK: module { |
| # CHECK: pdl.pattern : benefit(1) { |
| # CHECK: %0 = type |
| # CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type) |
| # CHECK: %1 = operation -> (%0 : !pdl.type) |
| # CHECK: rewrite %1 with "rewrite" |
| # CHECK: } |
| # CHECK: } |
| @constructAndPrintInModule |
| def test_apply_native_constraint(): |
| pattern = PatternOp(1) |
| with InsertionPoint(pattern.body): |
| resultType = TypeOp() |
| ApplyNativeConstraintOp("typeConstraint", args=[resultType]) |
| root = OperationOp(types=[resultType]) |
| RewriteOp(root, name="rewrite") |