| //===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Print MCInst instructions to .ptx format. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "MCTargetDesc/NVPTXInstPrinter.h" |
| #include "NVPTX.h" |
| #include "NVPTXUtilities.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/IR/NVVMIntrinsicUtils.h" |
| #include "llvm/MC/MCExpr.h" |
| #include "llvm/MC/MCInst.h" |
| #include "llvm/MC/MCInstrInfo.h" |
| #include "llvm/MC/MCSubtargetInfo.h" |
| #include "llvm/MC/MCSymbol.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include <cctype> |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "asm-printer" |
| |
| #include "NVPTXGenAsmWriter.inc" |
| |
| NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII, |
| const MCRegisterInfo &MRI) |
| : MCInstPrinter(MAI, MII, MRI) {} |
| |
| void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) { |
| // Decode the virtual register |
| // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister |
| unsigned RCId = (Reg.id() >> 28); |
| switch (RCId) { |
| default: report_fatal_error("Bad virtual register encoding"); |
| case 0: |
| // This is actually a physical register, so defer to the autogenerated |
| // register printer |
| OS << getRegisterName(Reg); |
| return; |
| case 1: |
| OS << "%p"; |
| break; |
| case 2: |
| OS << "%rs"; |
| break; |
| case 3: |
| OS << "%r"; |
| break; |
| case 4: |
| OS << "%rd"; |
| break; |
| case 5: |
| OS << "%f"; |
| break; |
| case 6: |
| OS << "%fd"; |
| break; |
| case 7: |
| OS << "%rq"; |
| break; |
| } |
| |
| unsigned VReg = Reg.id() & 0x0FFFFFFF; |
| OS << VReg; |
| } |
| |
| void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address, |
| StringRef Annot, const MCSubtargetInfo &STI, |
| raw_ostream &OS) { |
| printInstruction(MI, Address, OS); |
| |
| // Next always print the annotation. |
| printAnnotation(OS, Annot); |
| } |
| |
| void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo, |
| raw_ostream &O) { |
| const MCOperand &Op = MI->getOperand(OpNo); |
| if (Op.isReg()) { |
| unsigned Reg = Op.getReg(); |
| printRegName(O, Reg); |
| } else if (Op.isImm()) { |
| markup(O, Markup::Immediate) << formatImm(Op.getImm()); |
| } else { |
| assert(Op.isExpr() && "Unknown operand kind in printOperand"); |
| Op.getExpr()->print(O, &MAI); |
| } |
| } |
| |
| void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O, |
| const char *M) { |
| const MCOperand &MO = MI->getOperand(OpNum); |
| int64_t Imm = MO.getImm(); |
| llvm::StringRef Modifier(M); |
| |
| if (Modifier == "ftz") { |
| // FTZ flag |
| if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG) |
| O << ".ftz"; |
| return; |
| } else if (Modifier == "sat") { |
| // SAT flag |
| if (Imm & NVPTX::PTXCvtMode::SAT_FLAG) |
| O << ".sat"; |
| return; |
| } else if (Modifier == "relu") { |
| // RELU flag |
| if (Imm & NVPTX::PTXCvtMode::RELU_FLAG) |
| O << ".relu"; |
| return; |
| } else if (Modifier == "base") { |
| // Default operand |
| switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) { |
| default: |
| return; |
| case NVPTX::PTXCvtMode::NONE: |
| return; |
| case NVPTX::PTXCvtMode::RNI: |
| O << ".rni"; |
| return; |
| case NVPTX::PTXCvtMode::RZI: |
| O << ".rzi"; |
| return; |
| case NVPTX::PTXCvtMode::RMI: |
| O << ".rmi"; |
| return; |
| case NVPTX::PTXCvtMode::RPI: |
| O << ".rpi"; |
| return; |
| case NVPTX::PTXCvtMode::RN: |
| O << ".rn"; |
| return; |
| case NVPTX::PTXCvtMode::RZ: |
| O << ".rz"; |
| return; |
| case NVPTX::PTXCvtMode::RM: |
| O << ".rm"; |
| return; |
| case NVPTX::PTXCvtMode::RP: |
| O << ".rp"; |
| return; |
| case NVPTX::PTXCvtMode::RNA: |
| O << ".rna"; |
| return; |
| } |
| } |
| llvm_unreachable("Invalid conversion modifier"); |
| } |
| |
| void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O, |
| const char *M) { |
| const MCOperand &MO = MI->getOperand(OpNum); |
| int64_t Imm = MO.getImm(); |
| llvm::StringRef Modifier(M); |
| |
| if (Modifier == "ftz") { |
| // FTZ flag |
| if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG) |
| O << ".ftz"; |
| return; |
| } else if (Modifier == "base") { |
| switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) { |
| default: |
| return; |
| case NVPTX::PTXCmpMode::EQ: |
| O << ".eq"; |
| return; |
| case NVPTX::PTXCmpMode::NE: |
| O << ".ne"; |
| return; |
| case NVPTX::PTXCmpMode::LT: |
| O << ".lt"; |
| return; |
| case NVPTX::PTXCmpMode::LE: |
| O << ".le"; |
| return; |
| case NVPTX::PTXCmpMode::GT: |
| O << ".gt"; |
| return; |
| case NVPTX::PTXCmpMode::GE: |
| O << ".ge"; |
| return; |
| case NVPTX::PTXCmpMode::LO: |
| O << ".lo"; |
| return; |
| case NVPTX::PTXCmpMode::LS: |
| O << ".ls"; |
| return; |
| case NVPTX::PTXCmpMode::HI: |
| O << ".hi"; |
| return; |
| case NVPTX::PTXCmpMode::HS: |
| O << ".hs"; |
| return; |
| case NVPTX::PTXCmpMode::EQU: |
| O << ".equ"; |
| return; |
| case NVPTX::PTXCmpMode::NEU: |
| O << ".neu"; |
| return; |
| case NVPTX::PTXCmpMode::LTU: |
| O << ".ltu"; |
| return; |
| case NVPTX::PTXCmpMode::LEU: |
| O << ".leu"; |
| return; |
| case NVPTX::PTXCmpMode::GTU: |
| O << ".gtu"; |
| return; |
| case NVPTX::PTXCmpMode::GEU: |
| O << ".geu"; |
| return; |
| case NVPTX::PTXCmpMode::NUM: |
| O << ".num"; |
| return; |
| case NVPTX::PTXCmpMode::NotANumber: |
| O << ".nan"; |
| return; |
| } |
| } |
| llvm_unreachable("Empty Modifier"); |
| } |
| |
| void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum, |
| raw_ostream &O, const char *M) { |
| llvm::StringRef Modifier(M); |
| const MCOperand &MO = MI->getOperand(OpNum); |
| int Imm = (int)MO.getImm(); |
| if (Modifier == "sem") { |
| auto Ordering = NVPTX::Ordering(Imm); |
| switch (Ordering) { |
| case NVPTX::Ordering::NotAtomic: |
| return; |
| case NVPTX::Ordering::Relaxed: |
| O << ".relaxed"; |
| return; |
| case NVPTX::Ordering::Acquire: |
| O << ".acquire"; |
| return; |
| case NVPTX::Ordering::Release: |
| O << ".release"; |
| return; |
| case NVPTX::Ordering::Volatile: |
| O << ".volatile"; |
| return; |
| case NVPTX::Ordering::RelaxedMMIO: |
| O << ".mmio.relaxed"; |
| return; |
| default: |
| report_fatal_error(formatv( |
| "NVPTX LdStCode Printer does not support \"{}\" sem modifier. " |
| "Loads/Stores cannot be AcquireRelease or SequentiallyConsistent.", |
| OrderingToString(Ordering))); |
| } |
| } else if (Modifier == "scope") { |
| auto S = NVPTX::Scope(Imm); |
| switch (S) { |
| case NVPTX::Scope::Thread: |
| return; |
| case NVPTX::Scope::System: |
| O << ".sys"; |
| return; |
| case NVPTX::Scope::Block: |
| O << ".cta"; |
| return; |
| case NVPTX::Scope::Cluster: |
| O << ".cluster"; |
| return; |
| case NVPTX::Scope::Device: |
| O << ".gpu"; |
| return; |
| } |
| report_fatal_error( |
| formatv("NVPTX LdStCode Printer does not support \"{}\" sco modifier.", |
| ScopeToString(S))); |
| } else if (Modifier == "addsp") { |
| auto A = NVPTX::AddressSpace(Imm); |
| switch (A) { |
| case NVPTX::AddressSpace::Generic: |
| return; |
| case NVPTX::AddressSpace::Global: |
| case NVPTX::AddressSpace::Const: |
| case NVPTX::AddressSpace::Shared: |
| case NVPTX::AddressSpace::Param: |
| case NVPTX::AddressSpace::Local: |
| O << "." << A; |
| return; |
| } |
| report_fatal_error(formatv( |
| "NVPTX LdStCode Printer does not support \"{}\" addsp modifier.", |
| AddressSpaceToString(A))); |
| } else if (Modifier == "sign") { |
| switch (Imm) { |
| case NVPTX::PTXLdStInstCode::Signed: |
| O << "s"; |
| return; |
| case NVPTX::PTXLdStInstCode::Unsigned: |
| O << "u"; |
| return; |
| case NVPTX::PTXLdStInstCode::Untyped: |
| O << "b"; |
| return; |
| case NVPTX::PTXLdStInstCode::Float: |
| O << "f"; |
| return; |
| default: |
| llvm_unreachable("Unknown register type"); |
| } |
| } else if (Modifier == "vec") { |
| switch (Imm) { |
| case NVPTX::PTXLdStInstCode::V2: |
| O << ".v2"; |
| return; |
| case NVPTX::PTXLdStInstCode::V4: |
| O << ".v4"; |
| return; |
| } |
| // TODO: evaluate whether cases not covered by this switch are bugs |
| return; |
| } |
| llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str()); |
| } |
| |
| void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O, |
| const char *M) { |
| const MCOperand &MO = MI->getOperand(OpNum); |
| int Imm = (int)MO.getImm(); |
| llvm::StringRef Modifier(M); |
| if (Modifier.empty() || Modifier == "version") { |
| O << Imm; // Just print out PTX version |
| return; |
| } else if (Modifier == "aligned") { |
| // PTX63 requires '.aligned' in the name of the instruction. |
| if (Imm >= 63) |
| O << ".aligned"; |
| return; |
| } |
| llvm_unreachable("Unknown Modifier"); |
| } |
| |
| void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, |
| raw_ostream &O, const char *M) { |
| printOperand(MI, OpNum, O); |
| llvm::StringRef Modifier(M); |
| |
| if (Modifier == "add") { |
| O << ", "; |
| printOperand(MI, OpNum + 1, O); |
| } else { |
| if (MI->getOperand(OpNum + 1).isImm() && |
| MI->getOperand(OpNum + 1).getImm() == 0) |
| return; // don't print ',0' or '+0' |
| O << "+"; |
| printOperand(MI, OpNum + 1, O); |
| } |
| } |
| |
| void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum, |
| raw_ostream &O, const char *Modifier) { |
| auto &Op = MI->getOperand(OpNum); |
| assert(Op.isImm() && "Invalid operand"); |
| if (Op.getImm() != 0) { |
| O << "+"; |
| printOperand(MI, OpNum, O); |
| } |
| } |
| |
| void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, |
| raw_ostream &O, const char *Modifier) { |
| int64_t Imm = MI->getOperand(OpNum).getImm(); |
| O << formatHex(Imm) << "U"; |
| } |
| |
| void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum, |
| raw_ostream &O, const char *Modifier) { |
| const MCOperand &Op = MI->getOperand(OpNum); |
| assert(Op.isExpr() && "Call prototype is not an MCExpr?"); |
| const MCExpr *Expr = Op.getExpr(); |
| const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol(); |
| O << Sym.getName(); |
| } |
| |
| void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum, |
| raw_ostream &O, const char *Modifier) { |
| const MCOperand &MO = MI->getOperand(OpNum); |
| int64_t Imm = MO.getImm(); |
| |
| switch (Imm) { |
| default: |
| return; |
| case NVPTX::PTXPrmtMode::NONE: |
| return; |
| case NVPTX::PTXPrmtMode::F4E: |
| O << ".f4e"; |
| return; |
| case NVPTX::PTXPrmtMode::B4E: |
| O << ".b4e"; |
| return; |
| case NVPTX::PTXPrmtMode::RC8: |
| O << ".rc8"; |
| return; |
| case NVPTX::PTXPrmtMode::ECL: |
| O << ".ecl"; |
| return; |
| case NVPTX::PTXPrmtMode::ECR: |
| O << ".ecr"; |
| return; |
| case NVPTX::PTXPrmtMode::RC16: |
| O << ".rc16"; |
| return; |
| } |
| } |
| |
| void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum, |
| raw_ostream &O, |
| const char *Modifier) { |
| const MCOperand &MO = MI->getOperand(OpNum); |
| using RedTy = llvm::nvvm::TMAReductionOp; |
| |
| switch (static_cast<RedTy>(MO.getImm())) { |
| case RedTy::ADD: |
| O << ".add"; |
| return; |
| case RedTy::MIN: |
| O << ".min"; |
| return; |
| case RedTy::MAX: |
| O << ".max"; |
| return; |
| case RedTy::INC: |
| O << ".inc"; |
| return; |
| case RedTy::DEC: |
| O << ".dec"; |
| return; |
| case RedTy::AND: |
| O << ".and"; |
| return; |
| case RedTy::OR: |
| O << ".or"; |
| return; |
| case RedTy::XOR: |
| O << ".xor"; |
| return; |
| } |
| llvm_unreachable( |
| "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode"); |
| } |