blob: d34f45fcac0087509faed8bbcda82746941c892b [file] [log] [blame]
//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Print MCInst instructions to .ptx format.
//
//===----------------------------------------------------------------------===//
#include "MCTargetDesc/NVPTXInstPrinter.h"
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstrInfo.h"
#include "llvm/MC/MCSubtargetInfo.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <cctype>
using namespace llvm;
#define DEBUG_TYPE "asm-printer"
#include "NVPTXGenAsmWriter.inc"
NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
const MCRegisterInfo &MRI)
: MCInstPrinter(MAI, MII, MRI) {}
void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
// Decode the virtual register
// Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
unsigned RCId = (Reg.id() >> 28);
switch (RCId) {
default: report_fatal_error("Bad virtual register encoding");
case 0:
// This is actually a physical register, so defer to the autogenerated
// register printer
OS << getRegisterName(Reg);
return;
case 1:
OS << "%p";
break;
case 2:
OS << "%rs";
break;
case 3:
OS << "%r";
break;
case 4:
OS << "%rd";
break;
case 5:
OS << "%f";
break;
case 6:
OS << "%fd";
break;
case 7:
OS << "%rq";
break;
}
unsigned VReg = Reg.id() & 0x0FFFFFFF;
OS << VReg;
}
void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
StringRef Annot, const MCSubtargetInfo &STI,
raw_ostream &OS) {
printInstruction(MI, Address, OS);
// Next always print the annotation.
printAnnotation(OS, Annot);
}
void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
raw_ostream &O) {
const MCOperand &Op = MI->getOperand(OpNo);
if (Op.isReg()) {
unsigned Reg = Op.getReg();
printRegName(O, Reg);
} else if (Op.isImm()) {
markup(O, Markup::Immediate) << formatImm(Op.getImm());
} else {
assert(Op.isExpr() && "Unknown operand kind in printOperand");
Op.getExpr()->print(O, &MAI);
}
}
void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *M) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();
llvm::StringRef Modifier(M);
if (Modifier == "ftz") {
// FTZ flag
if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
O << ".ftz";
return;
} else if (Modifier == "sat") {
// SAT flag
if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
O << ".sat";
return;
} else if (Modifier == "relu") {
// RELU flag
if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
O << ".relu";
return;
} else if (Modifier == "base") {
// Default operand
switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
default:
return;
case NVPTX::PTXCvtMode::NONE:
return;
case NVPTX::PTXCvtMode::RNI:
O << ".rni";
return;
case NVPTX::PTXCvtMode::RZI:
O << ".rzi";
return;
case NVPTX::PTXCvtMode::RMI:
O << ".rmi";
return;
case NVPTX::PTXCvtMode::RPI:
O << ".rpi";
return;
case NVPTX::PTXCvtMode::RN:
O << ".rn";
return;
case NVPTX::PTXCvtMode::RZ:
O << ".rz";
return;
case NVPTX::PTXCvtMode::RM:
O << ".rm";
return;
case NVPTX::PTXCvtMode::RP:
O << ".rp";
return;
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
return;
}
}
llvm_unreachable("Invalid conversion modifier");
}
void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *M) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();
llvm::StringRef Modifier(M);
if (Modifier == "ftz") {
// FTZ flag
if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
O << ".ftz";
return;
} else if (Modifier == "base") {
switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
default:
return;
case NVPTX::PTXCmpMode::EQ:
O << ".eq";
return;
case NVPTX::PTXCmpMode::NE:
O << ".ne";
return;
case NVPTX::PTXCmpMode::LT:
O << ".lt";
return;
case NVPTX::PTXCmpMode::LE:
O << ".le";
return;
case NVPTX::PTXCmpMode::GT:
O << ".gt";
return;
case NVPTX::PTXCmpMode::GE:
O << ".ge";
return;
case NVPTX::PTXCmpMode::LO:
O << ".lo";
return;
case NVPTX::PTXCmpMode::LS:
O << ".ls";
return;
case NVPTX::PTXCmpMode::HI:
O << ".hi";
return;
case NVPTX::PTXCmpMode::HS:
O << ".hs";
return;
case NVPTX::PTXCmpMode::EQU:
O << ".equ";
return;
case NVPTX::PTXCmpMode::NEU:
O << ".neu";
return;
case NVPTX::PTXCmpMode::LTU:
O << ".ltu";
return;
case NVPTX::PTXCmpMode::LEU:
O << ".leu";
return;
case NVPTX::PTXCmpMode::GTU:
O << ".gtu";
return;
case NVPTX::PTXCmpMode::GEU:
O << ".geu";
return;
case NVPTX::PTXCmpMode::NUM:
O << ".num";
return;
case NVPTX::PTXCmpMode::NotANumber:
O << ".nan";
return;
}
}
llvm_unreachable("Empty Modifier");
}
void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
raw_ostream &O, const char *M) {
llvm::StringRef Modifier(M);
const MCOperand &MO = MI->getOperand(OpNum);
int Imm = (int)MO.getImm();
if (Modifier == "sem") {
auto Ordering = NVPTX::Ordering(Imm);
switch (Ordering) {
case NVPTX::Ordering::NotAtomic:
return;
case NVPTX::Ordering::Relaxed:
O << ".relaxed";
return;
case NVPTX::Ordering::Acquire:
O << ".acquire";
return;
case NVPTX::Ordering::Release:
O << ".release";
return;
case NVPTX::Ordering::Volatile:
O << ".volatile";
return;
case NVPTX::Ordering::RelaxedMMIO:
O << ".mmio.relaxed";
return;
default:
report_fatal_error(formatv(
"NVPTX LdStCode Printer does not support \"{}\" sem modifier. "
"Loads/Stores cannot be AcquireRelease or SequentiallyConsistent.",
OrderingToString(Ordering)));
}
} else if (Modifier == "scope") {
auto S = NVPTX::Scope(Imm);
switch (S) {
case NVPTX::Scope::Thread:
return;
case NVPTX::Scope::System:
O << ".sys";
return;
case NVPTX::Scope::Block:
O << ".cta";
return;
case NVPTX::Scope::Cluster:
O << ".cluster";
return;
case NVPTX::Scope::Device:
O << ".gpu";
return;
}
report_fatal_error(
formatv("NVPTX LdStCode Printer does not support \"{}\" sco modifier.",
ScopeToString(S)));
} else if (Modifier == "addsp") {
auto A = NVPTX::AddressSpace(Imm);
switch (A) {
case NVPTX::AddressSpace::Generic:
return;
case NVPTX::AddressSpace::Global:
case NVPTX::AddressSpace::Const:
case NVPTX::AddressSpace::Shared:
case NVPTX::AddressSpace::Param:
case NVPTX::AddressSpace::Local:
O << "." << A;
return;
}
report_fatal_error(formatv(
"NVPTX LdStCode Printer does not support \"{}\" addsp modifier.",
AddressSpaceToString(A)));
} else if (Modifier == "sign") {
switch (Imm) {
case NVPTX::PTXLdStInstCode::Signed:
O << "s";
return;
case NVPTX::PTXLdStInstCode::Unsigned:
O << "u";
return;
case NVPTX::PTXLdStInstCode::Untyped:
O << "b";
return;
case NVPTX::PTXLdStInstCode::Float:
O << "f";
return;
default:
llvm_unreachable("Unknown register type");
}
} else if (Modifier == "vec") {
switch (Imm) {
case NVPTX::PTXLdStInstCode::V2:
O << ".v2";
return;
case NVPTX::PTXLdStInstCode::V4:
O << ".v4";
return;
}
// TODO: evaluate whether cases not covered by this switch are bugs
return;
}
llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
}
void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *M) {
const MCOperand &MO = MI->getOperand(OpNum);
int Imm = (int)MO.getImm();
llvm::StringRef Modifier(M);
if (Modifier.empty() || Modifier == "version") {
O << Imm; // Just print out PTX version
return;
} else if (Modifier == "aligned") {
// PTX63 requires '.aligned' in the name of the instruction.
if (Imm >= 63)
O << ".aligned";
return;
}
llvm_unreachable("Unknown Modifier");
}
void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
raw_ostream &O, const char *M) {
printOperand(MI, OpNum, O);
llvm::StringRef Modifier(M);
if (Modifier == "add") {
O << ", ";
printOperand(MI, OpNum + 1, O);
} else {
if (MI->getOperand(OpNum + 1).isImm() &&
MI->getOperand(OpNum + 1).getImm() == 0)
return; // don't print ',0' or '+0'
O << "+";
printOperand(MI, OpNum + 1, O);
}
}
void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier) {
auto &Op = MI->getOperand(OpNum);
assert(Op.isImm() && "Invalid operand");
if (Op.getImm() != 0) {
O << "+";
printOperand(MI, OpNum, O);
}
}
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier) {
int64_t Imm = MI->getOperand(OpNum).getImm();
O << formatHex(Imm) << "U";
}
void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier) {
const MCOperand &Op = MI->getOperand(OpNum);
assert(Op.isExpr() && "Call prototype is not an MCExpr?");
const MCExpr *Expr = Op.getExpr();
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
O << Sym.getName();
}
void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();
switch (Imm) {
default:
return;
case NVPTX::PTXPrmtMode::NONE:
return;
case NVPTX::PTXPrmtMode::F4E:
O << ".f4e";
return;
case NVPTX::PTXPrmtMode::B4E:
O << ".b4e";
return;
case NVPTX::PTXPrmtMode::RC8:
O << ".rc8";
return;
case NVPTX::PTXPrmtMode::ECL:
O << ".ecl";
return;
case NVPTX::PTXPrmtMode::ECR:
O << ".ecr";
return;
case NVPTX::PTXPrmtMode::RC16:
O << ".rc16";
return;
}
}
void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
raw_ostream &O,
const char *Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
using RedTy = llvm::nvvm::TMAReductionOp;
switch (static_cast<RedTy>(MO.getImm())) {
case RedTy::ADD:
O << ".add";
return;
case RedTy::MIN:
O << ".min";
return;
case RedTy::MAX:
O << ".max";
return;
case RedTy::INC:
O << ".inc";
return;
case RedTy::DEC:
O << ".dec";
return;
case RedTy::AND:
O << ".and";
return;
case RedTy::OR:
O << ".or";
return;
case RedTy::XOR:
O << ".xor";
return;
}
llvm_unreachable(
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
}