[mlir][Vector] Make createWriteOrMaskedWrite utility (#190967)
Analog to https://github.com/llvm/llvm-project/pull/89119, make `createWriteOrMaskedWrite` a vector utility, exposing it for re-use by downstream users. This PR is mostly just moving code and updating documentation but also addresses a `TODO` for `isMaskTriviallyFoldable` to use that utility in `createReadOrMaskedRead` as well. No new tests were added, because the functionality is covered by existing tests. --------- Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
This commit is contained in:
@@ -236,6 +236,19 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
|
||||
bool useInBoundsInsteadOfMasking = false,
|
||||
ArrayRef<bool> inputScalableVecDims = {});
|
||||
|
||||
/// Create a TransferWriteOp of `vecToStore` into `dest`.
|
||||
///
|
||||
/// If the shape of the vector to write differs from the destination shape,
|
||||
/// masking is used to avoid out-of-bounds accesses. Set
|
||||
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
|
||||
/// instead of explicit masks.
|
||||
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
|
||||
/// to 0.
|
||||
Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
|
||||
Value vecToStore, Value dest,
|
||||
SmallVector<Value> writeIndices = {},
|
||||
bool useInBoundsInsteadOfMasking = false);
|
||||
|
||||
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
|
||||
/// given `shape`, i.e., it meets:
|
||||
/// 1. The numbers of elements in both array are equal.
|
||||
|
||||
@@ -1574,211 +1574,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Determines whether a mask for xfer_write is trivially "all true"
|
||||
///
|
||||
/// Given all the inputs required to generate a mask (mask sizes and shapes),
|
||||
/// and an xfer_write operation (write indices and the destination tensor
|
||||
/// shape), determines whether the corresponding mask would be trivially
|
||||
/// foldable (i.e., trivially "all true").
|
||||
///
|
||||
/// Use this method to avoid generating spurious masks and relaying on
|
||||
/// vectorization post-processing to remove them.
|
||||
///
|
||||
/// Pre-conditions for a mask to be trivially foldable:
|
||||
/// * All involved shapes (mask + destination tensor) are static.
|
||||
/// * All write indices are constant.
|
||||
/// * All mask sizes are constant (including `arith.constant`).
|
||||
///
|
||||
/// If the pre-conditions are met, the method checks for each destination
|
||||
/// dimension `d`:
|
||||
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
|
||||
/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
|
||||
///
|
||||
/// rankDiff = rank(dest) - rank(mask).
|
||||
///
|
||||
/// This method takes a conservative view: it may return false even if the mask
|
||||
/// is technically foldable.
|
||||
///
|
||||
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
|
||||
/// of the dest tensor):
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %mask = vector.create_mask 5, 1
|
||||
/// vector.mask %mask {
|
||||
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
|
||||
/// {in_bounds = [true, true]}
|
||||
/// : vector<5x1xi32>, tensor<5x1xi32>
|
||||
/// }
|
||||
///
|
||||
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
|
||||
/// mask is required to avoid out-of-bounds write):
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %mask = vector.create_mask 5, 1
|
||||
/// vector.mask %mask {
|
||||
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
|
||||
/// {in_bounds = [true, true]}
|
||||
/// : vector<8x1xi32>, tensor<5x1xi32>
|
||||
/// }
|
||||
///
|
||||
/// TODO: Re-use in createReadOrMaskedRead
|
||||
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
|
||||
SmallVector<Value> &writeIdxs,
|
||||
ArrayRef<int64_t> destShape,
|
||||
ArrayRef<int64_t> maskShape) {
|
||||
// Masking is unavoidable in the case of dynamic tensors.
|
||||
if (ShapedType::isDynamicShape(destShape))
|
||||
return false;
|
||||
|
||||
// Collect all constant mask sizes.
|
||||
SmallVector<int64_t, 4> cstMaskSizes;
|
||||
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
|
||||
if (auto intSize = getConstantIntValue(dimSize)) {
|
||||
cstMaskSizes.push_back(*intSize);
|
||||
}
|
||||
}
|
||||
|
||||
// If any of the mask sizes is non-constant, bail out.
|
||||
if (cstMaskSizes.size() != maskShape.size())
|
||||
return false;
|
||||
|
||||
// Collect all constant write indices.
|
||||
SmallVector<int64_t, 4> cstWriteIdxs;
|
||||
for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
|
||||
APSInt intVal;
|
||||
if (matchPattern(idx, m_ConstantInt(&intVal))) {
|
||||
cstWriteIdxs.push_back(intVal.getSExtValue());
|
||||
}
|
||||
}
|
||||
|
||||
// If any of the write indices is non-constant, bail out.
|
||||
if (cstWriteIdxs.size() != destShape.size())
|
||||
return false;
|
||||
|
||||
// Go over all destination dims and check (1) and (2). Take into account that:
|
||||
// * The number of mask sizes will match the rank of the vector to store.
|
||||
// This could be lower than the rank of the destination tensor.
|
||||
// * Mask sizes could be larger than the corresponding mask shape (hence
|
||||
// `clamp`).
|
||||
// TODO: The 2nd item should be rejected by the verifier.
|
||||
int64_t rankDiff = destShape.size() - cstMaskSizes.size();
|
||||
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
|
||||
if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
|
||||
/*(2)*/ destShape[rankDiff + i] <
|
||||
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
|
||||
cstWriteIdxs[i]))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Creates an optionally masked TransferWriteOp
|
||||
///
|
||||
/// Generates the following operation:
|
||||
/// %res = vector.transfer_write %vecToStore into %dest
|
||||
///
|
||||
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
|
||||
///
|
||||
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
|
||||
/// %res = vector.mask %mask {
|
||||
/// vector.transfer_write %vecToStore into %dest
|
||||
/// }
|
||||
///
|
||||
/// The mask shape is identical to `vecToStore` (with the element type ==
|
||||
/// i1), and the mask values are based on the shape of the `dest` tensor.
|
||||
///
|
||||
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
|
||||
/// is used instead of masking:
|
||||
///
|
||||
/// %write = vector.transfer_write %vecToStore into %dest
|
||||
/// in_bounds_flags = (...)
|
||||
/// %res = vector.transfer_write %input into %dest
|
||||
/// {in_bounds = in_bounds_flags}
|
||||
///
|
||||
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
|
||||
/// are set to 0.
|
||||
static Operation *
|
||||
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
|
||||
Value dest, SmallVector<Value> writeIndices = {},
|
||||
bool useInBoundsInsteadOfMasking = false) {
|
||||
|
||||
ShapedType destType = cast<ShapedType>(dest.getType());
|
||||
int64_t destRank = destType.getRank();
|
||||
auto destShape = destType.getShape();
|
||||
|
||||
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
|
||||
int64_t vecToStoreRank = vecToStoreType.getRank();
|
||||
auto vecToStoreShape = vecToStoreType.getShape();
|
||||
|
||||
// Compute the in_bounds attribute
|
||||
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
|
||||
if (useInBoundsInsteadOfMasking) {
|
||||
// Update the inBounds attribute.
|
||||
// FIXME: This computation is too weak - it ignores the write indices.
|
||||
for (unsigned i = 0; i < vecToStoreRank; i++)
|
||||
inBoundsVal[i] =
|
||||
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
|
||||
ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
|
||||
}
|
||||
|
||||
// If missing, initialize the write indices to 0.
|
||||
bool useDefaultWriteIdxs = writeIndices.empty();
|
||||
assert((useDefaultWriteIdxs ||
|
||||
writeIndices.size() == static_cast<size_t>(destRank)) &&
|
||||
"Invalid number of write indices!");
|
||||
if (writeIndices.empty()) {
|
||||
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||
writeIndices.assign(destRank, zero);
|
||||
}
|
||||
|
||||
// Generate the xfer_write Op
|
||||
Operation *write = vector::TransferWriteOp::create(builder, loc,
|
||||
/*vector=*/vecToStore,
|
||||
/*source=*/dest,
|
||||
/*indices=*/writeIndices,
|
||||
/*inBounds=*/inBoundsVal);
|
||||
|
||||
// If masking is disabled, exit.
|
||||
if (useInBoundsInsteadOfMasking)
|
||||
return write;
|
||||
|
||||
// Check if masking is needed. If not, exit.
|
||||
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
|
||||
return write;
|
||||
|
||||
// Compute the mask and mask the write Op.
|
||||
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
|
||||
vecToStoreType.getScalableDims());
|
||||
|
||||
SmallVector<OpFoldResult> destSizes =
|
||||
isa<MemRefType>(dest.getType())
|
||||
? memref::getMixedSizes(builder, loc, dest)
|
||||
: tensor::getMixedSizes(builder, loc, dest);
|
||||
|
||||
// Compute sizes for write-mask
|
||||
SmallVector<OpFoldResult> maskSizes;
|
||||
if (useDefaultWriteIdxs) {
|
||||
maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
|
||||
destSizes.end());
|
||||
} else {
|
||||
size_t diff = destShape.size() - vecToStoreRank;
|
||||
for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
|
||||
auto value =
|
||||
getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
|
||||
auto neg =
|
||||
builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
|
||||
maskSizes.push_back(OpFoldResult(neg));
|
||||
}
|
||||
}
|
||||
|
||||
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
|
||||
vecToStoreShape))
|
||||
return write;
|
||||
|
||||
Value maskForWrite =
|
||||
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
|
||||
return mlir::vector::maskOperation(builder, write, maskForWrite);
|
||||
}
|
||||
|
||||
/// Given the re-associations, "collapses" the input Vector type
|
||||
///
|
||||
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
|
||||
@@ -1929,7 +1724,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
|
||||
rewriter, loc, shapeCastOp.getResult(), destPermutation);
|
||||
|
||||
// Create TransferWriteOp.
|
||||
Operation *write = createWriteOrMaskedWrite(
|
||||
Operation *write = vector::createWriteOrMaskedWrite(
|
||||
rewriter, loc, transposeOp.getResult(), packOp.getDest());
|
||||
newResults.push_back(write->getResult(0));
|
||||
return success();
|
||||
@@ -2025,7 +1820,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
|
||||
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
|
||||
|
||||
// -- Generate the write operation --
|
||||
Operation *write = createWriteOrMaskedWrite(
|
||||
Operation *write = vector::createWriteOrMaskedWrite(
|
||||
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
|
||||
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
|
||||
|
||||
@@ -2061,7 +1856,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
|
||||
// Create Xfer write Op
|
||||
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
|
||||
padOp.getResultType().getElementType());
|
||||
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
|
||||
Operation *write =
|
||||
vector::createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
|
||||
newResults.push_back(write->getResult(0));
|
||||
return success();
|
||||
}
|
||||
@@ -2279,7 +2075,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
|
||||
contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
|
||||
|
||||
// Store result.
|
||||
Operation *write = createWriteOrMaskedWrite(
|
||||
Operation *write = vector::createWriteOrMaskedWrite(
|
||||
rewriter, loc, contractOp->getResult(0), outOperand->get());
|
||||
|
||||
// Finalize.
|
||||
@@ -3207,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
|
||||
auto writeIndices =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
|
||||
Operation *write =
|
||||
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
|
||||
writeIndices, inputVectorSizes.empty());
|
||||
vector::createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
|
||||
writeIndices, inputVectorSizes.empty());
|
||||
|
||||
// 4. Finalize
|
||||
newResults.push_back(write->getResult(0));
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
@@ -22,6 +23,7 @@
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
@@ -309,6 +311,101 @@ bool vector::isLinearizableVector(VectorType type) {
|
||||
return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
|
||||
}
|
||||
|
||||
/// Determines whether a mask for xfer_read/write is trivially "all true"
|
||||
///
|
||||
/// Given all the inputs required to generate a mask (mask sizes and shapes),
|
||||
/// and an xfer_read/write operation (indices and the source/destination tensor
|
||||
/// shape), determines whether the corresponding mask would be trivially
|
||||
/// foldable (i.e., trivially "all true").
|
||||
///
|
||||
/// Use this method to avoid generating spurious masks and relying on
|
||||
/// vectorization post-processing to remove them.
|
||||
///
|
||||
/// Pre-conditions for a mask to be trivially foldable:
|
||||
/// * All involved shapes (mask + destination tensor) are static.
|
||||
/// * All indices are constant.
|
||||
/// * All mask sizes are constant (including `arith.constant`).
|
||||
///
|
||||
/// If the pre-conditions are met, the method checks for each destination
|
||||
/// dimension `d`:
|
||||
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
|
||||
/// (2) destDimSize[rankDiff + d] <= index[d] + maskSize[d]
|
||||
///
|
||||
/// rankDiff = rank(dest) - rank(mask).
|
||||
///
|
||||
/// This method takes a conservative view: it may return false even if the mask
|
||||
/// is technically foldable.
|
||||
///
|
||||
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
|
||||
/// of the dest tensor):
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %mask = vector.create_mask 5, 1
|
||||
/// vector.mask %mask {
|
||||
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
|
||||
/// {in_bounds = [true, true]}
|
||||
/// : vector<5x1xi32>, tensor<5x1xi32>
|
||||
/// }
|
||||
///
|
||||
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
|
||||
/// mask is required to avoid out-of-bounds write):
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %mask = vector.create_mask 5, 1
|
||||
/// vector.mask %mask {
|
||||
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
|
||||
/// {in_bounds = [true, true]}
|
||||
/// : vector<8x1xi32>, tensor<5x1xi32>
|
||||
/// }
|
||||
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
|
||||
SmallVector<Value> &indices,
|
||||
ArrayRef<int64_t> baseShape,
|
||||
ArrayRef<int64_t> maskShape) {
|
||||
// Masking is unavoidable in the case of dynamic tensors.
|
||||
if (ShapedType::isDynamicShape(baseShape))
|
||||
return false;
|
||||
|
||||
// Collect all constant mask sizes.
|
||||
SmallVector<int64_t, 4> cstMaskSizes;
|
||||
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
|
||||
if (auto intSize = getConstantIntValue(dimSize)) {
|
||||
cstMaskSizes.push_back(*intSize);
|
||||
}
|
||||
}
|
||||
|
||||
// If any of the mask sizes is non-constant, bail out.
|
||||
if (cstMaskSizes.size() != maskShape.size())
|
||||
return false;
|
||||
|
||||
// Collect all constant indices.
|
||||
SmallVector<int64_t, 4> cstIndices;
|
||||
for (auto [i, idx] : llvm::enumerate(indices)) {
|
||||
APSInt intVal;
|
||||
if (matchPattern(idx, m_ConstantInt(&intVal))) {
|
||||
cstIndices.push_back(intVal.getSExtValue());
|
||||
}
|
||||
}
|
||||
|
||||
// If any of the indices is non-constant, bail out.
|
||||
if (cstIndices.size() != baseShape.size())
|
||||
return false;
|
||||
|
||||
// Go over all destination dims and check (1) and (2). Take into account that:
|
||||
// * The number of mask sizes will match the rank of the vector to
|
||||
// load/store. This could be lower than the rank of the destination tensor.
|
||||
// * Mask sizes could be larger than the corresponding mask shape (hence
|
||||
// `clamp`).
|
||||
// TODO: The 2nd item should be rejected by the verifier.
|
||||
int64_t rankDiff = baseShape.size() - cstMaskSizes.size();
|
||||
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
|
||||
if (/*(1)*/ maskShape[i] > baseShape[rankDiff + i] ||
|
||||
/*(2)*/ baseShape[rankDiff + i] <
|
||||
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
|
||||
cstIndices[i]))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
|
||||
Value source,
|
||||
ArrayRef<int64_t> inputVectorSizes,
|
||||
@@ -353,22 +450,27 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
|
||||
inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
|
||||
ShapedType::isStatic(sourceShape[i]);
|
||||
}
|
||||
auto transferReadOp = vector::TransferReadOp::create(
|
||||
builder, loc,
|
||||
/*vectorType=*/vecToReadTy,
|
||||
/*source=*/source,
|
||||
/*indices=*/Repeated<Value>(vecToReadRank, zero),
|
||||
/*padding=*/padValue,
|
||||
/*inBounds=*/inBoundsVal);
|
||||
SmallVector<Value> indices(vecToReadRank, zero);
|
||||
auto transferReadOp =
|
||||
vector::TransferReadOp::create(builder, loc,
|
||||
/*vectorType=*/vecToReadTy,
|
||||
/*source=*/source,
|
||||
/*indices=*/indices,
|
||||
/*padding=*/padValue,
|
||||
/*inBounds=*/inBoundsVal);
|
||||
|
||||
if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
|
||||
useInBoundsInsteadOfMasking)
|
||||
if (useInBoundsInsteadOfMasking)
|
||||
return transferReadOp;
|
||||
|
||||
SmallVector<OpFoldResult> mixedSourceDims =
|
||||
isa<MemRefType>(source.getType())
|
||||
? memref::getMixedSizes(builder, loc, source)
|
||||
: tensor::getMixedSizes(builder, loc, source);
|
||||
|
||||
if (isMaskTriviallyFoldable(mixedSourceDims, indices, sourceShape,
|
||||
vecToReadShape))
|
||||
return transferReadOp;
|
||||
|
||||
auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type());
|
||||
Value mask =
|
||||
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
|
||||
@@ -376,6 +478,89 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
|
||||
->getResult(0);
|
||||
}
|
||||
|
||||
Operation *vector::createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
|
||||
Value vecToStore, Value dest,
|
||||
SmallVector<Value> writeIndices,
|
||||
bool useInBoundsInsteadOfMasking) {
|
||||
|
||||
ShapedType destType = cast<ShapedType>(dest.getType());
|
||||
int64_t destRank = destType.getRank();
|
||||
auto destShape = destType.getShape();
|
||||
|
||||
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
|
||||
int64_t vecToStoreRank = vecToStoreType.getRank();
|
||||
auto vecToStoreShape = vecToStoreType.getShape();
|
||||
|
||||
// Compute the in_bounds attribute
|
||||
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
|
||||
if (useInBoundsInsteadOfMasking) {
|
||||
// Update the inBounds attribute.
|
||||
// FIXME: This computation is too weak - it ignores the write indices.
|
||||
for (unsigned i = 0; i < vecToStoreRank; i++)
|
||||
inBoundsVal[i] =
|
||||
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
|
||||
ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
|
||||
}
|
||||
|
||||
// If missing, initialize the write indices to 0.
|
||||
bool useDefaultWriteIdxs = writeIndices.empty();
|
||||
assert((useDefaultWriteIdxs ||
|
||||
writeIndices.size() == static_cast<size_t>(destRank)) &&
|
||||
"Invalid number of write indices!");
|
||||
if (useDefaultWriteIdxs) {
|
||||
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||
writeIndices.assign(destRank, zero);
|
||||
}
|
||||
|
||||
// Generate the xfer_write Op
|
||||
Operation *write = vector::TransferWriteOp::create(builder, loc,
|
||||
/*vector=*/vecToStore,
|
||||
/*dest=*/dest,
|
||||
/*indices=*/writeIndices,
|
||||
/*inBounds=*/inBoundsVal);
|
||||
|
||||
// If masking is disabled, exit.
|
||||
if (useInBoundsInsteadOfMasking)
|
||||
return write;
|
||||
|
||||
// Check if masking is needed. If not, exit.
|
||||
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
|
||||
return write;
|
||||
|
||||
// Compute the mask and mask the write Op.
|
||||
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
|
||||
vecToStoreType.getScalableDims());
|
||||
|
||||
SmallVector<OpFoldResult> destSizes =
|
||||
isa<MemRefType>(dest.getType())
|
||||
? memref::getMixedSizes(builder, loc, dest)
|
||||
: tensor::getMixedSizes(builder, loc, dest);
|
||||
|
||||
// Compute sizes for write-mask
|
||||
SmallVector<OpFoldResult> maskSizes;
|
||||
if (useDefaultWriteIdxs) {
|
||||
maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
|
||||
destSizes.end());
|
||||
} else {
|
||||
size_t diff = destShape.size() - vecToStoreRank;
|
||||
for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
|
||||
auto value =
|
||||
getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
|
||||
auto neg =
|
||||
builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
|
||||
maskSizes.push_back(OpFoldResult(neg));
|
||||
}
|
||||
}
|
||||
|
||||
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
|
||||
vecToStoreShape))
|
||||
return write;
|
||||
|
||||
Value maskForWrite =
|
||||
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
|
||||
return mlir::vector::maskOperation(builder, write, maskForWrite);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> inputVectorSizes) {
|
||||
|
||||
Reference in New Issue
Block a user