blob: 1ced01179dd825260f7d4eb0056f7260de7c126c [file] [log] [blame] [edit]
//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the NVGPU dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::nvgpu;
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
void nvgpu::NVGPUDialect::initialize() {
addTypes<DeviceAsyncTokenType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
>();
}
Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
// Parse the main keyword for the type.
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
MLIRContext *context = getContext();
// Handle 'device async token' types.
if (keyword == "device.async.token")
return DeviceAsyncTokenType::get(context);
parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
return Type();
}
void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
.Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
}
//===----------------------------------------------------------------------===//
// NVGPU_DeviceAsyncCopyOp
//===----------------------------------------------------------------------===//
/// Return true if the last dimension of the MemRefType has unit stride. Also
/// return true for memrefs with no strides.
static bool isLastMemrefDimUnitStride(MemRefType type) {
int64_t offset;
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(type, strides, offset))) {
return false;
}
return strides.back() == 1;
}
LogicalResult DeviceAsyncCopyOp::verify() {
auto srcMemref = getSrc().getType().cast<MemRefType>();
auto dstMemref = getDst().getType().cast<MemRefType>();
unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
if (!isLastMemrefDimUnitStride(srcMemref))
return emitError("source memref most minor dim must have unit stride");
if (!isLastMemrefDimUnitStride(dstMemref))
return emitError("destination memref most minor dim must have unit stride");
if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
return emitError("destination memref must have memory space ")
<< workgroupAddressSpace;
if (dstMemref.getElementType() != srcMemref.getElementType())
return emitError("source and destination must have the same element type");
if (size_t(srcMemref.getRank()) != getSrcIndices().size())
return emitOpError() << "expected " << srcMemref.getRank()
<< " source indices, got " << getSrcIndices().size();
if (size_t(dstMemref.getRank()) != getDstIndices().size())
return emitOpError() << "expected " << dstMemref.getRank()
<< " destination indices, got "
<< getDstIndices().size();
return success();
}
//===----------------------------------------------------------------------===//
// NVGPU_MmaSyncOp
//===----------------------------------------------------------------------===//
LogicalResult MmaSyncOp::verify() {
// Fundamental tensor core mma.sync op
// For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
// operation is of shape: 8-by-8-by-128b. F64 is an exception. The
// verification for mma.sync covering various shapes and data types is based
// on the fundamental tensor core operionation.
constexpr int kThreads = 32; // 32 threads per warp
int64_t shapeM = 8;
int64_t shapeN = 8;
int64_t shapeK; // set based on data type (128b for all data types except F64)
// Number of elements A, B, and C per thread per fundamental tensor core tile
int64_t numElementA; // set based on data type (32b except F64)
int64_t numElementB; // set based on data type (32b except F64)
int64_t numElementC{2}; // two accumulator elements per fundamental tile
// nvgpu.mma.sync vector operands (per thread)
auto aVector = getMatrixA().getType().cast<VectorType>();
auto bVector = getMatrixB().getType().cast<VectorType>();
auto cVector = getMatrixC().getType().cast<VectorType>();
// vector shapes
ArrayRef<int64_t> aShape = aVector.getShape();
ArrayRef<int64_t> bShape = bVector.getShape();
ArrayRef<int64_t> cShape = cVector.getShape();
// vector element type
Type aType = aVector.getElementType();
// nvgpu.mma.sync shape (per 32 threads or per warp)
int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
if (aType.isF64()) {
// exception to 8-by-8-128b fundamental tensor core tile size
shapeK = 4;
numElementA = 1;
numElementB = 1;
} else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
aType.isInteger(8) || aType.isInteger(4)) {
// 8-by-8-128b fundamental tensor core tile size
int operandBitwidth = aType.getIntOrFloatBitWidth();
shapeK = 128 / operandBitwidth; // 128b wide shapeK
numElementA = 32 / operandBitwidth; // 32b wide operand A
numElementB = 32 / operandBitwidth; // 32b wide operand B
} else {
return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
"supported by nvgpu.mma.sync";
}
//
// Basic verification
//
// verify warp-wide size for vector a
if (aShape[0] * aShape[1] * kThreads != m * k)
return emitOpError() << "expected " << m * k
<< " warp-wide matrix A elements";
// verify warp-wide size for vector b
if (bShape[0] * bShape[1] * kThreads != k * n)
return emitOpError() << "expected " << k * n
<< " warp-wide matrix B elements";
// verify warp-wide size for vector c
if (cShape[0] * cShape[1] * kThreads != m * n)
return emitOpError() << "expected " << m * n
<< " warp-wide matrix C elements";
//
// Extended verification
//
// tiles of fundamental tensor core operations
int64_t mTile = m / shapeM;
int64_t nTile = n / shapeN;
int64_t kTile = k / shapeK;
// verify shape of aVector
if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
<< " x " << numElementA << ")";
// verify shape of bVector
if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
<< " x " << numElementB << ")";
// verify shape of cVector
if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
<< " x " << numElementC << ")";
return success();
}
//===----------------------------------------------------------------------===//
// NVGPU_LdMatrixOp
//===----------------------------------------------------------------------===//
LogicalResult LdMatrixOp::verify() {
// ldmatrix reads data from source in shared memory
auto srcMemref = getSrcMemref().getType().cast<MemRefType>();
// ldmatrix writes data to result/destination in vector registers
auto resVector = getRes().getType().cast<VectorType>();
// vector register shape, element type, and bitwidth
ArrayRef<int64_t> resShape = resVector.getShape();
Type resType = resVector.getElementType();
int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
// ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
int64_t numElementsPer32b = 32 / elementBitWidth;
// number of 8-by-8 tiles
int64_t numTiles = getNumTiles();
// transpose elements in vector registers at 16b granularity when true
bool isTranspose = getTranspose();
// address space id for shared memory
unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
//
// verification
//
if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace))
return emitError()
<< "expected nvgpu.ldmatrix srcMemref must have memory space "
<< smemAddressSpace;
if (elementBitWidth > 32)
return emitError() << "nvgpu.ldmatrix works for 32b or lower";
if (isTranspose && !(elementBitWidth == 16))
return emitError()
<< "nvgpu.ldmatrix transpose works only at 16b granularity";
if (!(resShape[1] == numElementsPer32b))
return emitError() << "expected vector register shape[1] = "
<< numElementsPer32b;
if (!(resShape[0] == numTiles))
return emitError()
<< "expected vector register shape[0] and numTiles to match";
return success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"