Files
Hocky Yudhiono f1822ca735 [mlir][xevm] Fix greedy rewriter crash in HandleVectorExtractPattern matches shuffles on block arguments (#192213)
`HandleVectorExtractPattern` could report `success()` without rewriting
the IR when `llvm.shufflevector` extracted a contiguous slice from a
**block argument** (no defining op). The greedy rewriter’s expensive
checks then aborted with *“pattern returned success but IR did not
change”*.

The pattern only performs work when the shuffle’s operand is defined by
another op (`FPExt`, `FPTrunc`, `bitcast`, nested `shufflevector`, or
`load`). For operands like function arguments, `getDefiningOp()` is
null, so nothing is rewritten; the function still fell through to
`return success()` without changing the IR and would crash when
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is on. `mlir-opt
--convert-xevm-to-llvm --split-input-file
mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir` no longer hits the
fatal error.

Assisted-by: Cursor (Composer 2)
2026-04-28 21:22:32 -07:00

1568 lines
59 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Types.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace xevm;
namespace {
struct LLVMFuncAttributeOptions {
bool isConvergent = false;
bool isNoUnwind = false;
bool isWillReturn = false;
LLVM::MemoryEffectsAttr memEffectsAttr{};
};
static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
false, true, false, {}};
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
false, true, true, {}};
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
true, true, true, {}};
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
return TypeSwitch<Type, std::string>(ty)
.Case([isUnsigned](VectorType ty) -> std::string {
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
getTypeMangling(ty.getElementType(), isUnsigned);
})
.Case([](Float16Type) -> std::string { return "Dh"; })
.Case([](Float32Type) -> std::string { return "f"; })
.Case([](Float64Type) -> std::string { return "d"; })
.Case([isUnsigned](IntegerType ty) -> std::string {
switch (ty.getWidth()) {
case 8:
return isUnsigned ? "h" : "c";
case 16:
return isUnsigned ? "t" : "s";
case 32:
return isUnsigned ? "j" : "i";
case 64:
return isUnsigned ? "m" : "l";
default:
llvm_unreachable("unhandled integer type");
}
})
.DefaultUnreachable("unhandled type for mangling");
}
std::string mangle(StringRef baseName, ArrayRef<Type> types,
ArrayRef<bool> isUnsigned = {}) {
assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
"Signedness info doesn't match");
std::string s;
llvm::raw_string_ostream os(s);
llvm::SmallDenseMap<Type, unsigned> substitutions;
os << "_Z" << baseName.size() << baseName;
for (auto [idx, type] : llvm::enumerate(types)) {
auto it = substitutions.find(type);
if (it != substitutions.end()) {
os << "S";
// First substitution is `S_`, second is `S0_`, and so on.
if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
os << firstIdx - 1;
os << "_";
} else {
if (!type.isIntOrFloat())
substitutions[type] = substitutions.size();
os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
}
}
return os.str();
}
std::string builtinElemType(ElemType elemType) {
switch (elemType) {
case ElemType::BF8:
return "bf8";
case ElemType::F8:
return "hf8";
case ElemType::BF16:
return "bf";
case ElemType::F16:
return "hf";
case ElemType::F32:
return "f";
default:
return stringifyElemType(elemType).str();
}
}
static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
case LoadCacheControl::USE_DEFAULT:
control = -1;
break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
control = 1;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 3;
break;
default:
break;
}
return control;
}
static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
case StoreCacheControl::USE_DEFAULT:
control = -1;
break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
control = 1;
break;
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
default:
break;
}
return control;
}
static int32_t getL3CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
case LoadCacheControl::USE_DEFAULT:
control = -1;
break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
control = 1;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 3;
break;
default:
break;
}
return control;
}
static int32_t getL3CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
case StoreCacheControl::USE_DEFAULT:
control = -1;
break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3WB:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3WB:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
default:
break;
}
return control;
}
static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}
static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}
static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
return op.getCacheControl();
}
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
if (!attr)
return std::nullopt;
return std::optional<LoadCacheControl>(attr.getValue());
}
return std::nullopt;
}
static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
if (!attr)
return std::nullopt;
return std::optional<StoreCacheControl>(attr.getValue());
}
return std::nullopt;
}
template <typename OpType>
int32_t getL1CacheControl(OpType op) {
return getL1CacheControl(*getCacheControl(op));
}
template <typename OpType>
int32_t getL3CacheControl(OpType op) {
return getL3CacheControl(*getCacheControl(op));
}
template <typename OpType>
static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!getCacheControl(op))
return {};
constexpr int32_t decorationCacheControlArity{3};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
// If the cache control is USE_DEFAULT, then we dont emit any metadata.
// Assert that if one of the L1 or L3 cache control values is USE_DEFAULT
// (represented as -1), then both must be USE_DEFAULT; otherwise there is a
// bug.
assert(((getL1CacheControl<OpType>(op) == -1) ==
(getL3CacheControl<OpType>(op) == -1)) &&
"If one of L1 or L3 cache control is USE_DEFAULT, both must be "
"USE_DEFAULT");
if (getL1CacheControl<OpType>(op) == -1 &&
getL3CacheControl<OpType>(op) == -1)
return {};
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
controlKey, 0, getL1CacheControl<OpType>(op)};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
controlKey, 1, getL3CacheControl<OpType>(op)};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
return rewriter.getArrayAttr(combinedAttrs);
}
//===----------------------------------------------------------------------===//
// Cache control annotation utilities
//
// Instead of attaching cache control as MLIR attributes and handling them
// during LLVM translation, we directly emit llvm.intr.ptr.annotation op in
// MLIR.
//===----------------------------------------------------------------------===//
/// Build one cache-control payload string per attribute.
///
/// Each Attribute is expected to be an ArrayAttr of 3 IntegerAttr values:
/// [SPIR-V decoration token, cache level, cache control value]
///
/// A single entry produces a string like: {6442:"0,1"}
/// where the quote characters (0x22) will appear as \22 in LLVM IR textual
/// form.
static SmallVector<std::string>
buildCacheControlPayloads(ArrayRef<Attribute> attrs) {
SmallVector<std::string> payloads;
llvm::StringMap<bool> seen;
for (Attribute a : attrs) {
auto arr = dyn_cast<ArrayAttr>(a);
if (!arr)
continue;
auto vals = arr.getValue();
assert(vals.size() == 3 &&
"Expected exactly 3 integer values (Token, CacheLevel, "
"ControlValue) in cache control attribute.");
auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
if (!tokenAttr || !secondAttr || !thirdAttr)
continue;
// Produce: {SPIR-V decoration token:"L1 cache control,L3 cache control"}
// The quote char (0x22) is embedded literally; LLVM IR prints it as \22.
std::string entry =
llvm::formatv("{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
secondAttr.getValue().getZExtValue(),
thirdAttr.getValue().getZExtValue());
// Deduplicate identical annotations.
if (!seen.insert({entry, true}).second)
continue;
payloads.push_back(std::move(entry));
}
return payloads;
}
/// Counter for generating unique global variable names.
static std::atomic<uint64_t> globalNameCounter{0};
/// Get or create a global metadata string and return a !llvm.ptr<1> value
/// pointing to it. The AddressOfOp is created at the current rewriter
/// insertion point; the GlobalOp is created at the module start.
static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
Operation *moduleOp, Location loc,
StringRef value, StringRef nameHint) {
// Build null-terminated string.
std::string strWithNull = value.str();
strWithNull.push_back('\0');
StringRef strRef(strWithNull.data(), strWithNull.size());
auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
// Search for an existing global with the same content.
for (auto &op : moduleOp->getRegion(0).front()) {
if (auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
if (!existingGlobal.getSection() ||
*existingGlobal.getSection() != "llvm.metadata")
continue;
if (auto strAttr =
dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
if (strAttr.getValue() == strRef) {
return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
existingGlobal.getSymName());
}
}
}
}
// Create new global at module start.
auto i8Type = rewriter.getI8Type();
auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
std::string globalName =
llvm::formatv("{0}.{1}", nameHint,
globalNameCounter.fetch_add(1, std::memory_order_relaxed))
.str();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
auto globalOp =
LLVM::GlobalOp::create(rewriter, loc, arrayType,
/*isConstant=*/true, LLVM::Linkage::Private,
globalName, rewriter.getStringAttr(strRef));
globalOp.setSection(StringRef("llvm.metadata"));
globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
globalOp.setAlignment(1);
globalOp.setAddrSpace(1);
}
// InsertionGuard restores the original insertion point here.
return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
}
/// Annotate a pointer value with cache control metadata by emitting chained
/// `llvm.intr.ptr.annotation` ops (LLVM::PtrAnnotation).
///
/// This is the MLIR-level equivalent of handleDecorationCacheControl() from
/// the LLVM translation layer. For each cache control attribute, it emits:
///
/// %ann = llvm.intr.ptr.annotation %ptr, @".str.cachecontrol.N",
/// @".str.file.N", 0, null : !llvm.ptr<AS>
///
/// Multiple annotations are chained: the result of each annotation op is
/// fed as the pointer input to the next one.
///
/// \param rewriter The pattern rewriter.
/// \param loc Source location for created ops.
/// \param ptr The pointer value to annotate.
/// \param cacheControls The cache control ArrayAttr (from
/// getCacheControlMetadata).
/// \param moduleOp The enclosing module (for creating globals).
/// \returns The annotated pointer value (or the original ptr if no
/// annotations).
static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
Location loc, Value ptr,
ArrayAttr cacheControls,
Operation *moduleOp) {
SmallVector<std::string> payloads =
buildCacheControlPayloads(cacheControls.getValue());
if (payloads.empty())
return ptr;
auto ptrType = cast<LLVM::LLVMPointerType>(ptr.getType());
auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
auto i32Ty = rewriter.getI32Type();
// Create shared constants for all annotations on this pointer.
Value fileStr =
createMetadataStringPtr(rewriter, moduleOp, loc, "", ".str.file");
Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
// Chain: each annotation takes the result of the previous one as its
// pointer operand.
Value curPtr = ptr;
for (const std::string &payload : payloads) {
Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
".str.cachecontrol");
auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
annStr, fileStr, lineVal, nullAS1);
curPtr = annOp.getResult();
}
return curPtr;
}
/// Helper to apply cache control annotation on a pointer operand of a call.
/// Replaces the pointer argument of the call with an annotated version.
///
/// For operations that produce a call (like block load/store/prefetch), the
/// pointer is typically the first argument. This function:
/// 1. Builds the annotation chain on the pointer.
/// 2. Replaces the pointer operand in the provided args list.
///
/// \param rewriter The pattern rewriter.
/// \param loc Source location.
/// \param ptr The original pointer value (first arg to the call).
/// \param cacheControls The cache control metadata.
/// \param moduleOp The enclosing module.
/// \param args The argument list (modified in place: args[ptrIdx] is
/// replaced).
/// \param ptrIdx Index of the pointer in the args list (default 0).
template <typename OpType>
static void
applyCacheControlAnnotation(ConversionPatternRewriter &rewriter, Location loc,
OpType op, SmallVectorImpl<Value> &args,
Operation *moduleOp, unsigned ptrIdx = 0) {
std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata(rewriter, op);
if (!optCacheControls)
return;
Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
*optCacheControls, moduleOp);
args[ptrIdx] = annotatedPtr;
}
//===----------------------------------------------------------------------===//
// End cache control annotation utilities
//===----------------------------------------------------------------------===//
static LLVM::CallOp createDeviceFunctionCall(
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
ArrayRef<Type> argTypes, ArrayRef<Value> args,
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
Location loc = op->getLoc();
auto funcOpRes =
LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
assert(!failed(funcOpRes));
LLVM::LLVMFuncOp funcOp = funcOpRes.value();
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
funcOp.setConvergent(funcAttributeOptions.isConvergent);
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
if (funcAttributeOptions.memEffectsAttr)
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
for (auto [idx, attrName] : paramAttrs)
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
callOp->setAttrs(funcOp->getAttrs());
return callOp;
}
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
switch (pTy) {
case xevm::ElemType::F32:
case xevm::ElemType::TF32:
return 1;
case xevm::ElemType::BF16:
case xevm::ElemType::F16:
return 2;
case xevm::ElemType::U8:
case xevm::ElemType::S8:
case xevm::ElemType::BF8:
case xevm::ElemType::F8:
return 4;
case xevm::ElemType::E2M1:
case xevm::ElemType::U4:
case xevm::ElemType::S4:
return 8;
default:
llvm_unreachable("unsupported xevm::ElemType");
}
}
class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getC()) {
return rewriter.notifyMatchFailure(op, "OCL requires C operand");
}
auto precisionA = op.getTypes().getA();
auto precisionB = op.getTypes().getB();
auto precisionC = op.getTypes().getC();
auto precisionD = op.getTypes().getD();
if (precisionC != precisionD) {
return rewriter.notifyMatchFailure(op, "type of C and D need to match");
}
if (precisionC != xevm::ElemType::S32 &&
precisionC != xevm::ElemType::F32 &&
precisionC != xevm::ElemType::F16 &&
precisionC != xevm::ElemType::BF16) {
return rewriter.notifyMatchFailure(
op, "type of C and D must be S32, F32, F16 or BF16");
}
if (precisionA == xevm::ElemType::S32 ||
precisionA == xevm::ElemType::F32) {
return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
}
if (precisionB == xevm::ElemType::S32 ||
precisionB == xevm::ElemType::F32) {
return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
}
constexpr uint32_t bitWidthPackedA{16};
constexpr uint32_t bitWidthPackedB{32};
auto loc = op.getLoc();
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
VectorType origTy = cast<VectorType>(val.getType());
const uint32_t vecBitSize =
origTy.getNumElements() *
origTy.getElementType().getIntOrFloatBitWidth();
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
return val;
};
Value a = op.getA();
Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedA);
a = castIfNeeded(a, packedAType);
Value b = op.getB();
Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedB);
b = castIfNeeded(b, packedBType);
Value c = op.getC();
VectorType cOrigTy = cast<VectorType>(c.getType());
VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
// OCL builtins encode bfloat16 as int16
VectorType cTy =
cOrigTy.getElementType().isBF16()
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
constexpr int32_t systolicDepth{8};
std::string fnName =
llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
stringifyElemType(op.getTypes().getA()).str(),
stringifyElemType(op.getTypes().getB()).str(),
systolicDepth *
getNumOperandsPerDword(op.getTypes().getA()))
.str();
SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
fnName = mangle(fnName, argTypes);
SmallVector<Value> args{a, b, c};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
funcAttrs, op.getOperation())
->getResult(0);
if (resOrigTy != resTy)
result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
}
};
class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
SmallVector<Value> args{op.getPtr(), one};
// Annotate pointer with cache control before passing to the call.
applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
/*ptrIdx=*/0);
SmallVector<Type> argTypes;
for (auto arg : args)
argTypes.push_back(arg.getType());
auto funcAttr = noUnwindAttrs;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
createDeviceFunctionCall(rewriter, fnName,
LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, funcAttr, op.getOperation());
rewriter.eraseOp(op);
return success();
}
};
class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
const std::string fnName{"atomic_work_item_fence"};
int memScope, addrSpace;
switch (op.getAddrspace()) {
case xevm::AddrSpace::SHARED:
addrSpace = 1; // CLK_LOCAL_MEM_FENCE
break;
case xevm::AddrSpace::GLOBAL:
addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
break;
default:
// GENERIC is not supported in OpenCL
return rewriter.notifyMatchFailure(
op, "Fence only supports global and shared address spaces.");
}
switch (op.getScope()) {
case xevm::MemScope::WORKGROUP:
memScope = 1;
break;
case xevm::MemScope::DEVICE:
memScope = 2;
break;
default:
// CLUSTER and SYSTEM are not supported in OpenCL
return rewriter.notifyMatchFailure(
op, "Fence only supports workgroup and device memory scopes.");
}
Type i32Type = rewriter.getI32Type();
Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
Value memScopeConst =
LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
Value addrSpaceConst =
LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
SmallVector<Type> argTypes{3, i32Type};
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, noUnwindAttrs,
op.getOperation());
rewriter.eraseOp(op);
return success();
}
};
template <typename OpType>
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
auto loc = op.getLoc();
auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
VectorType vecType;
bool packReg = false;
bool transpose = false;
if constexpr (isLoad) {
vecType = op.getRes().getType();
packReg = op.getPackRegister();
transpose = op.getTranspose();
} else if constexpr (!isPrefetch) {
vecType = op.getStoredVal().getType();
}
auto i32Type = rewriter.getI32Type();
Value byteCoord =
LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
byteCoord = LLVM::InsertElementOp::create(
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
byteCoord = LLVM::InsertElementOp::create(
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
// Annotate pointer (args[0]) with cache control before the call.
applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
/*ptrIdx=*/0);
SmallVector<Type> retTypes;
Value spvLoadDstPtr;
std::string funcName{"intel_sub_group_2d_block_"};
std::string bitWidthId;
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
if constexpr (isPrefetch) { // Prefetch
funcName += "prefetch";
paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr = noUnwindAttrs;
funcAttr.memEffectsAttr = memAttr;
} else {
auto vecElemType = vecType.getElementType();
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
vecType.getNumElements());
auto dstOrSrcPtr = LLVM::AllocaOp::create(
rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
vecElemType, numElems);
args.push_back(dstOrSrcPtr);
if constexpr (isLoad) { // Load
funcName += "read";
bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
if (packReg)
funcName += "_transform";
else if (transpose)
funcName += "_transpose";
spvLoadDstPtr = dstOrSrcPtr;
retTypes.push_back(vecType);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
};
} else { // Store
funcName += "write";
bitWidthId = (vecElemBitWidth == 32)
? "j"
: ((vecElemBitWidth == 16) ? "t" : "h");
LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
};
}
}
funcName =
llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
.str();
std::string prefetchCode("");
if (!isPrefetch)
prefetchCode += "P";
funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
funcName, prefetchCode, bitWidthId)
.str();
SmallVector<Type> argTypes;
for (auto arg : args) {
argTypes.push_back(arg.getType());
}
createDeviceFunctionCall(
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
if constexpr (isLoad)
rewriter.replaceOp(
op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
else
rewriter.eraseOp(op);
return success();
}
};
template <typename OpType>
class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
auto loc = op.getLoc();
auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
// Get OpenCL function name
// https://registry.khronos.org/OpenCL/extensions/
// intel/cl_intel_subgroup_local_block_io.html
std::string funcName{"intel_sub_group_block_"};
// Value or Result type can be vector or scalar
Type valOrResTy;
if constexpr (isStore) {
funcName += "write_u";
valOrResTy = op.getVal().getType();
} else {
funcName += "read_u";
valOrResTy = op.getType();
}
// Get element type of the vector/scalar
VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
funcName += getTypeMangling(elemType);
if (vecTy)
funcName += std::to_string(vecTy.getNumElements());
SmallVector<Type, 2> argTypes{};
// XeVM BlockLoad/StoreOp always use signless integer types
// but OpenCL builtins expect unsigned types
// use unsigned types for mangling
SmallVector<bool, 2> isUnsigned{};
// arg0: pointer to the src/dst address
// arg1 - only if store : vector to store
// Prepare arguments
SmallVector<Value, 2> args{};
args.push_back(op.getPtr());
argTypes.push_back(op.getPtr().getType());
isUnsigned.push_back(true);
// Annotate pointer (args[0]) with cache control.
applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
/*ptrIdx=*/0);
// Update argTypes[0] in case the pointer type changed (it shouldn't
// change type, but the value is now the annotated pointer).
argTypes[0] = args[0].getType();
Type retType;
if constexpr (isStore) {
args.push_back(op.getVal());
argTypes.push_back(op.getVal().getType());
isUnsigned.push_back(true);
retType = LLVM::LLVMVoidType::get(rewriter.getContext());
} else {
retType = valOrResTy;
}
funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
"PU3AS" +
std::to_string(op.getPtr().getType().getAddressSpace());
funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
if constexpr (isStore)
funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
LLVM::CallOp call =
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
{}, funcAttr, op.getOperation());
if constexpr (isStore)
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, call->getResult(0));
return success();
}
};
template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
return failure();
auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata(rewriter, op);
if (!optCacheControls) {
rewriter.modifyOpInPlace(op, [&]() { op->removeAttr("cache_control"); });
return success();
}
// Determine which operand is the pointer.
constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
unsigned ptrIdx = isStore ? 1 : 0;
Value ptr = op->getOperand(ptrIdx);
// Emit annotation intrinsic calls on the pointer.
Value annotatedPtr = annotatePtrWithCacheControl(
rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
// Replace the pointer operand with the annotated one.
rewriter.modifyOpInPlace(op, [&]() {
op->setOperand(ptrIdx, annotatedPtr);
op->removeAttr("cache_control");
});
return success();
}
};
//===----------------------------------------------------------------------===//
// GPU index id operations
//===----------------------------------------------------------------------===//
/*
// Launch Config ops
// dimidx - x, y, z - is fixed to i32
// return type is set by XeVM type converter
// get_local_id
xevm::WorkitemIdXOp;
xevm::WorkitemIdYOp;
xevm::WorkitemIdZOp;
// get_local_size
xevm::WorkgroupDimXOp;
xevm::WorkgroupDimYOp;
xevm::WorkgroupDimZOp;
// get_group_id
xevm::WorkgroupIdXOp;
xevm::WorkgroupIdYOp;
xevm::WorkgroupIdZOp;
// get_num_groups
xevm::GridDimXOp;
xevm::GridDimYOp;
xevm::GridDimZOp;
// get_global_id : to be added if needed
*/
// Helpers to get the OpenCL function name and dimension argument for each op.
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
return {"get_local_id", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
return {"get_local_id", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
return {"get_local_id", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
return {"get_local_size", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
return {"get_local_size", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
return {"get_local_size", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
return {"get_group_id", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
return {"get_group_id", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
return {"get_group_id", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
return {"get_num_groups", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
return {"get_num_groups", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
return {"get_num_groups", 2};
}
/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
/// a constant argument for the dimension - x, y or z.
template <typename OpType>
class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto [baseName, dim] = getConfig(op);
Type dimTy = rewriter.getI32Type();
Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
static_cast<int64_t>(dim));
std::string func = mangle(baseName, {dimTy}, {true});
Type resTy = op.getType();
auto call =
createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
noUnwindWillReturnAttrs, op.getOperation());
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
/*errnoMem=*/noModRef,
/*targetMem0=*/noModRef,
/*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
}
};
/*
// Subgroup ops
// get_sub_group_local_id
xevm::LaneIdOp;
// get_sub_group_id
xevm::SubgroupIdOp;
// get_sub_group_size
xevm::SubgroupSizeOp;
// get_num_sub_groups : to be added if needed
*/
// Helpers to get the OpenCL function name for each op.
static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
static StringRef getConfig(xevm::SubgroupSizeOp) {
return "get_sub_group_size";
}
template <typename OpType>
class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string func = mangle(getConfig(op).str(), {});
Type resTy = op.getType();
auto call =
createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
noUnwindWillReturnAttrs, op.getOperation());
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
/*errnoMem=*/noModRef,
/*targetMem0=*/noModRef,
/*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
}
};
class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Supported source and result types are resticted for now.
auto srcEtype = op.getSrcEtype().getEtype();
auto dstEtype = op.getDstEtype().getEtype();
if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
if (vecSrcTy.getNumElements() != 16)
return rewriter.notifyMatchFailure(
op, "Only vector src of 16 elements is supported");
} else {
return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
}
if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
if (vecDstTy.getNumElements() != 16)
return rewriter.notifyMatchFailure(
op, "Only vector dst of 16 elements is supported");
} else {
return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
}
if (srcEtype == TruncfSrcElemTypes::F16 &&
dstEtype == TruncfDstElemTypes::BF8) {
// BF8 is just F16 with lower 8 bits of mantessa discard.
// Signbit Exponent Mantessa
// BF8 1 5 2
// F16 1 5 10
// Xe arch is Little Endian so BF8 is just the second byte of the two
// byte representation used for F16
auto firstHalf =
LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
auto secondHalf = LLVM::ShuffleVectorOp::create(
rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
{8, 9, 10, 11, 12, 13, 14, 15});
auto firstHalfCasted = LLVM::BitcastOp::create(
rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
firstHalf);
auto secondHalfCasted = LLVM::BitcastOp::create(
rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
secondHalf);
// Gather just the second bytes from every two byte F16 values
auto resFirstHalf = LLVM::ShuffleVectorOp::create(
rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
{1, 3, 5, 7, 9, 11, 13, 15});
auto resSecondHalf = LLVM::ShuffleVectorOp::create(
rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
{1, 3, 5, 7, 9, 11, 13, 15});
auto res = LLVM::ShuffleVectorOp::create(
rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
rewriter.replaceOp(op, res);
} else {
return rewriter.notifyMatchFailure(
op, "Unsupported src, dst element type pair.");
}
return success();
}
};
class MMAMxToOCLPattern : public OpConversionPattern<MMAMxOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getC()) {
return rewriter.notifyMatchFailure(op, "OCL requires C operand");
}
auto precisionC = op.getTypes().getC();
auto precisionD = op.getTypes().getD();
if (precisionC != precisionD) {
return rewriter.notifyMatchFailure(op, "type of C and D need to match");
}
constexpr uint32_t bitWidthPackedA{16};
constexpr uint32_t bitWidthPackedB{32};
auto loc = op.getLoc();
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
VectorType origTy = cast<VectorType>(val.getType());
const uint32_t vecBitSize =
origTy.getNumElements() *
origTy.getElementType().getIntOrFloatBitWidth();
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
return val;
};
Value a = op.getA();
Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedA);
a = castIfNeeded(a, packedAType);
Value b = op.getB();
Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedB);
b = castIfNeeded(b, packedBType);
Value c = op.getC();
VectorType cOrigTy = cast<VectorType>(c.getType());
VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
// OCL builtins encode bfloat16 as int16
VectorType cTy =
cOrigTy.getElementType().isBF16()
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
std::string fnName =
llvm::formatv("__builtin_IB_sub_group16_bdpas_{0}_{1}_{2}_{3}_8_8",
builtinElemType(op.getTypes().getD()),
builtinElemType(op.getTypes().getC()),
builtinElemType(op.getTypes().getA()),
builtinElemType(op.getTypes().getB()))
.str();
auto scaleA = op.getScaleA();
auto scaleB = op.getScaleB();
SmallVector<Type> argTypes{cTy, a.getType(), b.getType(), scaleA.getType(),
scaleB.getType()};
SmallVector<Value> args{c, a, b, scaleA, scaleB};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
/*errnoMem=*/LLVM::ModRefInfo::NoModRef,
/*targetMem0=*/LLVM::ModRefInfo::NoModRef,
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
funcAttrs, op.getOperation())
->getResult(0);
if (resOrigTy != resTy)
result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
}
};
class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
auto addrSpace = ptrType.getAddressSpace();
if (addrSpace != 3)
return failure();
auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
if (!symTable)
return failure();
Block *moduleBody;
if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
moduleBody = mod.getBody();
} else if (gpu::GPUModuleOp gpuMod =
dyn_cast<gpu::GPUModuleOp>(*symTable)) {
moduleBody = gpuMod.getBody();
} else {
return failure();
}
auto val = op.getArraySize();
APInt cst;
if (!matchPattern(val, m_ConstantInt(&cst)))
return failure();
auto loc = op.getLoc();
auto globalType = LLVM::LLVMArrayType::get(
rewriter.getContext(), op.getElemType(), cst.getZExtValue());
LLVM::GlobalOp globalVar;
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleBody);
auto alignment = op.getAlignment();
globalVar = LLVM::GlobalOp::create(
rewriter, loc, globalType, /*isConstant=*/false,
/*linkage=*/LLVM::Linkage::Internal,
/*name=*/std::string("__global_alloca_") +
std::to_string(getNextGlobalIdx()),
/*value=*/Attribute(),
/*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
}
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
return success();
}
private:
static unsigned getNextGlobalIdx() {
static unsigned globalIdx = 0;
return globalIdx++;
}
};
// Checks if shufflevector is used as a way to extract a contiguous slice
// from a vector.
// - source vector V1 and V2 are the same vector.
// - mask size is not greater than the source vector size
// - mask values represent a sequence of consecutive increasing numbers
// that stay in bounds of the source vector when used for indexing.
static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
if (op.getV1() != op.getV2())
return false;
auto maskAttr = op.getMask();
int64_t maskSize = static_cast<int64_t>(maskAttr.size());
int64_t sourceSize = op.getV1().getType().getNumElements();
if (maskSize > sourceSize)
return false;
int64_t firstIndex = maskAttr[0];
for (int64_t i = 1; i < maskSize; ++i) {
int64_t index = maskAttr[i];
if (index != firstIndex + i)
return false;
if (index >= sourceSize)
return false;
}
return true;
}
// Input vector of a shuffle vector op extracting a contiguous slice is an
// illegal vector in SPIRV kernel if the vector size is > 16 elements.
// To legalize this case, keep applying the following transformations until no
// more match:
// 1. keep hoisting the shuffle vector op past unary element-wise operations
// start with fpext, fptrunc and bitcast for now.
// 2. merge with another shuffle vector op
// 3. merge with load as a smaller load
class HandleVectorExtractPattern
: public OpRewritePattern<LLVM::ShuffleVectorOp> {
using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
void initialize() { setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
PatternRewriter &rewriter) const override {
if (!isExtractingContiguousSlice(op))
return failure();
auto mask = op.getMask();
auto loc = op.getLoc();
auto ty = op.getType();
// Check source operand to determine rewrite pattern.
auto src = op.getV1();
// 1. Hoist past unary element-wise operations
if (auto srcOp = src.getDefiningOp()) {
if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
Value srcInput = srcOp->getOperand(0);
// Create new shuffle vector op with unary input as source.
auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
auto newShuffleVecTy =
VectorType::get(mask.size(), srcVecTy.getElementType());
auto newShuffle = LLVM::ShuffleVectorOp::create(
rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
// Create new unary op with new shuffle as input.
Value newUnaryOp;
if (isa<LLVM::FPExtOp>(srcOp)) {
newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
} else {
newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
}
rewriter.replaceOp(op, newUnaryOp);
} else if (isa<LLVM::BitcastOp>(srcOp)) {
Value srcInput = srcOp->getOperand(0);
// Create new shuffle vector op with unary input as source.
auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
auto srcInputSize = srcInputVecTy.getNumElements();
auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
auto srcResSize = srcResVecTy.getNumElements();
auto maskSize = static_cast<int32_t>(mask.size());
if (srcInputSize > srcResSize) {
return failure();
}
if (srcResSize % srcInputSize != 0) {
return failure();
}
auto maskScale = srcResSize / srcInputSize;
if (maskScale != 1) {
if (mask[0] % maskScale != 0) {
return failure();
}
// Create a new mask that maps to the source vector
SmallVector<int32_t> newMask;
int32_t newMaskSize = maskSize / maskScale;
int32_t maskStart = mask[0] / maskScale;
for (int32_t i = 0; i < newMaskSize; ++i) {
newMask.push_back(maskStart + i);
}
mask = newMask;
}
auto newShuffleVecTy =
VectorType::get(srcInputSize, srcInputVecTy.getElementType());
auto newShuffle = LLVM::ShuffleVectorOp::create(
rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
// Create new unary op with new shuffle as input.
auto newBitcast =
LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
rewriter.replaceOp(op, newBitcast);
} else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
// 2. Merge with source shuffle vector op if, the source op is
// also extracting a contigous slice and create a new
// shuffle vector op directly from the source of
// the first shuffle.
auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
if (!isExtractingContiguousSlice(srcShuffle))
return failure();
auto srcMask = srcShuffle.getMask();
SmallVector<int32_t> combinedMask;
for (auto index : mask) {
combinedMask.push_back(srcMask[index]);
}
auto newShuffle = LLVM::ShuffleVectorOp::create(
rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
rewriter.replaceOp(op, newShuffle);
} else if (isa<LLVM::LoadOp>(srcOp)) {
// 3. Merge with load as a smaller load
auto loadOp = cast<LLVM::LoadOp>(srcOp);
auto loadPtr = loadOp.getAddr();
auto loadAddrSpace = loadPtr.getType().getAddressSpace();
if (loadAddrSpace != 0)
return failure();
auto loadTy = dyn_cast<VectorType>(loadOp.getType());
auto elemTy = loadTy.getElementType();
auto firstIndex = mask[0];
auto newVecTy = VectorType::get(mask.size(), elemTy);
// GEPOp is needed if first index is not zero
if (firstIndex) {
auto newPtr = LLVM::GEPOp::create(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), loadAddrSpace),
elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
rewriter.replaceOp(op, newLoad);
} else {
auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
rewriter.replaceOp(op, newLoad);
}
} else {
return failure();
}
} else {
// No defining op (e.g. function argument): nothing to hoist/merge.
return failure();
}
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
struct ConvertXeVMToLLVMPass
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, XeVMDialect>();
}
void runOnOperation() override {
ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
populateXeVMToLLVMConversionPatterns(target, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
// Apply in-dialect lowerings to handle illegal vectors
{
RewritePatternSet vectorPatterns(&getContext());
vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
GreedyRewriteConfig config{};
// folding can remove ops with temporary attributes used to
// represent LLVM metadata, so disable it here.
// Effectively just this single pattern is applied without any
// op folding patterns from dialects.
config.enableFolding(false);
// config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
// config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
(void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
config);
}
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
RewritePatternSet &patterns) {
// some LLVM operations need to be converted.
target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
// llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
// properly by SPIRV backend. It needs to be rewritten as a sequence with
// llvm global.
if (isa<LLVM::AllocaOp>(op)) {
LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
auto addrSpace = pTy.getAddressSpace();
return addrSpace != 3;
}
// cache_control attribute should be converted.
return !op->hasAttr("cache_control");
});
target.addIllegalDialect<XeVMDialect>();
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
BlockLoadStore1DToOCLPattern<BlockStoreOp>,
LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
LaunchConfigOpToOCLPattern<GridDimXOp>,
LaunchConfigOpToOCLPattern<GridDimYOp>,
LaunchConfigOpToOCLPattern<GridDimZOp>,
SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
patterns.getContext());
}