|  | //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  | // | 
|  | // This file contains definitions of patterns to lower GPU Subgroup MMA ops to | 
|  | // SPIRV Cooperative Matrix ops. | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" | 
|  | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" | 
|  | #include "mlir/Dialect/GPU/IR/GPUDialect.h" | 
|  | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" | 
|  | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | 
|  | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" | 
|  | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" | 
|  | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" | 
|  | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" | 
|  | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" | 
|  | #include "mlir/IR/BuiltinAttributes.h" | 
|  | #include "mlir/IR/BuiltinTypes.h" | 
|  | #include "mlir/IR/TypeUtilities.h" | 
|  | #include "mlir/IR/ValueRange.h" | 
|  | #include "llvm/ADT/STLExtras.h" | 
|  | #include "llvm/ADT/StringSwitch.h" | 
|  |  | 
|  | #include <cassert> | 
|  |  | 
|  | namespace mlir { | 
|  | //===----------------------------------------------------------------------===// | 
|  | // Patterns and helpers. | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op | 
|  | /// when the elementwise op directly supports with cooperative matrix type. | 
|  | /// Returns false if cannot. | 
|  | /// | 
|  | /// See SPV_KHR_cooperative_matrix for supported elementwise ops. | 
|  | static bool createElementwiseOp(ConversionPatternRewriter &builder, | 
|  | gpu::SubgroupMmaElementwiseOp op, Type coopType, | 
|  | ValueRange operands) { | 
|  | assert((isa<spirv::CooperativeMatrixType>(coopType))); | 
|  |  | 
|  | switch (op.getOpType()) { | 
|  | case gpu::MMAElementwiseOp::ADDF: | 
|  | builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::ADDI: | 
|  | builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::SUBF: | 
|  | builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::SUBI: | 
|  | builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::DIVF: | 
|  | builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::DIVS: | 
|  | builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::DIVU: | 
|  | builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::NEGATEF: | 
|  | builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::NEGATES: | 
|  | builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands); | 
|  | return true; | 
|  | case gpu::MMAElementwiseOp::EXTF: | 
|  | builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands); | 
|  | return true; | 
|  | default: | 
|  | break; | 
|  | } | 
|  | return false; | 
|  | } | 
|  |  | 
|  | bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { | 
|  | assert(!operands.empty()); | 
|  | if (!llvm::all_equal( | 
|  | llvm::map_range(operands, [](Value v) { return v.getType(); }))) | 
|  | return false; | 
|  |  | 
|  | return isa<spirv::CooperativeMatrixType>(operands.front().getType()); | 
|  | } | 
|  |  | 
|  | namespace { | 
|  | /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative | 
|  | /// matrix ops. | 
|  | struct WmmaConstantOpToSPIRVLowering final | 
|  | : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { | 
|  | using OpConversionPattern::OpConversionPattern; | 
|  |  | 
|  | LogicalResult | 
|  | matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, | 
|  | ConversionPatternRewriter &rewriter) const override { | 
|  | assert(adaptor.getOperands().size() == 1); | 
|  | Value cst = adaptor.getOperands().front(); | 
|  | auto coopType = getTypeConverter()->convertType(op.getType()); | 
|  | if (!coopType) | 
|  | return rewriter.notifyMatchFailure(op, "type conversion failed"); | 
|  |  | 
|  | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst); | 
|  | return success(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for | 
|  | /// the default case. | 
|  | struct WmmaElementwiseOpToSPIRVDefaultLowering final | 
|  | : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { | 
|  | using OpConversionPattern::OpConversionPattern; | 
|  |  | 
|  | LogicalResult | 
|  | matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, | 
|  | ConversionPatternRewriter &rewriter) const override { | 
|  | // All operands should be of cooperative matrix types. | 
|  | if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { | 
|  | return rewriter.notifyMatchFailure(op, | 
|  | "not all operands are coop matrices"); | 
|  | } | 
|  |  | 
|  | auto coopType = getTypeConverter()->convertType(op.getType()); | 
|  | if (!coopType) | 
|  | return rewriter.notifyMatchFailure(op, "type conversion failed"); | 
|  |  | 
|  | return success( | 
|  | createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); | 
|  | } | 
|  | }; | 
|  |  | 
|  | /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for | 
|  | /// matrix times scalar case. | 
|  | struct WmmaElementwiseOpToSPIRVScalarMulLowering final | 
|  | : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { | 
|  | using OpConversionPattern::OpConversionPattern; | 
|  |  | 
|  | LogicalResult | 
|  | matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, | 
|  | ConversionPatternRewriter &rewriter) const override { | 
|  | if (adaptor.getOperands().size() != 2) | 
|  | return failure(); | 
|  |  | 
|  | // All operands should be of cooperative matrix types. | 
|  | if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { | 
|  | return rewriter.notifyMatchFailure(op, | 
|  | "not all operands are coop matrices"); | 
|  | } | 
|  |  | 
|  | if (op.getOpType() != gpu::MMAElementwiseOp::MULF) | 
|  | return failure(); | 
|  |  | 
|  | // Use the original operands to check whether one of the operands is a splat | 
|  | // scalar value. | 
|  | Value lhs = op.getOperands().front(); | 
|  | Value rhs = op.getOperands().back(); | 
|  | Value splat = nullptr; | 
|  | Value matrix = nullptr; | 
|  | if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { | 
|  | splat = adaptor.getOperands().front(); | 
|  | matrix = adaptor.getOperands().back(); | 
|  | } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { | 
|  | matrix = adaptor.getOperands().front(); | 
|  | splat = adaptor.getOperands().back(); | 
|  | } | 
|  | if (!splat || !matrix) | 
|  | return rewriter.notifyMatchFailure(op, "no splat operand"); | 
|  |  | 
|  | // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. | 
|  | Value scalar; | 
|  | auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); | 
|  | if (!cc) { | 
|  | return rewriter.notifyMatchFailure(op, | 
|  | "splat is not a composite construct"); | 
|  | } | 
|  |  | 
|  | assert(cc.getConstituents().size() == 1); | 
|  | scalar = cc.getConstituents().front(); | 
|  |  | 
|  | auto coopType = getTypeConverter()->convertType(op.getType()); | 
|  | if (!coopType) | 
|  | return rewriter.notifyMatchFailure(op, "type conversion failed"); | 
|  | rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( | 
|  | op, coopType, ValueRange{matrix, scalar}); | 
|  | return success(); | 
|  | } | 
|  | }; | 
|  | } // namespace | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // SPV_KHR_cooperative_matrix | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | namespace khr { | 
|  | namespace { | 
|  |  | 
|  | /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV | 
|  | /// dialect. | 
|  | struct WmmaLoadOpToSPIRVLowering final | 
|  | : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { | 
|  | using OpConversionPattern::OpConversionPattern; | 
|  |  | 
|  | LogicalResult | 
|  | matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor, | 
|  | ConversionPatternRewriter &rewriter) const override { | 
|  | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); | 
|  | Location loc = op->getLoc(); | 
|  |  | 
|  | auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType()); | 
|  | MemRefType memrefType = op.getSrcMemref().getType(); | 
|  | Value bufferPtr = | 
|  | spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(), | 
|  | adaptor.getIndices(), loc, rewriter); | 
|  |  | 
|  | auto coopType = | 
|  | typeConverter.convertType<spirv::CooperativeMatrixType>(retType); | 
|  | if (!coopType) | 
|  | return rewriter.notifyMatchFailure(op, "type conversion failed"); | 
|  |  | 
|  | int64_t stride = op.getLeadDimension().getSExtValue(); | 
|  | IntegerType i32Type = rewriter.getI32Type(); | 
|  | auto strideValue = rewriter.create<spirv::ConstantOp>( | 
|  | loc, i32Type, IntegerAttr::get(i32Type, stride)); | 
|  |  | 
|  | bool isColMajor = op.getTranspose().value_or(false); | 
|  | auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor | 
|  | : spirv::CooperativeMatrixLayoutKHR::RowMajor; | 
|  |  | 
|  | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>( | 
|  | op, coopType, bufferPtr, strideValue, layout); | 
|  | return success(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV | 
|  | /// dialect. | 
|  | struct WmmaStoreOpToSPIRVLowering final | 
|  | : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { | 
|  | using OpConversionPattern::OpConversionPattern; | 
|  |  | 
|  | LogicalResult | 
|  | matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor, | 
|  | ConversionPatternRewriter &rewriter) const override { | 
|  | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); | 
|  | Location loc = op->getLoc(); | 
|  |  | 
|  | auto memrefType = cast<MemRefType>(op.getDstMemref().getType()); | 
|  | Value bufferPtr = | 
|  | spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(), | 
|  | adaptor.getIndices(), loc, rewriter); | 
|  |  | 
|  | int64_t stride = op.getLeadDimension().getSExtValue(); | 
|  | IntegerType i32Type = rewriter.getI32Type(); | 
|  | auto strideValue = rewriter.create<spirv::ConstantOp>( | 
|  | loc, i32Type, IntegerAttr::get(i32Type, stride)); | 
|  |  | 
|  | bool isColMajor = op.getTranspose().value_or(false); | 
|  | auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor | 
|  | : spirv::CooperativeMatrixLayoutKHR::RowMajor; | 
|  |  | 
|  | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>( | 
|  | op, bufferPtr, adaptor.getSrc(), strideValue, layout); | 
|  | return success(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV | 
|  | /// dialect. | 
|  | struct WmmaMmaOpToSPIRVLowering final | 
|  | : OpConversionPattern<gpu::SubgroupMmaComputeOp> { | 
|  | using OpConversionPattern::OpConversionPattern; | 
|  |  | 
|  | LogicalResult | 
|  | matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, | 
|  | OpAdaptor adaptor, | 
|  | ConversionPatternRewriter &rewriter) const override { | 
|  | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>( | 
|  | subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(), | 
|  | adaptor.getOpC()); | 
|  | return success(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  | } // namespace khr | 
|  | } // namespace mlir | 
|  |  | 
|  | void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( | 
|  | SPIRVTypeConverter &converter, RewritePatternSet &patterns) { | 
|  | using namespace mlir; | 
|  | MLIRContext *context = patterns.getContext(); | 
|  | patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering, | 
|  | khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, | 
|  | WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); | 
|  | // Give the following patterns higher benefit to prevail over the default one. | 
|  | patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context, | 
|  | /*benefit=*/2); | 
|  | } | 
|  |  | 
|  | void mlir::populateMMAToSPIRVCoopMatrixTypeConversion( | 
|  | mlir::SPIRVTypeConverter &typeConverter) { | 
|  | typeConverter.addConversion([](gpu::MMAMatrixType type) { | 
|  | ArrayRef<int64_t> retTypeShape = type.getShape(); | 
|  | Type elementType = type.getElementType(); | 
|  | auto use = | 
|  | llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand()) | 
|  | .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA) | 
|  | .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB) | 
|  | .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc); | 
|  |  | 
|  | return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0], | 
|  | retTypeShape[1], | 
|  | spirv::Scope::Subgroup, use); | 
|  | }); | 
|  | } |