| //===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===// | 
 | // | 
 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
 | // See https://llvm.org/LICENSE.txt for license information. | 
 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
 | // | 
 | //===----------------------------------------------------------------------===// | 
 | // | 
 | // Main entry function for mlir-rewrite. | 
 | // | 
 | //===----------------------------------------------------------------------===// | 
 |  | 
 | #include "mlir/AsmParser/AsmParser.h" | 
 | #include "mlir/AsmParser/AsmParserState.h" | 
 | #include "mlir/IR/AsmState.h" | 
 | #include "mlir/IR/Dialect.h" | 
 | #include "mlir/IR/MLIRContext.h" | 
 | #include "mlir/InitAllDialects.h" | 
 | #include "mlir/Pass/Pass.h" | 
 | #include "mlir/Pass/PassManager.h" | 
 | #include "mlir/Support/FileUtilities.h" | 
 | #include "mlir/Tools/ParseUtilities.h" | 
 | #include "llvm/ADT/RewriteBuffer.h" | 
 | #include "llvm/Support/CommandLine.h" | 
 | #include "llvm/Support/InitLLVM.h" | 
 | #include "llvm/Support/LineIterator.h" | 
 | #include "llvm/Support/Regex.h" | 
 | #include "llvm/Support/SourceMgr.h" | 
 | #include "llvm/Support/ToolOutputFile.h" | 
 |  | 
 | using namespace mlir; | 
 |  | 
 | namespace mlir { | 
 | using OperationDefinition = AsmParserState::OperationDefinition; | 
 |  | 
 | /// Return the source code associated with the OperationDefinition. | 
 | SMRange getOpRange(const OperationDefinition &op) { | 
 |   const char *startOp = op.scopeLoc.Start.getPointer(); | 
 |   const char *endOp = op.scopeLoc.End.getPointer(); | 
 |  | 
 |   for (const auto &res : op.resultGroups) { | 
 |     SMRange range = res.definition.loc; | 
 |     startOp = std::min(startOp, range.Start.getPointer()); | 
 |   } | 
 |   return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)}; | 
 | } | 
 |  | 
 | /// Helper to simplify rewriting the source file. | 
 | class RewritePad { | 
 | public: | 
 |   static std::unique_ptr<RewritePad> init(StringRef inputFilename, | 
 |                                           StringRef outputFilename); | 
 |  | 
 |   /// Return the context the file was parsed into. | 
 |   MLIRContext *getContext() { return &context; } | 
 |  | 
 |   /// Return the OperationDefinition's of the operation's parsed. | 
 |   iterator_range<AsmParserState::OperationDefIterator> getOpDefs() { | 
 |     return asmState.getOpDefs(); | 
 |   } | 
 |  | 
 |   /// Insert the specified string at the specified location in the original | 
 |   /// buffer. | 
 |   void insertText(SMLoc pos, StringRef str, bool insertAfter = true) { | 
 |     rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter); | 
 |   } | 
 |  | 
 |   /// Replace the range of the source text with the corresponding string in the | 
 |   /// output. | 
 |   void replaceRange(SMRange range, StringRef str) { | 
 |     rewriteBuffer.ReplaceText(range.Start.getPointer() - start, | 
 |                               range.End.getPointer() - range.Start.getPointer(), | 
 |                               str); | 
 |   } | 
 |  | 
 |   /// Replace the range of the operation in the source text with the | 
 |   /// corresponding string in the output. | 
 |   void replaceDef(const OperationDefinition &opDef, StringRef newDef) { | 
 |     replaceRange(getOpRange(opDef), newDef); | 
 |   } | 
 |  | 
 |   /// Return the source string corresponding to the source range. | 
 |   StringRef getSourceString(SMRange range) { | 
 |     return StringRef(range.Start.getPointer(), | 
 |                      range.End.getPointer() - range.Start.getPointer()); | 
 |   } | 
 |  | 
 |   /// Return the source string corresponding to operation definition. | 
 |   StringRef getSourceString(const OperationDefinition &opDef) { | 
 |     auto range = getOpRange(opDef); | 
 |     return getSourceString(range); | 
 |   } | 
 |  | 
 |   /// Write to stream the result of applying all changes to the | 
 |   /// original buffer. | 
 |   /// Note that it isn't safe to use this function to overwrite memory mapped | 
 |   /// files in-place (PR17960). | 
 |   /// | 
 |   /// The original buffer is not actually changed. | 
 |   raw_ostream &write(raw_ostream &stream) const { | 
 |     return rewriteBuffer.write(stream); | 
 |   } | 
 |  | 
 |   /// Return lines that are purely comments. | 
 |   SmallVector<SMRange> getSingleLineComments() { | 
 |     unsigned curBuf = sourceMgr.getMainFileID(); | 
 |     const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf); | 
 |     llvm::line_iterator lineIterator(*curMB); | 
 |     SmallVector<SMRange> ret; | 
 |     for (; !lineIterator.is_at_end(); ++lineIterator) { | 
 |       StringRef trimmed = lineIterator->ltrim(); | 
 |       if (trimmed.starts_with("//")) { | 
 |         ret.emplace_back( | 
 |             SMLoc::getFromPointer(trimmed.data()), | 
 |             SMLoc::getFromPointer(trimmed.data() + trimmed.size())); | 
 |       } | 
 |     } | 
 |     return ret; | 
 |   } | 
 |  | 
 |   /// Return the IR from parsed file. | 
 |   Block *getParsed() { return &parsedIR; } | 
 |  | 
 |   /// Return the definition for the given operation, or nullptr if the given | 
 |   /// operation does not have a definition. | 
 |   const OperationDefinition &getOpDef(Operation *op) const { | 
 |     return *asmState.getOpDef(op); | 
 |   } | 
 |  | 
 | private: | 
 |   // The context and state required to parse. | 
 |   MLIRContext context; | 
 |   llvm::SourceMgr sourceMgr; | 
 |   DialectRegistry registry; | 
 |   FallbackAsmResourceMap fallbackResourceMap; | 
 |  | 
 |   // Storage of textual parsing results. | 
 |   AsmParserState asmState; | 
 |  | 
 |   // Parsed IR. | 
 |   Block parsedIR; | 
 |  | 
 |   // The RewriteBuffer  is doing most of the real work. | 
 |   llvm::RewriteBuffer rewriteBuffer; | 
 |  | 
 |   // Start of the original input, used to compute offset. | 
 |   const char *start; | 
 | }; | 
 |  | 
 | std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename, | 
 |                                              StringRef outputFilename) { | 
 |   std::unique_ptr<RewritePad> r = std::make_unique<RewritePad>(); | 
 |  | 
 |   // Register all the dialects needed. | 
 |   registerAllDialects(r->registry); | 
 |  | 
 |   // Set up the input file. | 
 |   std::string errorMessage; | 
 |   std::unique_ptr<llvm::MemoryBuffer> file = | 
 |       openInputFile(inputFilename, &errorMessage); | 
 |   if (!file) { | 
 |     llvm::errs() << errorMessage << "\n"; | 
 |     return nullptr; | 
 |   } | 
 |   r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); | 
 |  | 
 |   // Set up the MLIR context and error handling. | 
 |   r->context.appendDialectRegistry(r->registry); | 
 |  | 
 |   // Record the start of the buffer to compute offsets with. | 
 |   unsigned curBuf = r->sourceMgr.getMainFileID(); | 
 |   const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf); | 
 |   r->start = curMB->getBufferStart(); | 
 |   r->rewriteBuffer.Initialize(curMB->getBuffer()); | 
 |  | 
 |   // Parse and populate the AsmParserState. | 
 |   ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true, | 
 |                            &r->fallbackResourceMap); | 
 |   // Always allow unregistered. | 
 |   r->context.allowUnregisteredDialects(true); | 
 |   if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig, | 
 |                                 &r->asmState))) | 
 |     return nullptr; | 
 |  | 
 |   return r; | 
 | } | 
 |  | 
 | /// Return the source code associated with the operation name. | 
 | SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; } | 
 |  | 
 | /// Return whether the operation was printed using generic syntax in original | 
 | /// buffer. | 
 | bool isGeneric(const OperationDefinition &op) { | 
 |   return op.loc.Start.getPointer()[0] == '"'; | 
 | } | 
 |  | 
 | inline int asMainReturnCode(LogicalResult r) { | 
 |   return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE; | 
 | } | 
 |  | 
 | /// Reriter function to invoke. | 
 | using RewriterFunction = std::function<mlir::LogicalResult( | 
 |     mlir::RewritePad &rewriteState, llvm::raw_ostream &os)>; | 
 |  | 
 | /// Structure to group information about a rewriter (argument to invoke via | 
 | /// mlir-tblgen, description, and rewriter function). | 
 | class RewriterInfo { | 
 | public: | 
 |   /// RewriterInfo constructor should not be invoked directly, instead use | 
 |   /// RewriterRegistration or registerRewriter. | 
 |   RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter) | 
 |       : arg(arg), description(description), rewriter(std::move(rewriter)) {} | 
 |  | 
 |   /// Invokes the rewriter and returns whether the rewriter failed. | 
 |   LogicalResult invoke(mlir::RewritePad &rewriteState, raw_ostream &os) const { | 
 |     assert(rewriter && "Cannot call rewriter with null rewriter"); | 
 |     return rewriter(rewriteState, os); | 
 |   } | 
 |  | 
 |   /// Returns the command line option that may be passed to 'mlir-rewrite' to | 
 |   /// invoke this rewriter. | 
 |   StringRef getRewriterArgument() const { return arg; } | 
 |  | 
 |   /// Returns a description for the rewriter. | 
 |   StringRef getRewriterDescription() const { return description; } | 
 |  | 
 | private: | 
 |   // The argument with which to invoke the rewriter via mlir-tblgen. | 
 |   StringRef arg; | 
 |  | 
 |   // Description of the rewriter. | 
 |   StringRef description; | 
 |  | 
 |   // Rewritererator function. | 
 |   RewriterFunction rewriter; | 
 | }; | 
 |  | 
 | static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry; | 
 |  | 
 | /// Adds command line option for each registered rewriter. | 
 | struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> { | 
 |   RewriterNameParser(llvm::cl::Option &opt); | 
 |  | 
 |   void printOptionInfo(const llvm::cl::Option &o, | 
 |                        size_t globalWidth) const override; | 
 | }; | 
 |  | 
 | /// RewriterRegistration provides a global initializer that registers a rewriter | 
 | /// function. | 
 | struct RewriterRegistration { | 
 |   RewriterRegistration(StringRef arg, StringRef description, | 
 |                        const RewriterFunction &function); | 
 | }; | 
 |  | 
 | RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description, | 
 |                                            const RewriterFunction &function) { | 
 |   rewriterRegistry->emplace_back(arg, description, function); | 
 | } | 
 |  | 
 | RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt) | 
 |     : llvm::cl::parser<const RewriterInfo *>(opt) { | 
 |   for (const auto &kv : *rewriterRegistry) { | 
 |     addLiteralOption(kv.getRewriterArgument(), &kv, | 
 |                      kv.getRewriterDescription()); | 
 |   } | 
 | } | 
 |  | 
 | void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o, | 
 |                                          size_t globalWidth) const { | 
 |   RewriterNameParser *tp = const_cast<RewriterNameParser *>(this); | 
 |   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(), | 
 |                        [](const RewriterNameParser::OptionInfo *vT1, | 
 |                           const RewriterNameParser::OptionInfo *vT2) { | 
 |                          return vT1->Name.compare(vT2->Name); | 
 |                        }); | 
 |   using llvm::cl::parser; | 
 |   parser<const RewriterInfo *>::printOptionInfo(o, globalWidth); | 
 | } | 
 |  | 
 | } // namespace mlir | 
 |  | 
 | // TODO: Make these injectable too in non-global way. | 
 | static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"}; | 
 | static llvm::cl::opt<std::string> simpleRenameOpName{ | 
 |     "simple-rename-op-name", llvm::cl::desc("Name of op to match on"), | 
 |     llvm::cl::cat(clSimpleRenameCategory)}; | 
 | static llvm::cl::opt<std::string> simpleRenameMatch{ | 
 |     "simple-rename-match", llvm::cl::desc("Match string for rename"), | 
 |     llvm::cl::cat(clSimpleRenameCategory)}; | 
 | static llvm::cl::opt<std::string> simpleRenameReplace{ | 
 |     "simple-rename-replace", llvm::cl::desc("Replace string for rename"), | 
 |     llvm::cl::cat(clSimpleRenameCategory)}; | 
 |  | 
 | // Rewriter that does simple renames. | 
 | LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) { | 
 |   StringRef opName = simpleRenameOpName; | 
 |   StringRef match = simpleRenameMatch; | 
 |   StringRef replace = simpleRenameReplace; | 
 |   llvm::Regex regex(match); | 
 |  | 
 |   rewriteState.getParsed()->walk([&](Operation *op) { | 
 |     if (op->getName().getStringRef() != opName) | 
 |       return; | 
 |  | 
 |     const OperationDefinition &opDef = rewriteState.getOpDef(op); | 
 |     SMRange range = getOpRange(opDef); | 
 |     // This is a little bit overkill for simple. | 
 |     std::string str = regex.sub(replace, rewriteState.getSourceString(range)); | 
 |     rewriteState.replaceRange(range, str); | 
 |   }); | 
 |   return success(); | 
 | } | 
 |  | 
 | static mlir::RewriterRegistration rewriteSimpleRename("simple-rename", | 
 |                                                       "Perform a simple rename", | 
 |                                                       simpleRename); | 
 |  | 
 | // Rewriter that insert range markers. | 
 | LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) { | 
 |   for (const auto &it : rewriteState.getOpDefs()) { | 
 |     auto [startOp, endOp] = getOpRange(it); | 
 |  | 
 |     rewriteState.insertText(startOp, "<"); | 
 |     rewriteState.insertText(endOp, ">"); | 
 |  | 
 |     auto nameRange = getOpNameRange(it); | 
 |  | 
 |     if (isGeneric(it)) { | 
 |       rewriteState.insertText(nameRange.Start, "["); | 
 |       rewriteState.insertText(nameRange.End, "]"); | 
 |     } else { | 
 |       rewriteState.insertText(nameRange.Start, "!["); | 
 |       rewriteState.insertText(nameRange.End, "]!"); | 
 |     } | 
 |   } | 
 |  | 
 |   // Highlight all comment lines. | 
 |   // TODO: Could be replaced if this is kept in memory. | 
 |   for (auto commentLine : rewriteState.getSingleLineComments()) { | 
 |     rewriteState.insertText(commentLine.Start, "{"); | 
 |     rewriteState.insertText(commentLine.End, "}"); | 
 |   } | 
 |  | 
 |   return success(); | 
 | } | 
 |  | 
 | static mlir::RewriterRegistration | 
 |     rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges); | 
 |  | 
 | int main(int argc, char **argv) { | 
 |   static llvm::cl::opt<std::string> inputFilename( | 
 |       llvm::cl::Positional, llvm::cl::desc("<input file>"), | 
 |       llvm::cl::init("-")); | 
 |  | 
 |   static llvm::cl::opt<std::string> outputFilename( | 
 |       "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), | 
 |       llvm::cl::init("-")); | 
 |  | 
 |   llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser> | 
 |       rewriter("", llvm::cl::desc("Rewriter to run")); | 
 |  | 
 |   std::string helpHeader = "mlir-rewrite"; | 
 |  | 
 |   llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader); | 
 |  | 
 |   // If no rewriter has been selected, exit with error code. Could also just | 
 |   // return but its unlikely this was intentionally being used as `cp`. | 
 |   if (!rewriter) { | 
 |     llvm::errs() << "No rewriter selected!\n"; | 
 |     return mlir::asMainReturnCode(mlir::failure()); | 
 |   } | 
 |  | 
 |   // Set up rewrite buffer. | 
 |   auto rewriterOr = RewritePad::init(inputFilename, outputFilename); | 
 |   if (!rewriterOr) | 
 |     return mlir::asMainReturnCode(mlir::failure()); | 
 |  | 
 |   // Set up the output file. | 
 |   std::string errorMessage; | 
 |   auto output = openOutputFile(outputFilename, &errorMessage); | 
 |   if (!output) { | 
 |     llvm::errs() << errorMessage << "\n"; | 
 |     return mlir::asMainReturnCode(mlir::failure()); | 
 |   } | 
 |  | 
 |   LogicalResult result = rewriter->invoke(*rewriterOr, output->os()); | 
 |   if (succeeded(result)) { | 
 |     rewriterOr->write(output->os()); | 
 |     output->keep(); | 
 |   } | 
 |   return mlir::asMainReturnCode(result); | 
 | } |