| //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the OpenMP dialect and its operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/OperationSupport.h" |
| |
| #include "llvm/ADT/SmallString.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include <cstddef> |
| |
| #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" |
| |
| using namespace mlir; |
| using namespace mlir::omp; |
| |
| OpenMPDialect::OpenMPDialect(MLIRContext *context) |
| : Dialect(getDialectNamespace(), context) { |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" |
| >(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ParallelOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse a list of operands with types. |
| /// |
| /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` |
| /// ssa-id-and-type-list ::= ssa-id-and-type | |
| /// ssa-id-and-type ',' ssa-id-and-type-list |
| /// ssa-id-and-type ::= ssa-id `:` type |
| static ParseResult |
| parseOperandAndTypeList(OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::OperandType> &operands, |
| SmallVectorImpl<Type> &types) { |
| if (parser.parseLParen()) |
| return failure(); |
| |
| do { |
| OpAsmParser::OperandType operand; |
| Type type; |
| if (parser.parseOperand(operand) || parser.parseColonType(type)) |
| return failure(); |
| operands.push_back(operand); |
| types.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (parser.parseRParen()) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { |
| p << "omp.parallel"; |
| |
| if (auto ifCond = op.if_expr_var()) |
| p << " if(" << ifCond << ")"; |
| |
| if (auto threads = op.num_threads_var()) |
| p << " num_threads(" << threads << " : " << threads.getType() << ")"; |
| |
| // Print private, firstprivate, shared and copyin parameters |
| auto printDataVars = [&p](StringRef name, OperandRange vars) { |
| if (vars.size()) { |
| p << " " << name << "("; |
| for (unsigned i = 0; i < vars.size(); ++i) { |
| std::string separator = i == vars.size() - 1 ? ")" : ", "; |
| p << vars[i] << " : " << vars[i].getType() << separator; |
| } |
| } |
| }; |
| printDataVars("private", op.private_vars()); |
| printDataVars("firstprivate", op.firstprivate_vars()); |
| printDataVars("shared", op.shared_vars()); |
| printDataVars("copyin", op.copyin_vars()); |
| |
| if (auto def = op.default_val()) |
| p << " default(" << def->drop_front(3) << ")"; |
| |
| if (auto bind = op.proc_bind_val()) |
| p << " proc_bind(" << bind << ")"; |
| |
| p.printRegion(op.getRegion()); |
| } |
| |
| /// Emit an error if the same clause is present more than once on an operation. |
| static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, |
| llvm::StringRef operation) { |
| return parser.emitError(parser.getNameLoc()) |
| << " at most one " << clause << " clause can appear on the " |
| << operation << " operation"; |
| } |
| |
| /// Parses a parallel operation. |
| /// |
| /// operation ::= `omp.parallel` clause-list |
| /// clause-list ::= clause | clause clause-list |
| /// clause ::= if | numThreads | private | firstprivate | shared | copyin | |
| /// default | procBind |
| /// if ::= `if` `(` ssa-id `)` |
| /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` |
| /// private ::= `private` operand-and-type-list |
| /// firstprivate ::= `firstprivate` operand-and-type-list |
| /// shared ::= `shared` operand-and-type-list |
| /// copyin ::= `copyin` operand-and-type-list |
| /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) |
| /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` |
| /// |
| /// Note that each clause can only appear once in the clase-list. |
| static ParseResult parseParallelOp(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::OperandType ifCond; |
| std::pair<OpAsmParser::OperandType, Type> numThreads; |
| llvm::SmallVector<OpAsmParser::OperandType, 4> privates; |
| llvm::SmallVector<Type, 4> privateTypes; |
| llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates; |
| llvm::SmallVector<Type, 4> firstprivateTypes; |
| llvm::SmallVector<OpAsmParser::OperandType, 4> shareds; |
| llvm::SmallVector<Type, 4> sharedTypes; |
| llvm::SmallVector<OpAsmParser::OperandType, 4> copyins; |
| llvm::SmallVector<Type, 4> copyinTypes; |
| std::array<int, 6> segments{0, 0, 0, 0, 0, 0}; |
| llvm::StringRef keyword; |
| bool defaultVal = false; |
| bool procBind = false; |
| |
| const int ifClausePos = 0; |
| const int numThreadsClausePos = 1; |
| const int privateClausePos = 2; |
| const int firstprivateClausePos = 3; |
| const int sharedClausePos = 4; |
| const int copyinClausePos = 5; |
| const llvm::StringRef opName = result.name.getStringRef(); |
| |
| while (succeeded(parser.parseOptionalKeyword(&keyword))) { |
| if (keyword == "if") { |
| // Fail if there was already another if condition |
| if (segments[ifClausePos]) |
| return allowedOnce(parser, "if", opName); |
| if (parser.parseLParen() || parser.parseOperand(ifCond) || |
| parser.parseRParen()) |
| return failure(); |
| segments[ifClausePos] = 1; |
| } else if (keyword == "num_threads") { |
| // fail if there was already another num_threads clause |
| if (segments[numThreadsClausePos]) |
| return allowedOnce(parser, "num_threads", opName); |
| if (parser.parseLParen() || parser.parseOperand(numThreads.first) || |
| parser.parseColonType(numThreads.second) || parser.parseRParen()) |
| return failure(); |
| segments[numThreadsClausePos] = 1; |
| } else if (keyword == "private") { |
| // fail if there was already another private clause |
| if (segments[privateClausePos]) |
| return allowedOnce(parser, "private", opName); |
| if (parseOperandAndTypeList(parser, privates, privateTypes)) |
| return failure(); |
| segments[privateClausePos] = privates.size(); |
| } else if (keyword == "firstprivate") { |
| // fail if there was already another firstprivate clause |
| if (segments[firstprivateClausePos]) |
| return allowedOnce(parser, "firstprivate", opName); |
| if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) |
| return failure(); |
| segments[firstprivateClausePos] = firstprivates.size(); |
| } else if (keyword == "shared") { |
| // fail if there was already another shared clause |
| if (segments[sharedClausePos]) |
| return allowedOnce(parser, "shared", opName); |
| if (parseOperandAndTypeList(parser, shareds, sharedTypes)) |
| return failure(); |
| segments[sharedClausePos] = shareds.size(); |
| } else if (keyword == "copyin") { |
| // fail if there was already another copyin clause |
| if (segments[copyinClausePos]) |
| return allowedOnce(parser, "copyin", opName); |
| if (parseOperandAndTypeList(parser, copyins, copyinTypes)) |
| return failure(); |
| segments[copyinClausePos] = copyins.size(); |
| } else if (keyword == "default") { |
| // fail if there was already another default clause |
| if (defaultVal) |
| return allowedOnce(parser, "default", opName); |
| defaultVal = true; |
| llvm::StringRef defval; |
| if (parser.parseLParen() || parser.parseKeyword(&defval) || |
| parser.parseRParen()) |
| return failure(); |
| llvm::SmallString<16> attrval; |
| // The def prefix is required for the attribute as "private" is a keyword |
| // in C++ |
| attrval += "def"; |
| attrval += defval; |
| auto attr = parser.getBuilder().getStringAttr(attrval); |
| result.addAttribute("default_val", attr); |
| } else if (keyword == "proc_bind") { |
| // fail if there was already another default clause |
| if (procBind) |
| return allowedOnce(parser, "proc_bind", opName); |
| procBind = true; |
| llvm::StringRef bind; |
| if (parser.parseLParen() || parser.parseKeyword(&bind) || |
| parser.parseRParen()) |
| return failure(); |
| auto attr = parser.getBuilder().getStringAttr(bind); |
| result.addAttribute("proc_bind_val", attr); |
| } else { |
| return parser.emitError(parser.getNameLoc()) |
| << keyword << " is not a valid clause for the " << opName |
| << " operation"; |
| } |
| } |
| |
| // Add if parameter |
| if (segments[ifClausePos]) { |
| parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(), |
| result.operands); |
| } |
| |
| // Add num_threads parameter |
| if (segments[numThreadsClausePos]) { |
| parser.resolveOperand(numThreads.first, numThreads.second, result.operands); |
| } |
| |
| // Add private parameters |
| if (segments[privateClausePos]) { |
| parser.resolveOperands(privates, privateTypes, privates[0].location, |
| result.operands); |
| } |
| |
| // Add firstprivate parameters |
| if (segments[firstprivateClausePos]) { |
| parser.resolveOperands(firstprivates, firstprivateTypes, |
| firstprivates[0].location, result.operands); |
| } |
| |
| // Add shared parameters |
| if (segments[sharedClausePos]) { |
| parser.resolveOperands(shareds, sharedTypes, shareds[0].location, |
| result.operands); |
| } |
| |
| // Add copyin parameters |
| if (segments[copyinClausePos]) { |
| parser.resolveOperands(copyins, copyinTypes, copyins[0].location, |
| result.operands); |
| } |
| |
| result.addAttribute("operand_segment_sizes", |
| parser.getBuilder().getI32VectorAttr(segments)); |
| |
| Region *body = result.addRegion(); |
| llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs; |
| llvm::SmallVector<Type, 4> regionArgTypes; |
| if (parser.parseRegion(*body, regionArgs, regionArgTypes)) |
| return failure(); |
| return success(); |
| } |
| |
| namespace mlir { |
| namespace omp { |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" |
| } // namespace omp |
| } // namespace mlir |