951 lines
38 KiB
C++
951 lines
38 KiB
C++
//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements lowering of vector operations to XeGPU dialect ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
|
|
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include <algorithm>
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
// Return true if value represents a zero constant.
|
|
static bool isZeroConstant(Value val) {
|
|
auto constant = val.getDefiningOp<arith::ConstantOp>();
|
|
if (!constant)
|
|
return false;
|
|
|
|
return TypeSwitch<Attribute, bool>(constant.getValue())
|
|
.Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); })
|
|
.Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); })
|
|
.Default(false);
|
|
}
|
|
|
|
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
|
|
Operation *op, VectorType vecTy,
|
|
MemRefType memTy) {
|
|
// Validate only vector as the basic vector store and load ops guarantee
|
|
// XeGPU-compatible memref source.
|
|
unsigned vecRank = vecTy.getRank();
|
|
if (!(vecRank == 1 || vecRank == 2))
|
|
return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
|
|
|
|
if (!vecTy.getElementType().isIntOrFloat())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Expected scalar type with known bitwidth");
|
|
|
|
// XeGPU requires the memref to have a scalar integer or float element type.
|
|
// Memrefs with vector element types (e.g. memref<?xvector<4xf32>>) are not
|
|
// supported because createNdDescriptor computes byte offsets using
|
|
// getElementTypeBitWidth(), which asserts on non-integer/float types.
|
|
if (!memTy.getElementType().isIntOrFloat())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Unsupported memref element type: expected integer or float");
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult transferPreconditions(PatternRewriter &rewriter,
|
|
VectorTransferOpInterface xferOp) {
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp,
|
|
"Masked transfer is not supported");
|
|
|
|
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
|
|
if (!srcTy)
|
|
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
|
|
|
|
// Validate further transfer op semantics.
|
|
SmallVector<int64_t> strides;
|
|
int64_t offset;
|
|
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "Buffer must be contiguous in the innermost dimension");
|
|
|
|
VectorType vecTy = xferOp.getVectorType();
|
|
unsigned vecRank = vecTy.getRank();
|
|
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "Boundary check is available only for block instructions.");
|
|
|
|
AffineMap map = xferOp.getPermutationMap();
|
|
if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
|
|
return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
|
|
unsigned numInputDims = map.getNumInputs();
|
|
for (AffineExpr expr : map.getResults().take_back(vecRank)) {
|
|
auto dim = dyn_cast<AffineDimExpr>(expr);
|
|
if (dim.getPosition() < (numInputDims - vecRank))
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "Only the innermost dimensions can be accessed");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
|
|
Location loc,
|
|
xegpu::TensorDescType descType,
|
|
TypedValue<MemRefType> src) {
|
|
MemRefType srcTy = src.getType();
|
|
assert(srcTy.isStrided() && "Expected strided memref type");
|
|
auto [strides, offset] = srcTy.getStridesAndOffset();
|
|
bool isStatic = true;
|
|
|
|
// Memref is dynamic if any of its shape, offset or strides is dynamic.
|
|
if (!srcTy.hasStaticShape())
|
|
isStatic = false;
|
|
|
|
if (!ShapedType::isStatic(offset))
|
|
isStatic = false;
|
|
|
|
for (auto stride : strides) {
|
|
if (!ShapedType::isStatic(stride)) {
|
|
isStatic = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
xegpu::CreateNdDescOp ndDesc;
|
|
if (isStatic) {
|
|
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
|
|
} else {
|
|
// In case of ranked dynamic memref, instead of passing on the memref,
|
|
// i64 base address, source's offset, shape and strides have to be
|
|
// explicitly provided.
|
|
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
|
|
auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
|
|
rewriter, loc, meta.getBaseBuffer());
|
|
auto offset = meta.getOffset();
|
|
auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
|
|
auto offsetInBytes = arith::MulIOp::create(
|
|
rewriter, loc, offset,
|
|
arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
|
|
auto adjustedBaseAddr = arith::AddIOp::create(
|
|
rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
|
|
auto adjustedAddrI64 = arith::IndexCastOp::create(
|
|
rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
|
|
ndDesc = xegpu::CreateNdDescOp::create(
|
|
rewriter, loc, descType, adjustedAddrI64,
|
|
meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
|
|
}
|
|
|
|
return ndDesc;
|
|
}
|
|
|
|
// Adjusts the strides of a memref according to a given permutation map for
|
|
// vector operations.
|
|
//
|
|
// This function updates the innermost strides in the `strides` array to
|
|
// reflect the permutation specified by `permMap`. The permutation is computed
|
|
// using the inverse and broadcasting-aware version of the permutation map,
|
|
// and is applied to the relevant strides. This ensures that memory accesses
|
|
// are consistent with the logical permutation of vector elements.
|
|
//
|
|
// Example:
|
|
// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
|
|
// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
|
|
// 0]), then after calling this function, the last two strides will be
|
|
// swapped:
|
|
// Original strides: [s0, s1, s2, s3]
|
|
// After permutation: [s0, s1, s3, s2]
|
|
//
|
|
static void adjustStridesForPermutation(AffineMap permMap,
|
|
SmallVectorImpl<Value> &strides) {
|
|
|
|
AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
|
|
SmallVector<unsigned> perms;
|
|
invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
|
|
SmallVector<int64_t> perms64(perms.begin(), perms.end());
|
|
strides = applyPermutation(strides, perms64);
|
|
}
|
|
|
|
// Computes memory strides and a memref offset for vector transfer operations,
|
|
// handling both static and dynamic memrefs while applying permutation
|
|
// transformations for XeGPU lowering.
|
|
template <
|
|
typename OpType,
|
|
typename = std::enable_if_t<llvm::is_one_of<
|
|
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
|
|
vector::GatherOp, vector::ScatterOp>::value>>
|
|
static std::pair<SmallVector<Value>, Value>
|
|
computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
|
|
SmallVector<Value> strides;
|
|
Value baseMemref = xferOp.getBase();
|
|
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
|
|
|
|
Location loc = xferOp.getLoc();
|
|
Value offsetVal = nullptr;
|
|
if (memrefType.hasStaticShape()) {
|
|
int64_t offset;
|
|
SmallVector<int64_t> intStrides;
|
|
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
|
|
return {{}, offsetVal};
|
|
bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
|
|
return ShapedType::isDynamic(strideVal);
|
|
});
|
|
|
|
if (!hasDynamicStrides)
|
|
for (int64_t s : intStrides)
|
|
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
|
|
|
|
if (!ShapedType::isDynamic(offset))
|
|
offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
|
|
}
|
|
|
|
if (strides.empty() || !offsetVal) {
|
|
// For dynamic shape memref, use memref.extract_strided_metadata to get
|
|
// stride values
|
|
unsigned rank = memrefType.getRank();
|
|
Type indexType = rewriter.getIndexType();
|
|
|
|
// Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
|
|
// size0, size1, ..., sizeN-1]
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(MemRefType::get(
|
|
{}, memrefType.getElementType())); // base memref (unranked)
|
|
resultTypes.push_back(indexType); // offset
|
|
|
|
for (unsigned i = 0; i < rank; ++i)
|
|
resultTypes.push_back(indexType); // strides
|
|
|
|
for (unsigned i = 0; i < rank; ++i)
|
|
resultTypes.push_back(indexType); // sizes
|
|
|
|
auto meta = memref::ExtractStridedMetadataOp::create(
|
|
rewriter, loc, resultTypes, baseMemref);
|
|
|
|
if (strides.empty())
|
|
strides.append(meta.getStrides().begin(), meta.getStrides().end());
|
|
|
|
if (!offsetVal)
|
|
offsetVal = meta.getOffset();
|
|
}
|
|
|
|
if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
|
|
vector::TransferWriteOp>::value) {
|
|
AffineMap permMap = xferOp.getPermutationMap();
|
|
// Adjust strides according to the permutation map (e.g., for transpose)
|
|
adjustStridesForPermutation(permMap, strides);
|
|
}
|
|
|
|
return {strides, offsetVal};
|
|
}
|
|
|
|
// This function compute the vectors of localOffsets for scattered load/stores.
|
|
// It is used in the lowering of vector.transfer_read/write to
|
|
// load_gather/store_scatter Example:
|
|
// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
|
|
// %cst {in_bounds = [true, true, true, true]}>} :
|
|
// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
|
|
//
|
|
// %6 = vector.step: vector<4xindex>
|
|
// %7 = vector.step: vector<2xindex>
|
|
// %8 = vector.step: vector<6xindex>
|
|
// %9 = vector.step: vector<32xindex>
|
|
// %10 = arith.mul %6, 384
|
|
// %11 = arith.mul %7, 192
|
|
// %12 = arith.mul %8, 32
|
|
// %13 = arith.mul %9, 1
|
|
// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
|
|
// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
|
|
// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
|
|
// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
|
|
// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
|
|
// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
|
|
// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
|
|
// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
|
|
// %22 = arith.add %18, %19
|
|
// %23 = arith.add %20, %21
|
|
// %local_offsets = arith.add %22, %23
|
|
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
|
|
// %offsets = memref_offset + orig_offset + local_offsets
|
|
static Value computeOffsets(VectorTransferOpInterface xferOp,
|
|
PatternRewriter &rewriter, ArrayRef<Value> strides,
|
|
Value baseOffset) {
|
|
Location loc = xferOp.getLoc();
|
|
VectorType vectorType = xferOp.getVectorType();
|
|
SmallVector<Value> indices(xferOp.getIndices().begin(),
|
|
xferOp.getIndices().end());
|
|
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
|
|
|
// Create vector.step operations for each dimension
|
|
SmallVector<Value> stepVectors;
|
|
llvm::map_to_vector(vectorShape, [&](int64_t dim) {
|
|
auto stepType = VectorType::get({dim}, rewriter.getIndexType());
|
|
auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
|
|
stepVectors.push_back(stepOp);
|
|
return stepOp;
|
|
});
|
|
|
|
// Multiply step vectors by corresponding strides
|
|
size_t memrefRank = strides.size();
|
|
size_t vectorRank = vectorShape.size();
|
|
SmallVector<Value> strideMultiplied;
|
|
for (size_t i = 0; i < vectorRank; ++i) {
|
|
size_t memrefDim = memrefRank - vectorRank + i;
|
|
Value strideValue = strides[memrefDim];
|
|
auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
|
|
auto bcastOp =
|
|
vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
|
|
auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
|
|
strideMultiplied.push_back(mulOp);
|
|
}
|
|
|
|
// Shape cast each multiplied vector to add singleton dimensions
|
|
SmallVector<Value> shapeCasted;
|
|
for (size_t i = 0; i < vectorRank; ++i) {
|
|
SmallVector<int64_t> newShape(vectorRank, 1);
|
|
newShape[i] = vectorShape[i];
|
|
auto newType = VectorType::get(newShape, rewriter.getIndexType());
|
|
auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
|
|
strideMultiplied[i]);
|
|
shapeCasted.push_back(castOp);
|
|
}
|
|
|
|
// Broadcast each shape-casted vector to full vector shape
|
|
SmallVector<Value> broadcasted;
|
|
auto fullIndexVectorType =
|
|
VectorType::get(vectorShape, rewriter.getIndexType());
|
|
for (Value shapeCastVal : shapeCasted) {
|
|
auto broadcastOp = vector::BroadcastOp::create(
|
|
rewriter, loc, fullIndexVectorType, shapeCastVal);
|
|
broadcasted.push_back(broadcastOp);
|
|
}
|
|
|
|
// Add all broadcasted vectors together to compute local offsets
|
|
Value localOffsets = broadcasted[0];
|
|
for (size_t i = 1; i < broadcasted.size(); ++i)
|
|
localOffsets =
|
|
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
|
|
|
|
// Compute base offset from transfer read indices
|
|
for (size_t i = 0; i < indices.size(); ++i) {
|
|
Value strideVal = strides[i];
|
|
Value offsetContrib =
|
|
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
|
|
baseOffset =
|
|
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
|
|
}
|
|
// Broadcast base offset to match vector shape
|
|
Value bcastBase = vector::BroadcastOp::create(
|
|
rewriter, loc, fullIndexVectorType, baseOffset);
|
|
localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
|
|
return localOffsets;
|
|
}
|
|
|
|
// Compute the element-wise offsets for vector.gather or vector.scatter ops.
|
|
//
|
|
// This function linearizes the base offsets of the gather/scatter operation
|
|
// and combines them with the per-element indices to produce a final vector of
|
|
// memory offsets.
|
|
template <
|
|
typename OpType,
|
|
typename = std::enable_if_t<llvm::is_one_of<
|
|
std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
|
|
static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
|
|
ArrayRef<Value> strides, Value baseOffset) {
|
|
Location loc = gatScatOp.getLoc();
|
|
SmallVector<Value> offsets = gatScatOp.getOffsets();
|
|
for (size_t i = 0; i < offsets.size(); ++i) {
|
|
Value offsetContrib =
|
|
arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
|
|
baseOffset =
|
|
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
|
|
}
|
|
Value indices = gatScatOp.getIndices();
|
|
VectorType vecType = cast<VectorType>(indices.getType());
|
|
|
|
Value strideVector =
|
|
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
|
|
.getResult();
|
|
Value stridedIndices =
|
|
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
|
|
|
|
Value baseVector =
|
|
vector::BroadcastOp::create(
|
|
rewriter, loc,
|
|
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
|
|
baseOffset)
|
|
.getResult();
|
|
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
|
|
.getResult();
|
|
}
|
|
|
|
// Collapses shapes of a nD memref to the target rank while applying offsets for
|
|
// the collapsed dimensions. Returns the new memref value and the remaining
|
|
// offsets for the last targetRank dimensions. For example:
|
|
// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
|
|
// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
|
|
static std::pair<Value, SmallVector<OpFoldResult>>
|
|
convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
|
|
Value memref,
|
|
SmallVector<OpFoldResult> offsets,
|
|
int64_t targetRank) {
|
|
auto memrefType = cast<MemRefType>(memref.getType());
|
|
unsigned rank = memrefType.getRank();
|
|
|
|
if (rank <= targetRank)
|
|
return {memref, offsets};
|
|
|
|
int64_t numCombinedDims = rank - targetRank;
|
|
SmallVector<OpFoldResult> subviewOffsets;
|
|
SmallVector<OpFoldResult> subviewSizes;
|
|
SmallVector<OpFoldResult> subviewStrides;
|
|
|
|
// For the combined dimensions: use the provided offsets, size=1, stride=1
|
|
for (unsigned i = 0; i < numCombinedDims; ++i) {
|
|
subviewOffsets.push_back(offsets[i]);
|
|
subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
|
|
subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
|
|
}
|
|
|
|
// For the last targetRank dimensions: offset=0, use full size, stride=1
|
|
SmallVector<int64_t> resultShape;
|
|
auto originalShape = memrefType.getShape();
|
|
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
|
|
for (unsigned i = numCombinedDims; i < rank; ++i) {
|
|
subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
|
if (ShapedType::isDynamic(originalShape[i])) {
|
|
subviewSizes.push_back(meta.getSizes()[i]);
|
|
resultShape.push_back(ShapedType::kDynamic);
|
|
} else {
|
|
subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
|
|
resultShape.push_back(originalShape[i]);
|
|
}
|
|
subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
|
|
}
|
|
|
|
auto resultType = memref::SubViewOp::inferRankReducedResultType(
|
|
resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
|
|
auto subviewOp =
|
|
memref::SubViewOp::create(rewriter, loc, resultType, memref,
|
|
subviewOffsets, subviewSizes, subviewStrides);
|
|
|
|
// Return the remaining offsets for the last targetRank dimensions
|
|
SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
|
|
offsets.end());
|
|
return {subviewOp.getResult(), newOffsets};
|
|
}
|
|
|
|
template <
|
|
typename OpType,
|
|
typename = std::enable_if_t<llvm::is_one_of<
|
|
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
|
|
vector::GatherOp, vector::ScatterOp>::value>>
|
|
// Convert memref to i64 base pointer
|
|
static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
|
|
Location loc = xferOp.getLoc();
|
|
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
|
|
rewriter, loc, xferOp.getBase())
|
|
.getResult();
|
|
return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
|
|
indexPtr)
|
|
.getResult();
|
|
}
|
|
|
|
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
|
|
PatternRewriter &rewriter) {
|
|
|
|
Location loc = readOp.getLoc();
|
|
VectorType vectorType = readOp.getVectorType();
|
|
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
|
auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
|
|
if (!memrefType)
|
|
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
|
|
|
|
auto meta = computeMemrefMeta(readOp, rewriter);
|
|
if (meta.first.empty())
|
|
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
|
|
|
|
Value localOffsets =
|
|
computeOffsets(readOp, rewriter, meta.first, meta.second);
|
|
|
|
Value flatMemref = memrefToIndexPtr(readOp, rewriter);
|
|
|
|
Value mask = vector::ConstantMaskOp::create(
|
|
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
|
|
vectorShape);
|
|
auto gatherOp = xegpu::LoadGatherOp::create(
|
|
rewriter, loc, vectorType, flatMemref, localOffsets, mask,
|
|
/*chunk_size=*/IntegerAttr{},
|
|
/*l1_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l2_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l3_hint=*/xegpu::CachePolicyAttr{},
|
|
/*layout=*/nullptr);
|
|
|
|
rewriter.replaceOp(readOp, gatherOp.getResult());
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
|
|
PatternRewriter &rewriter) {
|
|
|
|
Location loc = writeOp.getLoc();
|
|
VectorType vectorType = writeOp.getVectorType();
|
|
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
|
|
|
auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
|
|
if (!memrefType)
|
|
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
|
|
|
|
auto meta = computeMemrefMeta(writeOp, rewriter);
|
|
if (meta.first.empty())
|
|
return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
|
|
|
|
Value localOffsets =
|
|
computeOffsets(writeOp, rewriter, meta.first, meta.second);
|
|
|
|
Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
|
|
|
|
Value mask = vector::ConstantMaskOp::create(
|
|
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
|
|
vectorShape);
|
|
xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
|
|
localOffsets, mask,
|
|
/*chunk_size=*/IntegerAttr{},
|
|
/*l1_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l2_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l3_hint=*/xegpu::CachePolicyAttr{},
|
|
/*layout=*/nullptr);
|
|
rewriter.eraseOp(writeOp);
|
|
return success();
|
|
}
|
|
|
|
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = readOp.getLoc();
|
|
|
|
if (failed(transferPreconditions(rewriter, readOp)))
|
|
return failure();
|
|
auto readMemTy = cast<MemRefType>(readOp.getShapedType());
|
|
VectorType loadedVecTy = readOp.getVectorType();
|
|
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
|
|
// Check if the memref has address space 3 (shared local memory)
|
|
bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
|
|
// Handle the SLM case.
|
|
if (isSharedMemory) {
|
|
// If the memref is SLM only support 2D case for now.
|
|
if (loadedVecTy.getRank() != 2)
|
|
return rewriter.notifyMatchFailure(
|
|
readOp, "Only 2D vector loads are supported for SLM");
|
|
AffineMap readMap = readOp.getPermutationMap();
|
|
if (!readMap.isMinorIdentity())
|
|
return rewriter.notifyMatchFailure(
|
|
readOp, "Transpose not supported for SLM loads");
|
|
// Out of bounds case is not supported for SLM loads.
|
|
if (isOutOfBounds)
|
|
return rewriter.notifyMatchFailure(
|
|
readOp, "Out-of-bounds access is not supported for SLM loads");
|
|
|
|
// Create mem_desc for SLM
|
|
auto memDescType =
|
|
xegpu::MemDescType::get(rewriter.getContext(), readMemTy.getShape(),
|
|
readMemTy.getElementType(),
|
|
/*mem_layout=*/nullptr);
|
|
auto createMemDescOp = xegpu::CreateMemDescOp::create(
|
|
rewriter, loc, memDescType, readOp.getBase());
|
|
// Convert indices to OpFoldResult for LoadMatrixOp
|
|
SmallVector<OpFoldResult> indices =
|
|
getAsOpFoldResult(readOp.getIndices());
|
|
auto loadMatrixOp = xegpu::LoadMatrixOp::create(
|
|
rewriter, loc, loadedVecTy, createMemDescOp.getResult(), indices,
|
|
/*layout=*/nullptr);
|
|
|
|
rewriter.replaceOp(readOp, loadMatrixOp.getResult());
|
|
return success();
|
|
}
|
|
|
|
// TODO:This check needs to be replaced with proper uArch capability check
|
|
auto chip = xegpu::getChipStr(readOp);
|
|
// Lower to scattered load Op if the target HW doesn't have 2d block load
|
|
// support and the load is not from shared memory.
|
|
if ((chip != "pvc" && chip != "bmg") ||
|
|
readOp.getVectorType().getRank() > 2) {
|
|
|
|
// TODO: add support for OutOfBound access
|
|
if (isOutOfBounds)
|
|
return failure();
|
|
return lowerToScatteredLoadOp(readOp, rewriter);
|
|
}
|
|
|
|
// Handle the 1D non-SLM case using load.gather.
|
|
if (loadedVecTy.getRank() == 1 && !isOutOfBounds)
|
|
return lowerToScatteredLoadOp(readOp, rewriter);
|
|
|
|
// Perform common data transfer checks.
|
|
// TODO: Maybe too strict for SLM case.
|
|
if (failed(
|
|
storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
|
|
return failure();
|
|
|
|
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
|
|
return rewriter.notifyMatchFailure(
|
|
readOp, "Unsupported non-zero padded out-of-bounds read");
|
|
|
|
AffineMap readMap = readOp.getPermutationMap();
|
|
bool isTransposeLoad = !readMap.isMinorIdentity();
|
|
auto elementType = loadedVecTy.getElementType();
|
|
|
|
SmallVector<int64_t> descShape(loadedVecTy.getShape());
|
|
if (isTransposeLoad) {
|
|
// If load is transposed, then the shape of the source-descriptor
|
|
// is the opposite from the result-shape. Applying the permutation
|
|
// to get the reversive shape.
|
|
auto inversedMap = inversePermutation(readMap);
|
|
descShape = applyPermutationMap(inversedMap, loadedVecTy.getShape());
|
|
loadedVecTy = VectorType::get(descShape, elementType);
|
|
}
|
|
auto descType = xegpu::TensorDescType::get(
|
|
descShape, elementType, /*array_length=*/1,
|
|
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
|
|
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
|
|
rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
|
|
loadedVecTy.getRank());
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
|
|
|
|
Operation *loadedOp =
|
|
xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc, indices,
|
|
/*packed=*/nullptr, /*transpose=*/nullptr,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
|
/*layout=*/nullptr);
|
|
if (isTransposeLoad) {
|
|
// Transposing the loaded vector with a separate vector.transpose
|
|
// operation
|
|
auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
|
|
SmallVector<int64_t> perm(range.begin(), range.end());
|
|
auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
|
|
loadedOp = vector::TransposeOp::create(
|
|
rewriter, loc, loadedOp->getResult(0), permApplied);
|
|
}
|
|
rewriter.replaceOp(readOp, loadedOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TransferWriteLowering
|
|
: public OpRewritePattern<vector::TransferWriteOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = writeOp.getLoc();
|
|
|
|
if (failed(transferPreconditions(rewriter, writeOp)))
|
|
return failure();
|
|
// Perform common data transfer checks.
|
|
VectorType vecTy = writeOp.getVectorType();
|
|
auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
|
|
// Check if the memref has address space 3 (shared local memory)
|
|
bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
|
|
|
|
// For shared local memory (address space 3), use create_mem_desc +
|
|
// store_matrix
|
|
if (isSharedMemory) {
|
|
// Only support 2D case for now.
|
|
if (vecTy.getRank() != 2)
|
|
return rewriter.notifyMatchFailure(
|
|
writeOp, "Only 2D vector stores are supported for SLM");
|
|
// Create mem_desc for SLM
|
|
auto memDescType =
|
|
xegpu::MemDescType::get(rewriter.getContext(), writeMemTy.getShape(),
|
|
writeMemTy.getElementType(),
|
|
/*mem_layout=*/nullptr);
|
|
|
|
auto createMemDescOp = xegpu::CreateMemDescOp::create(
|
|
rewriter, loc, memDescType, writeOp.getBase());
|
|
|
|
// Convert indices to OpFoldResult for StoreMatrixOp
|
|
SmallVector<OpFoldResult> indices =
|
|
getAsOpFoldResult(writeOp.getIndices());
|
|
|
|
xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
|
|
createMemDescOp.getResult(), indices,
|
|
/*layout=*/nullptr);
|
|
|
|
rewriter.eraseOp(writeOp);
|
|
return success();
|
|
}
|
|
|
|
// TODO:This check needs to be replaced with proper uArch capability check
|
|
auto chip = xegpu::getChipStr(writeOp);
|
|
// Lower to scattered store Op if the target HW doesn't have 2d block
|
|
// store support and the memref is not SLM.
|
|
if ((chip != "pvc" && chip != "bmg") ||
|
|
writeOp.getVectorType().getRank() > 2) {
|
|
|
|
// TODO: add support for OutOfBound access
|
|
if (writeOp.hasOutOfBoundsDim())
|
|
return failure();
|
|
return lowerToScatteredStoreOp(writeOp, rewriter);
|
|
}
|
|
|
|
if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
|
|
return failure();
|
|
|
|
AffineMap map = writeOp.getPermutationMap();
|
|
if (!map.isMinorIdentity())
|
|
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
|
|
|
|
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
|
|
rewriter, loc, writeOp.getBase(),
|
|
getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
|
|
|
|
auto descType = xegpu::TensorDescType::get(
|
|
vecTy.getShape(), vecTy.getElementType(),
|
|
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
|
|
xegpu::MemorySpace::Global);
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
|
|
|
|
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
|
|
ndDesc, indices,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
|
/*layout=*/nullptr);
|
|
rewriter.replaceOp(writeOp, storeOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
|
|
if (!srcTy)
|
|
return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
|
|
|
|
Location loc = gatherOp.getLoc();
|
|
VectorType vectorType = gatherOp.getVectorType();
|
|
|
|
auto meta = computeMemrefMeta(gatherOp, rewriter);
|
|
if (meta.first.empty())
|
|
return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
|
|
|
|
Value localOffsets =
|
|
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
|
|
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
|
|
|
|
auto xeGatherOp = xegpu::LoadGatherOp::create(
|
|
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
|
|
/*chunk_size=*/IntegerAttr{},
|
|
/*l1_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l2_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l3_hint=*/xegpu::CachePolicyAttr{},
|
|
/*layout=*/nullptr);
|
|
|
|
auto selectOp =
|
|
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
|
|
xeGatherOp.getResult(), gatherOp.getPassThru());
|
|
rewriter.replaceOp(gatherOp, selectOp.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
|
|
if (!srcTy)
|
|
return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
|
|
|
|
Location loc = scatterOp.getLoc();
|
|
auto meta = computeMemrefMeta(scatterOp, rewriter);
|
|
if (meta.first.empty())
|
|
return rewriter.notifyMatchFailure(scatterOp,
|
|
"Failed to compute strides");
|
|
|
|
Value localOffsets =
|
|
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
|
|
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
|
|
|
|
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
|
|
flatMemref, localOffsets, scatterOp.getMask(),
|
|
/*chunk_size=*/IntegerAttr{},
|
|
/*l1_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l2_hint=*/xegpu::CachePolicyAttr{},
|
|
/*l3_hint=*/xegpu::CachePolicyAttr{},
|
|
/*layout=*/nullptr);
|
|
rewriter.eraseOp(scatterOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = loadOp.getLoc();
|
|
|
|
VectorType vecTy = loadOp.getResult().getType();
|
|
MemRefType memTy = loadOp.getBase().getType();
|
|
if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
|
|
return failure();
|
|
|
|
// Boundary check is available only for block instructions.
|
|
bool boundaryCheck = vecTy.getRank() > 1;
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
|
|
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
|
|
rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
|
|
vecTy.getRank());
|
|
|
|
auto descType = xegpu::TensorDescType::get(
|
|
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
|
|
boundaryCheck, xegpu::MemorySpace::Global);
|
|
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
|
|
auto loadNdOp =
|
|
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
|
|
/*packed=*/nullptr, /*transpose=*/nullptr,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
|
/*layout=*/nullptr);
|
|
rewriter.replaceOp(loadOp, loadNdOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = storeOp.getLoc();
|
|
|
|
TypedValue<VectorType> vector = storeOp.getValueToStore();
|
|
VectorType vecTy = vector.getType();
|
|
MemRefType memTy = storeOp.getBase().getType();
|
|
if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
|
|
return failure();
|
|
|
|
// Boundary check is available only for block instructions.
|
|
bool boundaryCheck = vecTy.getRank() > 1;
|
|
|
|
auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
|
|
rewriter, loc, storeOp.getBase(),
|
|
getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
|
|
|
|
auto descType = xegpu::TensorDescType::get(
|
|
vecTy.getShape(), vecTy.getElementType(),
|
|
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
|
|
|
|
// By default, no specific caching policy is assigned.
|
|
xegpu::CachePolicyAttr hint = nullptr;
|
|
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
|
rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
|
|
|
|
auto storeNdOp =
|
|
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
|
|
/*l1_hint=*/hint,
|
|
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
|
/*layout=*/nullptr);
|
|
|
|
rewriter.replaceOp(storeOp, storeNdOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = contractOp.getLoc();
|
|
|
|
if (contractOp.getKind() != vector::CombiningKind::ADD)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Expects add combining kind");
|
|
|
|
TypedValue<Type> acc = contractOp.getAcc();
|
|
VectorType accType = dyn_cast<VectorType>(acc.getType());
|
|
if (!accType || accType.getRank() != 2)
|
|
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
|
|
|
|
// Accept only plain 2D data layout.
|
|
// VNNI packing is applied to DPAS as a separate lowering step.
|
|
TypedValue<VectorType> lhs = contractOp.getLhs();
|
|
TypedValue<VectorType> rhs = contractOp.getRhs();
|
|
if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Expects lhs and rhs 2D vectors");
|
|
|
|
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
|
|
|
|
auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
|
|
TypeRange{contractOp.getResultType()},
|
|
ValueRange{lhs, rhs, acc});
|
|
rewriter.replaceOp(contractOp, dpasOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertVectorToXeGPUPass
|
|
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorToXeGPUConversionPatterns(patterns);
|
|
populatePrepareVectorToMMAPatterns(patterns);
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateVectorToXeGPUConversionPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
|
|
ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
|
|
patterns.getContext());
|
|
}
|