Files
Nishant Patel b3ca423a78 [MLIR][Vector] Enhance vector.multi_reduction unrolling to handle scalar result (#188633)
Previously, UnrollMultiReductionPattern bailed out when all the
dimensions were reduced to a scalar. This PR adds support for this case
by tiling the source vector and chaining partial reductions through the
accumulator operand.
2026-04-01 14:59:08 -07:00

1418 lines
58 KiB
C++

//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to do vector unrolling and vector distribution.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include <optional>
#define DEBUG_TYPE "vector-unroll"
using namespace mlir;
using namespace mlir::vector;
/// Compute the indices of the slice `index` for a transfer op.
static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
ArrayRef<Value> indices,
AffineMap permutationMap,
Location loc,
OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
auto isBroadcast = [](AffineExpr expr) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
return constExpr.getValue() == 0;
return false;
};
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value> slicedIndices(indices);
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
if (isBroadcast(dim.value()))
continue;
unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
auto expr = getAffineDimExpr(0, builder.getContext()) +
getAffineConstantExpr(elementOffsets[dim.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
slicedIndices[pos] =
affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
}
return slicedIndices;
}
// Compute the new indices by adding `offsets` to `originalIndices`.
// If m < n (m = offsets.size(), n = originalIndices.size()),
// then only the trailing m values in `originalIndices` are updated.
static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
Location loc,
OperandRange originalIndices,
ArrayRef<int64_t> offsets) {
assert(offsets.size() <= originalIndices.size() &&
"Offsets should not exceed the number of original indices");
SmallVector<Value> indices(originalIndices);
auto start = indices.size() - offsets.size();
for (auto [i, offset] : llvm::enumerate(offsets)) {
if (offset != 0) {
indices[start + i] = arith::AddIOp::create(
rewriter, loc, originalIndices[start + i],
arith::ConstantIndexOp::create(rewriter, loc, offset));
}
}
return indices;
}
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
return builder.create(loc, op->getName().getIdentifier(), operands,
resultTypes, op->getAttrs());
}
/// Return the target shape for unrolling for the given `op`. Return
/// std::nullopt if the op shouldn't be or cannot be unrolled.
static std::optional<SmallVector<int64_t>>
getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
if (options.filterConstraint && failed(options.filterConstraint(op))) {
LDBG() << "--no filter constraint -> BAIL";
return std::nullopt;
}
assert(options.nativeShape &&
"vector unrolling expects the native shape or native"
"shape call back function to be set");
auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
if (!unrollableVectorOp) {
LDBG() << "--not an unrollable op -> BAIL";
return std::nullopt;
}
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape) {
LDBG() << "--could not get shape of op " << *op << " -> BAIL";
return std::nullopt;
}
LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
if (!targetShape) {
LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
return std::nullopt;
}
LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
if (!maybeShapeRatio) {
LDBG() << "--could not compute integral shape ratio -> BAIL";
return std::nullopt;
}
if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
LDBG() << "--no unrolling needed -> SKIP";
return std::nullopt;
}
LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
return targetShape;
}
static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops, Operation *op,
const vector::UnrollVectorOptions &options) {
SmallVector<int64_t> loopOrder =
llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
if (options.traversalOrderCallback != nullptr) {
std::optional<SmallVector<int64_t>> order =
options.traversalOrderCallback(op);
if (order) {
loopOrder = std::move(*order);
}
}
return loopOrder;
}
namespace {
struct UnrollTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (readOp.getTransferRank() == 0)
return failure();
if (readOp.getMask())
return failure();
auto targetShape = getTargetShape(options, readOp);
if (!targetShape)
return failure();
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = readOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
// Prepare the result vector;
Value result =
arith::ConstantOp::create(rewriter, loc, sourceVectorType,
rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
SmallVector<Value> originalIndices(readOp.getIndices().begin(),
readOp.getIndices().end());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), readOp, options);
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
readOp.getPermutationMap(), loc, rewriter);
auto slicedRead = vector::TransferReadOp::create(
rewriter, loc, targetType, readOp.getBase(), indices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (writeOp.getTransferRank() == 0)
return failure();
if (writeOp.getMask())
return failure();
auto targetShape = getTargetShape(options, writeOp);
if (!targetShape)
return failure();
auto sourceVectorType = writeOp.getVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
// Bail-out if rank(source) != rank(target). The main limitation here is the
// fact that `ExtractStridedSlice` requires the rank for the input and
// output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return rewriter.notifyMatchFailure(
writeOp,
"expected source input vector rank to match target shape rank");
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), writeOp, options);
Value resultTensor;
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
writeOp.getPermutationMap(), loc, rewriter);
Operation *slicedWrite = vector::TransferWriteOp::create(
rewriter, loc, slicedVector,
resultTensor ? resultTensor : writeOp.getBase(), indices,
writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
// For the tensor case update the destination for the next transfer write.
if (!slicedWrite->getResults().empty())
resultTensor = slicedWrite->getResult(0);
}
if (resultTensor)
rewriter.replaceOp(writeOp, resultTensor);
else
rewriter.eraseOp(writeOp);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct OffsetMapInfo {
static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
static unsigned getHashValue(const SmallVector<int64_t> &v) {
return static_cast<unsigned>(llvm::hash_combine_range(v));
}
static bool isEqual(const SmallVector<int64_t> &lhs,
const SmallVector<int64_t> &rhs) {
return lhs == rhs;
}
};
struct UnrollContractionPattern
: public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ContractionOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, contractOp);
if (!targetShape)
return failure();
auto dstVecType = cast<VectorType>(contractOp.getResultType());
SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
Location loc = contractOp.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
llvm::MapVector<
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
accCache;
SmallVector<int64_t> loopOrder = getUnrollOrder(
contractOp.getIteratorTypes().size(), contractOp, options);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
SmallVector<Value> slicesOperands(contractOp.getNumOperands());
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](unsigned index, Value operand,
AffineMap permutationMap,
ArrayRef<int64_t> operandOffets) {
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
slicesOperands[index] =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffets, operandShape, operandStrides);
};
// Extract the new lhs operand.
AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
// Extract the new rhs operand.
AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
SmallVector<int64_t> rhsOffets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
SmallVector<int64_t> accOffets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
// If a version of the accumulator has already been computed, use it
// otherwise extract the first version from the original operand.
auto *accIt = accCache.find(accOffets);
if (accIt != accCache.end())
slicesOperands[2] = accIt->second;
else
extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
SmallVector<int64_t> dstShape =
applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, contractOp, slicesOperands, targetType);
SmallVector<int64_t> dstOffets =
applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
// Save the accumulated value untill all the loops are unrolled since
// reduction loop keep updating the accumulator.
accCache[dstOffets] = newOp->getResult(0);
}
// Assemble back the accumulator into a single vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollMultiReductionPattern
: public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
return failure();
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
Location loc = reductionOp.getLoc();
auto resultType = reductionOp->getResult(0).getType();
// Handle scalar result case: all dimensions are reduced.
// Each source tile is reduced to a scalar, and partial results are
// chained through the accumulator operand.
if (resultType.isIntOrFloat()) {
Value accumulator = reductionOp.getAcc();
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> operandStrides(offsets.size(), 1);
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getSource(), offsets, *targetShape,
operandStrides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, {slicedOperand, accumulator},
resultType);
accumulator = newOp->getResult(0);
}
rewriter.replaceOp(reductionOp, accumulator);
return success();
}
// Vector result case.
llvm::MapVector<
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
accCache;
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
// of multiples of the targetShape.
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> operands;
SmallVector<int64_t> operandStrides(offsets.size(), 1);
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getSource(), offsets, *targetShape,
operandStrides);
operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
if (!reductionOp.isReducedDim(i)) {
destOffset.push_back(offsets[i]);
dstShape.push_back((*targetShape)[i]);
}
}
Value acc;
SmallVector<int64_t> accStrides(destOffset.size(), 1);
// If a version of the accumulator has already been computed, use it
// otherwise extract the first version from the original operand.
auto *accIt = accCache.find(destOffset);
if (accIt != accCache.end())
acc = accIt->second;
else
acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
operands.push_back(acc);
auto targetType = VectorType::get(
dstShape, reductionOp.getSourceVectorType().getElementType());
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
operands, targetType);
Value result = newOp->getResult(0);
accCache[destOffset] = result;
}
// Assemble back the accumulator into a single vector.
Value result = arith::ConstantOp::create(
rewriter, loc, reductionOp.getDestType(),
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
options(options) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
int64_t originalShapeRank = originalSize.size();
Location loc = op->getLoc();
// Handle rank mismatch by adding leading unit dimensions to targetShape
SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
int64_t rankDiff = originalShapeRank - targetShapeRank;
std::fill(adjustedTargetShape.begin(),
adjustedTargetShape.begin() + rankDiff, 1);
std::copy(targetShape->begin(), targetShape->end(),
adjustedTargetShape.begin() + rankDiff);
int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
}
Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, adjustedTargetShape, strides);
// Reshape to remove leading unit dims if needed
if (adjustedTargetShapeRank > targetShapeRank) {
extracted = rewriter.createOrFold<vector::ShapeCastOp>(
loc, VectorType::get(*targetShape, vecType.getElementType()),
extracted);
}
extractOperands.push_back(extracted);
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, unrolledVecType);
Value computeResult = newOp->getResult(0);
// Use strides sized to targetShape for proper insertion
SmallVector<int64_t> insertStrides =
(adjustedTargetShapeRank > targetShapeRank)
? SmallVector<int64_t>(targetShapeRank, 1)
: strides;
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ReductionOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
return failure();
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
// Create unrolled vector reduction.
Location loc = reductionOp.getLoc();
Value accumulator = nullptr;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
if (!accumulator) {
// This is the first reduction.
accumulator = result;
} else {
// On subsequent reduction, combine with the accumulator.
accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
accumulator, result);
}
}
rewriter.replaceOp(reductionOp, accumulator);
return success();
}
private:
const vector::UnrollVectorOptions options;
};
struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTransposePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
if (transposeOp.getResultVectorType().getRank() == 0)
return failure();
auto targetShape = getTargetShape(options, transposeOp);
if (!targetShape)
return failure();
auto originalVectorType = transposeOp.getResultVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = transposeOp.getLoc();
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
// Prepare the result vector;
Value result =
arith::ConstantOp::create(rewriter, loc, originalVectorType,
rewriter.getZeroAttr(originalVectorType));
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
// Unroll the computation.
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> permutedOffsets(elementOffsets.size());
SmallVector<int64_t> permutedShape(elementOffsets.size());
// Compute the source offsets and shape.
for (auto indices : llvm::enumerate(permutation)) {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
strides);
Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
loc, slicedOperand, permutation);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(transposeOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
UnrollGatherPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
}
LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
PatternRewriter &rewriter) const override {
VectorType sourceVectorType = gatherOp.getVectorType();
if (sourceVectorType.getRank() == 0)
return failure();
auto targetShape = getTargetShape(options, gatherOp);
if (!targetShape)
return failure();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = gatherOp.getLoc();
ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
// Prepare the result vector;
Value result =
arith::ConstantOp::create(rewriter, loc, sourceVectorType,
rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), gatherOp, options);
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
// To get the unrolled gather, extract the same slice based on the
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedGather, result, elementOffsets, strides);
}
rewriter.replaceOp(gatherOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
UnrollLoadPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
VectorType vecType = loadOp.getVectorType();
auto targetShape = getTargetShape(options, loadOp);
if (!targetShape)
return failure();
Location loc = loadOp.getLoc();
ArrayRef<int64_t> originalShape = vecType.getShape();
SmallVector<int64_t> strides(targetShape->size(), 1);
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
rewriter.getZeroAttr(vecType));
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalShape.size(), loadOp, options);
auto targetVecType =
VectorType::get(*targetShape, vecType.getElementType());
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
SmallVector<Value> indices =
sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
loadOp.getBase(), indices);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedLoad, result, offsets, strides);
}
rewriter.replaceOp(loadOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
UnrollStorePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
VectorType vecType = storeOp.getVectorType();
auto targetShape = getTargetShape(options, storeOp);
if (!targetShape)
return failure();
Location loc = storeOp.getLoc();
ArrayRef<int64_t> originalShape = vecType.getShape();
SmallVector<int64_t> strides(targetShape->size(), 1);
Value base = storeOp.getBase();
Value vector = storeOp.getValueToStore();
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalShape.size(), storeOp, options);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
SmallVector<Value> indices =
sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, vector, offsets, *targetShape, strides);
vector::StoreOp::create(rewriter, loc, slice, base, indices);
}
rewriter.eraseOp(storeOp);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
UnrollBroadcastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::BroadcastOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, broadcastOp);
if (!targetShape)
return failure();
Location loc = broadcastOp.getLoc();
VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
VectorType resType = broadcastOp.getResultVectorType();
VectorType targetType =
resType.cloneWith(*targetShape, resType.getElementType());
Value result = arith::ConstantOp::create(rewriter, loc, resType,
rewriter.getZeroAttr(resType));
SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
SmallVector<int64_t> strides(originalShape.size(), 1);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value newSrc;
if (!srcType) {
// Scalar to vector broadcast.
newSrc = broadcastOp.getSource();
} else {
// Vector to vector broadcast.
int64_t rank = srcType.getRank();
SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
SmallVector<int64_t> srcShape(targetShape->end() - rank,
targetShape->end());
SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
// adjust the offset and shape for src if the corresponding dim is 1.
for (int64_t i = 0; i < rank; ++i) {
if (srcType.getDimSize(i) == 1) {
srcOffsets[i] = 0;
srcShape[i] = 1;
}
}
newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
}
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
newSrc, targetType);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(broadcastOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
/// outermost dimension of the operand. For example:
///
/// ```
/// %0:4 = vector.to_elements %v : vector<2x2xf32>
///
/// ==>
///
/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
/// ```
///
/// When this pattern is applied until a fixed-point is reached,
/// this will produce a sequence of 1-d from_elements
/// ops.
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
UnrollToElements(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ToElementsOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
TypedValue<VectorType> source = op.getSource();
FailureOr<SmallVector<Value>> result =
vector::unrollVectorValue(source, rewriter);
if (failed(result)) {
return failure();
}
SmallVector<Value> vectors = *result;
SmallVector<Value> results;
for (Value vector : vectors) {
auto subElements =
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
llvm::append_range(results, subElements.getResults());
}
rewriter.replaceOp(op, results);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// This pattern unrolls `vector.step` operations according to the provided
/// target unroll shape. It decomposes a large step vector into smaller step
/// vectors (segments) and assembles the result by inserting each computed
/// segment into the appropriate offset of the original vector.
///
/// The pattern does not support scalable vectors and will fail to match them.
///
/// For each segment, it adds the base step vector and the segment's offset,
/// then inserts the result into the output vector at the corresponding
/// position.
///
/// Example:
/// Given a step operation:
/// %0 = vector.step : vector<8xindex>
///
/// and a target unroll shape of <4>, the pattern produces:
///
/// %base = vector.step : vector<4xindex>
/// %zero = arith.constant dense<0> : vector<8xindex>
/// %result0 = vector.insert_strided_slice %base, %zero
/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
/// %offset = arith.constant dense<4> : vector<4xindex>
/// %segment1 = arith.addi %base, %offset : vector<4xindex>
/// %result1 = vector.insert_strided_slice %segment1, %result0
/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
///
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
UnrollStepPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, stepOp);
if (!targetShape)
return failure();
VectorType vecType = stepOp.getType();
if (vecType.isScalable()) {
// Scalable vectors are not supported by this pattern.
return failure();
}
int64_t originalSize = vecType.getShape()[0];
Location loc = stepOp.getLoc();
SmallVector<int64_t> strides(1, 1);
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
rewriter.getZeroAttr(vecType));
auto targetVecType =
VectorType::get(*targetShape, vecType.getElementType());
Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange({originalSize}, *targetShape)) {
Value bcastOffset = arith::ConstantOp::create(
rewriter, loc, targetVecType,
DenseElementsAttr::get(
targetVecType,
IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
Value tileStep =
arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tileStep, result, offsets, strides);
}
rewriter.replaceOp(stepOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
///
/// ==>
///
/// %0 = ub.poison : vector<2x3xf32>
/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
/// ```
///
/// When this pattern is applied until a fixed-point is reached,
/// this will produce a sequence of 1-d from_elements
/// ops.
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
UnrollFromElements(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::FromElementsOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::FromElementsOp op,
PatternRewriter &rewriter) const override {
ValueRange allElements = op.getElements();
auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
VectorType subTy, int64_t index) {
size_t subTyNumElements = subTy.getNumElements();
assert((index + 1) * subTyNumElements <= allElements.size() &&
"out of bounds");
ValueRange subElements =
allElements.slice(index * subTyNumElements, subTyNumElements);
return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
};
return unrollVectorOp(op, rewriter, unrollFromElementsFn);
}
private:
vector::UnrollVectorOptions options;
};
/// This pattern unrolls `vector.create_mask` operations into smaller mask
/// operations based on the target unroll shape. Each unrolled slice computes
/// its local mask size in each dimension (d) as:
/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
/// Example:
/// Given a create_mask operation:
/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
/// elements
///
/// and a target unroll shape of <4x8>, the pattern produces:
///
/// %false = arith.constant dense<false> : vector<8x16xi1>
///
/// Slice [0,0]:
/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
/// Slice [0,8]:
/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
/// Slice [4,0]:
/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
/// Slice [4,8]:
/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
UnrollCreateMaskPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::CreateMaskOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, createMaskOp);
if (!targetShape)
return failure();
VectorType resultType = createMaskOp.getVectorType();
SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
Location loc = createMaskOp.getLoc();
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
VectorType targetVectorType =
VectorType::get(*targetShape, rewriter.getI1Type());
SmallVector<int64_t> strides(targetShape->size(), 1);
// In each dimension (d), each unrolled vector computes its mask size as:
// min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> unrolledOperands;
for (auto [i, originalMaskOperand] :
llvm::enumerate(createMaskOp.getOperands())) {
Value offsetVal =
arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
loc, originalMaskOperand, offsetVal);
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value unrolledDimSize =
arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
Value nonNegative =
rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
loc, nonNegative, unrolledDimSize);
unrolledOperands.push_back(unrolledOperand);
}
auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
loc, targetVectorType, unrolledOperands);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, unrolledMask, result, offsets, strides);
}
rewriter.replaceOp(createMaskOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// This pattern unrolls `vector.constant_mask` operations into smaller mask
/// operations based on the target unroll shape. Each unrolled slice computes
/// whether its elements should be masked based on the original mask dimensions
/// and the slice's offset position.
///
/// Example:
/// Given a constant_mask operation:
/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
///
/// and a target unroll shape of <4x8>, the pattern produces:
///
/// %false = arith.constant dense<false> : vector<8x16xi1>
///
/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
///
/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
///
/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
///
/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
struct UnrollConstantMaskPattern
: public OpRewritePattern<vector::ConstantMaskOp> {
UnrollConstantMaskPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, constantMaskOp);
if (!targetShape)
return failure();
VectorType resultType = constantMaskOp.getVectorType();
SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
Location loc = constantMaskOp.getLoc();
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
VectorType targetVectorType =
VectorType::get(*targetShape, rewriter.getI1Type());
SmallVector<int64_t> strides(targetShape->size(), 1);
// In each dimension (d), each unrolled vector computes its mask size as:
// min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> unrolledMaskDims;
for (auto [i, originalMaskDim] :
llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
// Calculate how many elements in this dimension should be masked
// for this particular slice
int64_t adjustedMaskSize =
std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
int64_t unrolledMaskDim =
std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
unrolledMaskDims.push_back(unrolledMaskDim);
}
auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
loc, targetVectorType, unrolledMaskDims);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, unrolledMask, result, offsets, strides);
}
rewriter.replaceOp(constantMaskOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
/// exactly. 2) The total number of elements in shape must be evenly divisible
/// by
/// the total number of elements in extractShape.
/// Examples:
/// isContiguous([4, 4], [8, 4]) == true
/// isContiguous([2, 4], [8, 4]) == true
/// isContiguous([2, 2], [8, 4]) == false
/// Removes leading unit dimensions to handle cases like:
/// isContiguous([1, 16], [1, 32]) == true
static bool isContiguous(ArrayRef<int64_t> extractShape,
ArrayRef<int64_t> shape) {
if (extractShape.empty() || shape.empty() ||
extractShape.size() > shape.size())
return false;
while (extractShape.size() > 1 && extractShape.front() == 1)
extractShape = extractShape.drop_front();
while (shape.size() > 1 && shape.front() == 1) {
shape = shape.drop_front();
}
size_t rankDiff = shape.size() - extractShape.size();
if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
return false;
int64_t extractElements = ShapedType::getNumElements(extractShape);
int64_t shapeElements = ShapedType::getNumElements(shape);
return shapeElements % extractElements == 0;
}
/// Determines what shape to use with `vector.extract_strided_slice` to extract
/// a contiguous memory region from a source vector. The extraction must be
/// contiguous and contain exactly the specified number of elements. If such an
/// extraction shape cannot be determined, returns std::nullopt.
/// EXAMPLE 1:
/// sourceShape = [16], targetElements = 8
/// Working right-to-left:
/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
/// remaining = 8/8 = 1
/// Result: [8]
///
/// EXAMPLE 2:
/// sourceShape = [4, 4], targetElements = 8
/// Working right-to-left:
/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
/// remaining = 8/4 = 2
/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
/// remaining = 2/2 = 1
/// Result: [2, 4]
static std::optional<SmallVector<int64_t>>
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
int64_t targetElements) {
SmallVector<int64_t> extractShape;
int64_t remainingElements = targetElements;
// Build extract shape from innermost dimension outward to ensure contiguity.
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
extractShape.insert(extractShape.begin(), takeFromDim);
if (remainingElements % takeFromDim != 0)
return std::nullopt; // Not evenly divisible.
remainingElements /= takeFromDim;
}
// Fill remaining dimensions with 1.
while (extractShape.size() < sourceShape.size())
extractShape.insert(extractShape.begin(), 1);
if (ShapedType::getNumElements(extractShape) != targetElements)
return std::nullopt;
return extractShape;
}
// Convert result offsets to source offsets via linear position.
static SmallVector<int64_t>
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape) {
// Convert result offsets to linear position.
int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
// Convert linear position to source offsets.
return delinearize(linearIndex, computeStrides(sourceShape));
}
/// This pattern unrolls `vector.shape_cast` operations according to the
/// provided target unroll shape. It unrolls a large shape cast into smaller
/// shape casts by extracting contiguous slices from the source vector, casting
/// each slice to the target shape, and assembling the result by inserting each
/// computed segment into the appropriate offset of the result vector.
///
/// This pattern only applies when contiguous slices can be extracted from the
/// source vector and inserted into the result vector such that each slice
/// remains a valid vector (and not decompose to scalars). In these cases, the
/// unrolling proceeds as:
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
/// vector.insert_strided_slice.
///
/// Example:
/// Given a shape cast operation:
/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
///
/// and a target unroll shape of <2x4>, the pattern produces:
///
/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
/// : vector<8x2xf32> to vector<4x2xf32>
/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
/// : vector<2x4xf32> into vector<4x4xf32>
/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
/// : vector<8x2xf32> to vector<4x2xf32>
/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
/// : vector<2x4xf32> into vector<4x4xf32>
///
struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
UnrollShapeCastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ShapeCastOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, shapeCastOp);
if (!targetShape)
return failure();
VectorType sourceType = shapeCastOp.getSourceVectorType();
VectorType resultType = shapeCastOp.getResultVectorType();
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
if (!isContiguous(*targetShape, resultShape))
return rewriter.notifyMatchFailure(
shapeCastOp, "Only supports cases where target shape is "
"contiguous in result vector shape");
int64_t targetElements = ShapedType::getNumElements(*targetShape);
// Calculate the shape to extract from source.
std::optional<SmallVector<int64_t>> extractShape =
calculateSourceExtractShape(sourceShape, targetElements);
if (!extractShape)
return rewriter.notifyMatchFailure(
shapeCastOp,
"cannot extract target number of elements contiguously from source");
Location loc = shapeCastOp.getLoc();
// Create result vector initialized to zero.
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
VectorType targetType =
VectorType::get(*targetShape, sourceType.getElementType());
SmallVector<int64_t> extractStrides(extractShape->size(), 1);
SmallVector<int64_t> insertStrides(targetShape->size(), 1);
for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
SmallVector<int64_t> sourceOffsets =
calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
extractStrides);
Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
loc, targetType, sourceChunk);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, targetChunk, result, resultOffsets, insertStrides);
}
rewriter.replaceOp(shapeCastOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
benefit);
}
void mlir::vector::populateVectorFromElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
benefit);
}