[mlir][Vector] Make createWriteOrMaskedWrite utility (#190967)

Analog to https://github.com/llvm/llvm-project/pull/89119, make
`createWriteOrMaskedWrite` a vector utility, exposing it for re-use by
downstream users.

This PR is mostly just moving code and updating documentation but also
addresses a `TODO` for `isMaskTriviallyFoldable` to use that utility in
`createReadOrMaskedRead` as well.

No new tests were added, because the functionality is covered by existing tests.

---------

Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
This commit is contained in:
Lukas Sommer
2026-04-09 12:52:04 +02:00
committed by GitHub
parent 01c5908761
commit 3529ce05e9
3 changed files with 214 additions and 220 deletions

View File

@@ -236,6 +236,19 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
bool useInBoundsInsteadOfMasking = false,
ArrayRef<bool> inputScalableVecDims = {});
/// Create a TransferWriteOp of `vecToStore` into `dest`.
///
/// If the shape of the vector to write differs from the destination shape,
/// masking is used to avoid out-of-bounds accesses. Set
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
/// instead of explicit masks.
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
/// to 0.
Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
Value vecToStore, Value dest,
SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false);
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
/// 1. The numbers of elements in both array are equal.

View File

@@ -1574,211 +1574,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
/// Determines whether a mask for xfer_write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
/// and an xfer_write operation (write indices and the destination tensor
/// shape), determines whether the corresponding mask would be trivially
/// foldable (i.e., trivially "all true").
///
/// Use this method to avoid generating spurious masks and relaying on
/// vectorization post-processing to remove them.
///
/// Pre-conditions for a mask to be trivially foldable:
/// * All involved shapes (mask + destination tensor) are static.
/// * All write indices are constant.
/// * All mask sizes are constant (including `arith.constant`).
///
/// If the pre-conditions are met, the method checks for each destination
/// dimension `d`:
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
///
/// rankDiff = rank(dest) - rank(mask).
///
/// This method takes a conservative view: it may return false even if the mask
/// is technically foldable.
///
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
/// of the dest tensor):
/// %c0 = arith.constant 0 : index
/// %mask = vector.create_mask 5, 1
/// vector.mask %mask {
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
/// {in_bounds = [true, true]}
/// : vector<5x1xi32>, tensor<5x1xi32>
/// }
///
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
/// mask is required to avoid out-of-bounds write):
/// %c0 = arith.constant 0 : index
/// %mask = vector.create_mask 5, 1
/// vector.mask %mask {
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
/// {in_bounds = [true, true]}
/// : vector<8x1xi32>, tensor<5x1xi32>
/// }
///
/// TODO: Re-use in createReadOrMaskedRead
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
SmallVector<Value> &writeIdxs,
ArrayRef<int64_t> destShape,
ArrayRef<int64_t> maskShape) {
// Masking is unavoidable in the case of dynamic tensors.
if (ShapedType::isDynamicShape(destShape))
return false;
// Collect all constant mask sizes.
SmallVector<int64_t, 4> cstMaskSizes;
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
if (auto intSize = getConstantIntValue(dimSize)) {
cstMaskSizes.push_back(*intSize);
}
}
// If any of the mask sizes is non-constant, bail out.
if (cstMaskSizes.size() != maskShape.size())
return false;
// Collect all constant write indices.
SmallVector<int64_t, 4> cstWriteIdxs;
for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
APSInt intVal;
if (matchPattern(idx, m_ConstantInt(&intVal))) {
cstWriteIdxs.push_back(intVal.getSExtValue());
}
}
// If any of the write indices is non-constant, bail out.
if (cstWriteIdxs.size() != destShape.size())
return false;
// Go over all destination dims and check (1) and (2). Take into account that:
// * The number of mask sizes will match the rank of the vector to store.
// This could be lower than the rank of the destination tensor.
// * Mask sizes could be larger than the corresponding mask shape (hence
// `clamp`).
// TODO: The 2nd item should be rejected by the verifier.
int64_t rankDiff = destShape.size() - cstMaskSizes.size();
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
/*(2)*/ destShape[rankDiff + i] <
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
cstWriteIdxs[i]))
return false;
}
return true;
}
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
/// %res = vector.transfer_write %vecToStore into %dest
///
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
/// vector.transfer_write %vecToStore into %dest
/// }
///
/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
/// are set to 0.
static Operation *
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the write indices.
for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] =
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
}
// If missing, initialize the write indices to 0.
bool useDefaultWriteIdxs = writeIndices.empty();
assert((useDefaultWriteIdxs ||
writeIndices.size() == static_cast<size_t>(destRank)) &&
"Invalid number of write indices!");
if (writeIndices.empty()) {
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
writeIndices.assign(destRank, zero);
}
// Generate the xfer_write Op
Operation *write = vector::TransferWriteOp::create(builder, loc,
/*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
// If masking is disabled, exit.
if (useInBoundsInsteadOfMasking)
return write;
// Check if masking is needed. If not, exit.
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
return write;
// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
isa<MemRefType>(dest.getType())
? memref::getMixedSizes(builder, loc, dest)
: tensor::getMixedSizes(builder, loc, dest);
// Compute sizes for write-mask
SmallVector<OpFoldResult> maskSizes;
if (useDefaultWriteIdxs) {
maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
destSizes.end());
} else {
size_t diff = destShape.size() - vecToStoreRank;
for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
auto value =
getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
auto neg =
builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
maskSizes.push_back(OpFoldResult(neg));
}
}
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
vecToStoreShape))
return write;
Value maskForWrite =
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Given the re-associations, "collapses" the input Vector type
///
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1929,7 +1724,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
rewriter, loc, shapeCastOp.getResult(), destPermutation);
// Create TransferWriteOp.
Operation *write = createWriteOrMaskedWrite(
Operation *write = vector::createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), packOp.getDest());
newResults.push_back(write->getResult(0));
return success();
@@ -2025,7 +1820,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
// -- Generate the write operation --
Operation *write = createWriteOrMaskedWrite(
Operation *write = vector::createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
@@ -2061,7 +1856,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
padOp.getResultType().getElementType());
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
Operation *write =
vector::createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
newResults.push_back(write->getResult(0));
return success();
}
@@ -2279,7 +2075,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
// Store result.
Operation *write = createWriteOrMaskedWrite(
Operation *write = vector::createWriteOrMaskedWrite(
rewriter, loc, contractOp->getResult(0), outOperand->get());
// Finalize.
@@ -3207,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
writeIndices, inputVectorSizes.empty());
vector::createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
writeIndices, inputVectorSizes.empty());
// 4. Finalize
newResults.push_back(write->getResult(0));

View File

@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -22,6 +23,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
@@ -309,6 +311,101 @@ bool vector::isLinearizableVector(VectorType type) {
return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
}
/// Determines whether a mask for xfer_read/write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
/// and an xfer_read/write operation (indices and the source/destination tensor
/// shape), determines whether the corresponding mask would be trivially
/// foldable (i.e., trivially "all true").
///
/// Use this method to avoid generating spurious masks and relying on
/// vectorization post-processing to remove them.
///
/// Pre-conditions for a mask to be trivially foldable:
/// * All involved shapes (mask + destination tensor) are static.
/// * All indices are constant.
/// * All mask sizes are constant (including `arith.constant`).
///
/// If the pre-conditions are met, the method checks for each destination
/// dimension `d`:
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
/// (2) destDimSize[rankDiff + d] <= index[d] + maskSize[d]
///
/// rankDiff = rank(dest) - rank(mask).
///
/// This method takes a conservative view: it may return false even if the mask
/// is technically foldable.
///
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
/// of the dest tensor):
/// %c0 = arith.constant 0 : index
/// %mask = vector.create_mask 5, 1
/// vector.mask %mask {
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
/// {in_bounds = [true, true]}
/// : vector<5x1xi32>, tensor<5x1xi32>
/// }
///
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
/// mask is required to avoid out-of-bounds write):
/// %c0 = arith.constant 0 : index
/// %mask = vector.create_mask 5, 1
/// vector.mask %mask {
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
/// {in_bounds = [true, true]}
/// : vector<8x1xi32>, tensor<5x1xi32>
/// }
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
SmallVector<Value> &indices,
ArrayRef<int64_t> baseShape,
ArrayRef<int64_t> maskShape) {
// Masking is unavoidable in the case of dynamic tensors.
if (ShapedType::isDynamicShape(baseShape))
return false;
// Collect all constant mask sizes.
SmallVector<int64_t, 4> cstMaskSizes;
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
if (auto intSize = getConstantIntValue(dimSize)) {
cstMaskSizes.push_back(*intSize);
}
}
// If any of the mask sizes is non-constant, bail out.
if (cstMaskSizes.size() != maskShape.size())
return false;
// Collect all constant indices.
SmallVector<int64_t, 4> cstIndices;
for (auto [i, idx] : llvm::enumerate(indices)) {
APSInt intVal;
if (matchPattern(idx, m_ConstantInt(&intVal))) {
cstIndices.push_back(intVal.getSExtValue());
}
}
// If any of the indices is non-constant, bail out.
if (cstIndices.size() != baseShape.size())
return false;
// Go over all destination dims and check (1) and (2). Take into account that:
// * The number of mask sizes will match the rank of the vector to
// load/store. This could be lower than the rank of the destination tensor.
// * Mask sizes could be larger than the corresponding mask shape (hence
// `clamp`).
// TODO: The 2nd item should be rejected by the verifier.
int64_t rankDiff = baseShape.size() - cstMaskSizes.size();
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
if (/*(1)*/ maskShape[i] > baseShape[rankDiff + i] ||
/*(2)*/ baseShape[rankDiff + i] <
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
cstIndices[i]))
return false;
}
return true;
}
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
@@ -353,22 +450,27 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
ShapedType::isStatic(sourceShape[i]);
}
auto transferReadOp = vector::TransferReadOp::create(
builder, loc,
/*vectorType=*/vecToReadTy,
/*source=*/source,
/*indices=*/Repeated<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
SmallVector<Value> indices(vecToReadRank, zero);
auto transferReadOp =
vector::TransferReadOp::create(builder, loc,
/*vectorType=*/vecToReadTy,
/*source=*/source,
/*indices=*/indices,
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
useInBoundsInsteadOfMasking)
if (useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
if (isMaskTriviallyFoldable(mixedSourceDims, indices, sourceShape,
vecToReadShape))
return transferReadOp;
auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type());
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
@@ -376,6 +478,89 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
->getResult(0);
}
Operation *vector::createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
Value vecToStore, Value dest,
SmallVector<Value> writeIndices,
bool useInBoundsInsteadOfMasking) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the write indices.
for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] =
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
}
// If missing, initialize the write indices to 0.
bool useDefaultWriteIdxs = writeIndices.empty();
assert((useDefaultWriteIdxs ||
writeIndices.size() == static_cast<size_t>(destRank)) &&
"Invalid number of write indices!");
if (useDefaultWriteIdxs) {
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
writeIndices.assign(destRank, zero);
}
// Generate the xfer_write Op
Operation *write = vector::TransferWriteOp::create(builder, loc,
/*vector=*/vecToStore,
/*dest=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
// If masking is disabled, exit.
if (useInBoundsInsteadOfMasking)
return write;
// Check if masking is needed. If not, exit.
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
return write;
// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
isa<MemRefType>(dest.getType())
? memref::getMixedSizes(builder, loc, dest)
: tensor::getMixedSizes(builder, loc, dest);
// Compute sizes for write-mask
SmallVector<OpFoldResult> maskSizes;
if (useDefaultWriteIdxs) {
maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
destSizes.end());
} else {
size_t diff = destShape.size() - vecToStoreRank;
for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
auto value =
getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
auto neg =
builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
maskSizes.push_back(OpFoldResult(neg));
}
}
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
vecToStoreShape))
return write;
Value maskForWrite =
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
return mlir::vector::maskOperation(builder, write, maskForWrite);
}
LogicalResult
vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
ArrayRef<int64_t> inputVectorSizes) {