//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering of vector operations to XeGPU dialect ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include #include namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Return true if value represents a zero constant. static bool isZeroConstant(Value val) { auto constant = val.getDefiningOp(); if (!constant) return false; return TypeSwitch(constant.getValue()) .Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); }) .Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); }) .Default(false); } static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, Operation *op, VectorType vecTy, MemRefType memTy) { // Validate only vector as the basic vector store and load ops guarantee // XeGPU-compatible memref source. unsigned vecRank = vecTy.getRank(); if (!(vecRank == 1 || vecRank == 2)) return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector"); if (!vecTy.getElementType().isIntOrFloat()) return rewriter.notifyMatchFailure( op, "Expected scalar type with known bitwidth"); // XeGPU requires the memref to have a scalar integer or float element type. // Memrefs with vector element types (e.g. memref>) are not // supported because createNdDescriptor computes byte offsets using // getElementTypeBitWidth(), which asserts on non-integer/float types. if (!memTy.getElementType().isIntOrFloat()) return rewriter.notifyMatchFailure( op, "Unsupported memref element type: expected integer or float"); return success(); } static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp) { if (xferOp.getMask()) return rewriter.notifyMatchFailure(xferOp, "Masked transfer is not supported"); auto srcTy = dyn_cast(xferOp.getShapedType()); if (!srcTy) return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); // Validate further transfer op semantics. SmallVector strides; int64_t offset; if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) return rewriter.notifyMatchFailure( xferOp, "Buffer must be contiguous in the innermost dimension"); VectorType vecTy = xferOp.getVectorType(); unsigned vecRank = vecTy.getRank(); if (xferOp.hasOutOfBoundsDim() && vecRank < 2) return rewriter.notifyMatchFailure( xferOp, "Boundary check is available only for block instructions."); AffineMap map = xferOp.getPermutationMap(); if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); unsigned numInputDims = map.getNumInputs(); for (AffineExpr expr : map.getResults().take_back(vecRank)) { auto dim = dyn_cast(expr); if (dim.getPosition() < (numInputDims - vecRank)) return rewriter.notifyMatchFailure( xferOp, "Only the innermost dimensions can be accessed"); } return success(); } static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::TensorDescType descType, TypedValue src) { MemRefType srcTy = src.getType(); assert(srcTy.isStrided() && "Expected strided memref type"); auto [strides, offset] = srcTy.getStridesAndOffset(); bool isStatic = true; // Memref is dynamic if any of its shape, offset or strides is dynamic. if (!srcTy.hasStaticShape()) isStatic = false; if (!ShapedType::isStatic(offset)) isStatic = false; for (auto stride : strides) { if (!ShapedType::isStatic(stride)) { isStatic = false; break; } } xegpu::CreateNdDescOp ndDesc; if (isStatic) { ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src); } else { // In case of ranked dynamic memref, instead of passing on the memref, // i64 base address, source's offset, shape and strides have to be // explicitly provided. auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src); auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create( rewriter, loc, meta.getBaseBuffer()); auto offset = meta.getOffset(); auto elemByteSize = srcTy.getElementTypeBitWidth() / 8; auto offsetInBytes = arith::MulIOp::create( rewriter, loc, offset, arith::ConstantIndexOp::create(rewriter, loc, elemByteSize)); auto adjustedBaseAddr = arith::AddIOp::create( rewriter, loc, baseAddrIndex.getResult(), offsetInBytes); auto adjustedAddrI64 = arith::IndexCastOp::create( rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr); ndDesc = xegpu::CreateNdDescOp::create( rewriter, loc, descType, adjustedAddrI64, meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides()); } return ndDesc; } // Adjusts the strides of a memref according to a given permutation map for // vector operations. // // This function updates the innermost strides in the `strides` array to // reflect the permutation specified by `permMap`. The permutation is computed // using the inverse and broadcasting-aware version of the permutation map, // and is applied to the relevant strides. This ensures that memory accesses // are consistent with the logical permutation of vector elements. // // Example: // Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`. // If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1, // 0]), then after calling this function, the last two strides will be // swapped: // Original strides: [s0, s1, s2, s3] // After permutation: [s0, s1, s3, s2] // static void adjustStridesForPermutation(AffineMap permMap, SmallVectorImpl &strides) { AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap); SmallVector perms; invMap.isPermutationOfMinorIdentityWithBroadcasting(perms); SmallVector perms64(perms.begin(), perms.end()); strides = applyPermutation(strides, perms64); } // Computes memory strides and a memref offset for vector transfer operations, // handling both static and dynamic memrefs while applying permutation // transformations for XeGPU lowering. template < typename OpType, typename = std::enable_if_t, vector::TransferReadOp, vector::TransferWriteOp, vector::GatherOp, vector::ScatterOp>::value>> static std::pair, Value> computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) { SmallVector strides; Value baseMemref = xferOp.getBase(); MemRefType memrefType = dyn_cast(baseMemref.getType()); Location loc = xferOp.getLoc(); Value offsetVal = nullptr; if (memrefType.hasStaticShape()) { int64_t offset; SmallVector intStrides; if (failed(memrefType.getStridesAndOffset(intStrides, offset))) return {{}, offsetVal}; bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) { return ShapedType::isDynamic(strideVal); }); if (!hasDynamicStrides) for (int64_t s : intStrides) strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); if (!ShapedType::isDynamic(offset)) offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset); } if (strides.empty() || !offsetVal) { // For dynamic shape memref, use memref.extract_strided_metadata to get // stride values unsigned rank = memrefType.getRank(); Type indexType = rewriter.getIndexType(); // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1, // size0, size1, ..., sizeN-1] SmallVector resultTypes; resultTypes.push_back(MemRefType::get( {}, memrefType.getElementType())); // base memref (unranked) resultTypes.push_back(indexType); // offset for (unsigned i = 0; i < rank; ++i) resultTypes.push_back(indexType); // strides for (unsigned i = 0; i < rank; ++i) resultTypes.push_back(indexType); // sizes auto meta = memref::ExtractStridedMetadataOp::create( rewriter, loc, resultTypes, baseMemref); if (strides.empty()) strides.append(meta.getStrides().begin(), meta.getStrides().end()); if (!offsetVal) offsetVal = meta.getOffset(); } if constexpr (llvm::is_one_of, vector::TransferReadOp, vector::TransferWriteOp>::value) { AffineMap permMap = xferOp.getPermutationMap(); // Adjust strides according to the permutation map (e.g., for transpose) adjustStridesForPermutation(permMap, strides); } return {strides, offsetVal}; } // This function compute the vectors of localOffsets for scattered load/stores. // It is used in the lowering of vector.transfer_read/write to // load_gather/store_scatter Example: // %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], // %cst {in_bounds = [true, true, true, true]}>} : // memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16> // // %6 = vector.step: vector<4xindex> // %7 = vector.step: vector<2xindex> // %8 = vector.step: vector<6xindex> // %9 = vector.step: vector<32xindex> // %10 = arith.mul %6, 384 // %11 = arith.mul %7, 192 // %12 = arith.mul %8, 32 // %13 = arith.mul %9, 1 // %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16> // %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16> // %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16> // %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16> // %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex> // %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex> // %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex> // %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex> // %22 = arith.add %18, %19 // %23 = arith.add %20, %21 // %local_offsets = arith.add %22, %23 // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map // %offsets = memref_offset + orig_offset + local_offsets static Value computeOffsets(VectorTransferOpInterface xferOp, PatternRewriter &rewriter, ArrayRef strides, Value baseOffset) { Location loc = xferOp.getLoc(); VectorType vectorType = xferOp.getVectorType(); SmallVector indices(xferOp.getIndices().begin(), xferOp.getIndices().end()); ArrayRef vectorShape = vectorType.getShape(); // Create vector.step operations for each dimension SmallVector stepVectors; llvm::map_to_vector(vectorShape, [&](int64_t dim) { auto stepType = VectorType::get({dim}, rewriter.getIndexType()); auto stepOp = vector::StepOp::create(rewriter, loc, stepType); stepVectors.push_back(stepOp); return stepOp; }); // Multiply step vectors by corresponding strides size_t memrefRank = strides.size(); size_t vectorRank = vectorShape.size(); SmallVector strideMultiplied; for (size_t i = 0; i < vectorRank; ++i) { size_t memrefDim = memrefRank - vectorRank + i; Value strideValue = strides[memrefDim]; auto mulType = dyn_cast(stepVectors[i].getType()); auto bcastOp = vector::BroadcastOp::create(rewriter, loc, mulType, strideValue); auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp); strideMultiplied.push_back(mulOp); } // Shape cast each multiplied vector to add singleton dimensions SmallVector shapeCasted; for (size_t i = 0; i < vectorRank; ++i) { SmallVector newShape(vectorRank, 1); newShape[i] = vectorShape[i]; auto newType = VectorType::get(newShape, rewriter.getIndexType()); auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType, strideMultiplied[i]); shapeCasted.push_back(castOp); } // Broadcast each shape-casted vector to full vector shape SmallVector broadcasted; auto fullIndexVectorType = VectorType::get(vectorShape, rewriter.getIndexType()); for (Value shapeCastVal : shapeCasted) { auto broadcastOp = vector::BroadcastOp::create( rewriter, loc, fullIndexVectorType, shapeCastVal); broadcasted.push_back(broadcastOp); } // Add all broadcasted vectors together to compute local offsets Value localOffsets = broadcasted[0]; for (size_t i = 1; i < broadcasted.size(); ++i) localOffsets = arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); // Compute base offset from transfer read indices for (size_t i = 0; i < indices.size(); ++i) { Value strideVal = strides[i]; Value offsetContrib = arith::MulIOp::create(rewriter, loc, indices[i], strideVal); baseOffset = arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); } // Broadcast base offset to match vector shape Value bcastBase = vector::BroadcastOp::create( rewriter, loc, fullIndexVectorType, baseOffset); localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); return localOffsets; } // Compute the element-wise offsets for vector.gather or vector.scatter ops. // // This function linearizes the base offsets of the gather/scatter operation // and combines them with the per-element indices to produce a final vector of // memory offsets. template < typename OpType, typename = std::enable_if_t, vector::GatherOp, vector::ScatterOp>::value>> static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp, ArrayRef strides, Value baseOffset) { Location loc = gatScatOp.getLoc(); SmallVector offsets = gatScatOp.getOffsets(); for (size_t i = 0; i < offsets.size(); ++i) { Value offsetContrib = arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]); baseOffset = arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); } Value indices = gatScatOp.getIndices(); VectorType vecType = cast(indices.getType()); Value strideVector = vector::BroadcastOp::create(rewriter, loc, vecType, strides.back()) .getResult(); Value stridedIndices = arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult(); Value baseVector = vector::BroadcastOp::create( rewriter, loc, VectorType::get(vecType.getShape(), rewriter.getIndexType()), baseOffset) .getResult(); return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices) .getResult(); } // Collapses shapes of a nD memref to the target rank while applying offsets for // the collapsed dimensions. Returns the new memref value and the remaining // offsets for the last targetRank dimensions. For example: // input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3], // output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3] static std::pair> convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc, Value memref, SmallVector offsets, int64_t targetRank) { auto memrefType = cast(memref.getType()); unsigned rank = memrefType.getRank(); if (rank <= targetRank) return {memref, offsets}; int64_t numCombinedDims = rank - targetRank; SmallVector subviewOffsets; SmallVector subviewSizes; SmallVector subviewStrides; // For the combined dimensions: use the provided offsets, size=1, stride=1 for (unsigned i = 0; i < numCombinedDims; ++i) { subviewOffsets.push_back(offsets[i]); subviewSizes.push_back(rewriter.getI64IntegerAttr(1)); subviewStrides.push_back(rewriter.getI64IntegerAttr(1)); } // For the last targetRank dimensions: offset=0, use full size, stride=1 SmallVector resultShape; auto originalShape = memrefType.getShape(); auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); for (unsigned i = numCombinedDims; i < rank; ++i) { subviewOffsets.push_back(rewriter.getI64IntegerAttr(0)); if (ShapedType::isDynamic(originalShape[i])) { subviewSizes.push_back(meta.getSizes()[i]); resultShape.push_back(ShapedType::kDynamic); } else { subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i])); resultShape.push_back(originalShape[i]); } subviewStrides.push_back(rewriter.getI64IntegerAttr(1)); } auto resultType = memref::SubViewOp::inferRankReducedResultType( resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides); auto subviewOp = memref::SubViewOp::create(rewriter, loc, resultType, memref, subviewOffsets, subviewSizes, subviewStrides); // Return the remaining offsets for the last targetRank dimensions SmallVector newOffsets(offsets.begin() + numCombinedDims, offsets.end()); return {subviewOp.getResult(), newOffsets}; } template < typename OpType, typename = std::enable_if_t, vector::TransferReadOp, vector::TransferWriteOp, vector::GatherOp, vector::ScatterOp>::value>> // Convert memref to i64 base pointer static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) { Location loc = xferOp.getLoc(); auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create( rewriter, loc, xferOp.getBase()) .getResult(); return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), indexPtr) .getResult(); } static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, PatternRewriter &rewriter) { Location loc = readOp.getLoc(); VectorType vectorType = readOp.getVectorType(); ArrayRef vectorShape = vectorType.getShape(); auto memrefType = dyn_cast(readOp.getShapedType()); if (!memrefType) return rewriter.notifyMatchFailure(readOp, "Expected memref source"); auto meta = computeMemrefMeta(readOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); Value localOffsets = computeOffsets(readOp, rewriter, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(readOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape); auto gatherOp = xegpu::LoadGatherOp::create( rewriter, loc, vectorType, flatMemref, localOffsets, mask, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}, /*layout=*/nullptr); rewriter.replaceOp(readOp, gatherOp.getResult()); return success(); } static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) { Location loc = writeOp.getLoc(); VectorType vectorType = writeOp.getVectorType(); ArrayRef vectorShape = vectorType.getShape(); auto memrefType = dyn_cast(writeOp.getShapedType()); if (!memrefType) return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); auto meta = computeMemrefMeta(writeOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides"); Value localOffsets = computeOffsets(writeOp, rewriter, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(writeOp, rewriter); Value mask = vector::ConstantMaskOp::create( rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape); xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref, localOffsets, mask, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}, /*layout=*/nullptr); rewriter.eraseOp(writeOp); return success(); } struct TransferReadLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { Location loc = readOp.getLoc(); if (failed(transferPreconditions(rewriter, readOp))) return failure(); auto readMemTy = cast(readOp.getShapedType()); VectorType loadedVecTy = readOp.getVectorType(); bool isOutOfBounds = readOp.hasOutOfBoundsDim(); // Check if the memref has address space 3 (shared local memory) bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy); // Handle the SLM case. if (isSharedMemory) { // If the memref is SLM only support 2D case for now. if (loadedVecTy.getRank() != 2) return rewriter.notifyMatchFailure( readOp, "Only 2D vector loads are supported for SLM"); AffineMap readMap = readOp.getPermutationMap(); if (!readMap.isMinorIdentity()) return rewriter.notifyMatchFailure( readOp, "Transpose not supported for SLM loads"); // Out of bounds case is not supported for SLM loads. if (isOutOfBounds) return rewriter.notifyMatchFailure( readOp, "Out-of-bounds access is not supported for SLM loads"); // Create mem_desc for SLM auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), readMemTy.getShape(), readMemTy.getElementType(), /*mem_layout=*/nullptr); auto createMemDescOp = xegpu::CreateMemDescOp::create( rewriter, loc, memDescType, readOp.getBase()); // Convert indices to OpFoldResult for LoadMatrixOp SmallVector indices = getAsOpFoldResult(readOp.getIndices()); auto loadMatrixOp = xegpu::LoadMatrixOp::create( rewriter, loc, loadedVecTy, createMemDescOp.getResult(), indices, /*layout=*/nullptr); rewriter.replaceOp(readOp, loadMatrixOp.getResult()); return success(); } // TODO:This check needs to be replaced with proper uArch capability check auto chip = xegpu::getChipStr(readOp); // Lower to scattered load Op if the target HW doesn't have 2d block load // support and the load is not from shared memory. if ((chip != "pvc" && chip != "bmg") || readOp.getVectorType().getRank() > 2) { // TODO: add support for OutOfBound access if (isOutOfBounds) return failure(); return lowerToScatteredLoadOp(readOp, rewriter); } // Handle the 1D non-SLM case using load.gather. if (loadedVecTy.getRank() == 1 && !isOutOfBounds) return lowerToScatteredLoadOp(readOp, rewriter); // Perform common data transfer checks. // TODO: Maybe too strict for SLM case. if (failed( storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy))) return failure(); if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) return rewriter.notifyMatchFailure( readOp, "Unsupported non-zero padded out-of-bounds read"); AffineMap readMap = readOp.getPermutationMap(); bool isTransposeLoad = !readMap.isMinorIdentity(); auto elementType = loadedVecTy.getElementType(); SmallVector descShape(loadedVecTy.getShape()); if (isTransposeLoad) { // If load is transposed, then the shape of the source-descriptor // is the opposite from the result-shape. Applying the permutation // to get the reversive shape. auto inversedMap = inversePermutation(readMap); descShape = applyPermutationMap(inversedMap, loadedVecTy.getShape()); loadedVecTy = VectorType::get(descShape, elementType); } auto descType = xegpu::TensorDescType::get( descShape, elementType, /*array_length=*/1, /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); auto [src, indices] = convertMemrefAndOffsetsToTargetRank( rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()), loadedVecTy.getRank()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, dyn_cast>(src)); Operation *loadedOp = xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc, indices, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint, /*layout=*/nullptr); if (isTransposeLoad) { // Transposing the loaded vector with a separate vector.transpose // operation auto range = llvm::seq(0, readMap.getResults().size()); SmallVector perm(range.begin(), range.end()); auto permApplied = applyPermutationMap(readMap, perm); loadedOp = vector::TransposeOp::create( rewriter, loc, loadedOp->getResult(0), permApplied); } rewriter.replaceOp(readOp, loadedOp); return success(); } }; struct TransferWriteLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { Location loc = writeOp.getLoc(); if (failed(transferPreconditions(rewriter, writeOp))) return failure(); // Perform common data transfer checks. VectorType vecTy = writeOp.getVectorType(); auto writeMemTy = cast(writeOp.getShapedType()); // Check if the memref has address space 3 (shared local memory) bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy); // For shared local memory (address space 3), use create_mem_desc + // store_matrix if (isSharedMemory) { // Only support 2D case for now. if (vecTy.getRank() != 2) return rewriter.notifyMatchFailure( writeOp, "Only 2D vector stores are supported for SLM"); // Create mem_desc for SLM auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), writeMemTy.getShape(), writeMemTy.getElementType(), /*mem_layout=*/nullptr); auto createMemDescOp = xegpu::CreateMemDescOp::create( rewriter, loc, memDescType, writeOp.getBase()); // Convert indices to OpFoldResult for StoreMatrixOp SmallVector indices = getAsOpFoldResult(writeOp.getIndices()); xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(), createMemDescOp.getResult(), indices, /*layout=*/nullptr); rewriter.eraseOp(writeOp); return success(); } // TODO:This check needs to be replaced with proper uArch capability check auto chip = xegpu::getChipStr(writeOp); // Lower to scattered store Op if the target HW doesn't have 2d block // store support and the memref is not SLM. if ((chip != "pvc" && chip != "bmg") || writeOp.getVectorType().getRank() > 2) { // TODO: add support for OutOfBound access if (writeOp.hasOutOfBoundsDim()) return failure(); return lowerToScatteredStoreOp(writeOp, rewriter); } if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy))) return failure(); AffineMap map = writeOp.getPermutationMap(); if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); auto [src, indices] = convertMemrefAndOffsetsToTargetRank( rewriter, loc, writeOp.getBase(), getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank()); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), xegpu::MemorySpace::Global); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, dyn_cast>(src)); auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, indices, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint, /*layout=*/nullptr); rewriter.replaceOp(writeOp, storeOp); return success(); } }; struct GatherLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp gatherOp, PatternRewriter &rewriter) const override { auto srcTy = dyn_cast(gatherOp.getBase().getType()); if (!srcTy) return rewriter.notifyMatchFailure(gatherOp, "Expects memref source"); Location loc = gatherOp.getLoc(); VectorType vectorType = gatherOp.getVectorType(); auto meta = computeMemrefMeta(gatherOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides"); Value localOffsets = computeOffsets(rewriter, gatherOp, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(gatherOp, rewriter); auto xeGatherOp = xegpu::LoadGatherOp::create( rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(), /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}, /*layout=*/nullptr); auto selectOp = arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), xeGatherOp.getResult(), gatherOp.getPassThru()); rewriter.replaceOp(gatherOp, selectOp.getResult()); return success(); } }; struct ScatterLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::ScatterOp scatterOp, PatternRewriter &rewriter) const override { auto srcTy = dyn_cast(scatterOp.getBase().getType()); if (!srcTy) return rewriter.notifyMatchFailure(scatterOp, "Expects memref source"); Location loc = scatterOp.getLoc(); auto meta = computeMemrefMeta(scatterOp, rewriter); if (meta.first.empty()) return rewriter.notifyMatchFailure(scatterOp, "Failed to compute strides"); Value localOffsets = computeOffsets(rewriter, scatterOp, meta.first, meta.second); Value flatMemref = memrefToIndexPtr(scatterOp, rewriter); xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets, scatterOp.getMask(), /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, /*l3_hint=*/xegpu::CachePolicyAttr{}, /*layout=*/nullptr); rewriter.eraseOp(scatterOp); return success(); } }; struct LoadLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { Location loc = loadOp.getLoc(); VectorType vecTy = loadOp.getResult().getType(); MemRefType memTy = loadOp.getBase().getType(); if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy))) return failure(); // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto [src, indices] = convertMemrefAndOffsetsToTargetRank( rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()), vecTy.getRank()); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, dyn_cast>(src)); auto loadNdOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint, /*layout=*/nullptr); rewriter.replaceOp(loadOp, loadNdOp); return success(); } }; struct StoreLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { Location loc = storeOp.getLoc(); TypedValue vector = storeOp.getValueToStore(); VectorType vecTy = vector.getType(); MemRefType memTy = storeOp.getBase().getType(); if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy))) return failure(); // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; auto [src, indices] = convertMemrefAndOffsetsToTargetRank( rewriter, loc, storeOp.getBase(), getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank()); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; xegpu::CreateNdDescOp ndDesc = createNdDescriptor( rewriter, loc, descType, dyn_cast>(src)); auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint, /*layout=*/nullptr); rewriter.replaceOp(storeOp, storeNdOp); return success(); } }; struct ContractionLowering : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { Location loc = contractOp.getLoc(); if (contractOp.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(contractOp, "Expects add combining kind"); TypedValue acc = contractOp.getAcc(); VectorType accType = dyn_cast(acc.getType()); if (!accType || accType.getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); // Accept only plain 2D data layout. // VNNI packing is applied to DPAS as a separate lowering step. TypedValue lhs = contractOp.getLhs(); TypedValue rhs = contractOp.getRhs(); if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2) return rewriter.notifyMatchFailure(contractOp, "Expects lhs and rhs 2D vectors"); if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps"); auto dpasOp = xegpu::DpasOp::create(rewriter, loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); return success(); } }; struct ConvertVectorToXeGPUPass : public impl::ConvertVectorToXeGPUBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorToXeGPUConversionPatterns(patterns); populatePrepareVectorToMMAPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // namespace void mlir::populateVectorToXeGPUConversionPatterns( RewritePatternSet &patterns) { patterns .add( patterns.getContext()); }