blob: adc610349b36f0e7b4bcfe4774d65f8cc2207561 [file] [log] [blame] [edit]
//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities used to lower to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Debug.h"
#include <functional>
#define DEBUG_TYPE "mlir-spirv-lowering"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
// Convert to 32-bit integers for now. Might need a way to control this in
// future.
// TODO(ravishankarm): It is probably better to make it 64-bit integers. To
// this some support is needed in SPIR-V dialect for Conversion
// instructions. The Vulkan spec requires the builtins like
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
// SExtended to 64-bit for index computations.
return IntegerType::get(32, context);
}
// TODO(ravishankarm): This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t> getTypeNumBytes(Type t) {
if (auto integerType = t.dyn_cast<IntegerType>()) {
return integerType.getWidth() / 8;
} else if (auto floatType = t.dyn_cast<FloatType>()) {
return floatType.getWidth() / 8;
} else if (auto memRefType = t.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (!memRefType.hasStaticShape() ||
failed(getStridesAndOffset(memRefType, strides, offset))) {
return llvm::None;
}
// To get the size of the memref object in memory, the total size is the
// max(stride * dimension-size) computed for all dimensions times the size
// of the element.
auto elementSize = getTypeNumBytes(memRefType.getElementType());
if (!elementSize) {
return llvm::None;
}
auto dims = memRefType.getShape();
if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
return llvm::None;
}
int64_t memrefSize = -1;
for (auto shape : enumerate(dims)) {
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
}
return (offset + memrefSize) * elementSize.getValue();
}
// TODO: Add size computation for other types.
return llvm::None;
}
static Type convertStdType(Type type) {
// If the type is already valid in SPIR-V, directly return.
if (spirv::SPIRVDialect::isValidType(type)) {
return type;
}
if (auto indexType = type.dyn_cast<IndexType>()) {
return SPIRVTypeConverter::getIndexType(type.getContext());
}
if (auto memRefType = type.dyn_cast<MemRefType>()) {
// TODO(ravishankarm): For now only support default memory space. The memory
// space description is not set is stone within MLIR, i.e. it depends on the
// context it is being used. To map this to SPIR-V storage classes, we
// should rely on the ABI attributes, and not on the memory space. This is
// still evolving, and needs to be revisited when there is more clarity.
if (memRefType.getMemorySpace()) {
return Type();
}
auto elementType = convertStdType(memRefType.getElementType());
if (!elementType) {
return Type();
}
auto elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
return Type();
}
// TODO(ravishankarm) : Handle dynamic shapes.
if (memRefType.hasStaticShape()) {
auto arraySize = getTypeNumBytes(memRefType);
if (!arraySize) {
return Type();
}
auto arrayType = spirv::ArrayType::get(
elementType, arraySize.getValue() / elementSize.getValue(),
elementSize.getValue());
auto structType = spirv::StructType::get(arrayType, 0);
// For now initialize the storage class to StorageBuffer. This will be
// updated later based on whats passed in w.r.t to the ABI attributes.
return spirv::PointerType::get(structType,
spirv::StorageClass::StorageBuffer);
}
}
return Type();
}
Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); }
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
/// Look through all global variables in `moduleOp` and check if there is a
/// spv.globalVariable that has the same `builtin` attribute.
static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
spirv::BuiltIn builtin) {
for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(
spirv::SPIRVDialect::getAttributeName(
spirv::Decoration::BuiltIn))) {
auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
if (varBuiltIn && varBuiltIn.getValue() == builtin) {
return varOp;
}
}
}
return nullptr;
}
/// Gets name of global variable for a builtin.
static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
}
/// Gets or inserts a global variable for a builtin within a module.
static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
spirv::BuiltIn builtin, OpBuilder &builder) {
if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
return varOp;
}
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(&moduleOp.getBlock());
auto name = getBuiltinVarName(builtin);
spirv::GlobalVariableOp newVarOp;
switch (builtin) {
case spirv::BuiltIn::NumWorkgroups:
case spirv::BuiltIn::WorkgroupSize:
case spirv::BuiltIn::WorkgroupId:
case spirv::BuiltIn::LocalInvocationId:
case spirv::BuiltIn::GlobalInvocationId: {
auto ptrType = spirv::PointerType::get(
VectorType::get({3}, builder.getIntegerType(32)),
spirv::StorageClass::Input);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
break;
}
default:
emitError(loc, "unimplemented builtin variable generation for ")
<< stringifyBuiltIn(builtin);
}
builder.restoreInsertionPoint(ip);
return newVarOp;
}
/// Gets the global variable associated with a builtin and add
/// it if it doesn't exist.
Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::BuiltIn builtin,
OpBuilder &builder) {
auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
if (!moduleOp) {
op->emitError("expected operation to be within a SPIR-V module");
return nullptr;
}
spirv::GlobalVariableOp varOp =
getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder);
Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
return builder.create<spirv::LoadOp>(op->getLoc(), ptr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
}
//===----------------------------------------------------------------------===//
// Set ABI attributes for lowering entry functions.
//===----------------------------------------------------------------------===//
LogicalResult
mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo,
ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
}
funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
return success();
}
//===----------------------------------------------------------------------===//
// SPIR-V ConversionTarget
//===----------------------------------------------------------------------===//
std::unique_ptr<spirv::SPIRVConversionTarget>
spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
MLIRContext *context) {
std::unique_ptr<SPIRVConversionTarget> target(
// std::make_unique does not work here because the constructor is private.
new SPIRVConversionTarget(targetEnv, context));
SPIRVConversionTarget *targetPtr = target.get();
target->addDynamicallyLegalDialect<SPIRVDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
// We need to capture the raw pointer here because it is stable:
// target will be destroyed once this function is returned.
[targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }));
return target;
}
spirv::SPIRVConversionTarget::SPIRVConversionTarget(
spirv::TargetEnvAttr targetEnv, MLIRContext *context)
: ConversionTarget(*context),
givenVersion(static_cast<spirv::Version>(targetEnv.version().getInt())) {
for (Attribute extAttr : targetEnv.extensions())
givenExtensions.insert(
*spirv::symbolizeExtension(extAttr.cast<StringAttr>().getValue()));
for (Attribute capAttr : targetEnv.capabilities())
givenCapabilities.insert(
static_cast<spirv::Capability>(capAttr.cast<IntegerAttr>().getInt()));
}
bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
// Make sure this op is available at the given version. Ops not implementing
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
// SPIR-V versions.
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
if (minVersion.getMinVersion() > givenVersion) {
LLVM_DEBUG(llvm::dbgs()
<< op->getName() << " illegal: requiring min version "
<< spirv::stringifyVersion(minVersion.getMinVersion())
<< "\n");
return false;
}
if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
if (maxVersion.getMaxVersion() < givenVersion) {
LLVM_DEBUG(llvm::dbgs()
<< op->getName() << " illegal: requiring max version "
<< spirv::stringifyVersion(maxVersion.getMaxVersion())
<< "\n");
return false;
}
// Make sure this op's required extensions are allowed to use. For each op,
// we return a vector of vector for its extension requirements following
// ((Extension::A OR Extenion::B) AND (Extension::C OR Extension::D))
// convention. Ops not implementing QueryExtensionInterface do not require
// extensions to be available.
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
auto exts = extensions.getExtensions();
for (const auto &ors : exts)
if (llvm::all_of(ors, [this](spirv::Extension ext) {
return this->givenExtensions.count(ext) == 0;
})) {
LLVM_DEBUG(llvm::dbgs() << op->getName()
<< " illegal: missing required extension\n");
return false;
}
}
// Make sure this op's required extensions are allowed to use. For each op,
// we return a vector of vector for its capability requirements following
// ((Capability::A OR Extenion::B) AND (Capability::C OR Capability::D))
// convention. Ops not implementing QueryExtensionInterface do not require
// extensions to be available.
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
auto caps = capabilities.getCapabilities();
for (const auto &ors : caps)
if (llvm::all_of(ors, [this](spirv::Capability cap) {
return this->givenCapabilities.count(cap) == 0;
})) {
LLVM_DEBUG(llvm::dbgs() << op->getName()
<< " illegal: missing required capability\n");
return false;
}
}
return true;
};