Use wrappers around `std::accumulate` to make the code more concise and less bug-prone: https://github.com/llvm/llvm-project/pull/162129. With `std::accumulate`, it's the initial value that determines the accumulator type. `llvm::sum_of` and `llvm::product_of` pick the right accumulator type based on the range element type. Found some funny bugs like a local accumulate helper that calculated a sum with initial value of 1 -- we didn't hit the bug because the code was actually dead...
477 lines
19 KiB
C++
477 lines
19 KiB
C++
//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements target-independent rewrites and utilities to lower the
|
|
// 'vector.shape_cast' operation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/UB//IR/UBOps.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include <numeric>
|
|
|
|
#define DEBUG_TYPE "vector-shape-cast-lowering"
|
|
|
|
using namespace mlir;
|
|
|
|
/// Perform the inplace update
|
|
/// rhs <- lhs + rhs
|
|
///
|
|
/// where `rhs` is a number expressed in mixed base `base` with most signficant
|
|
/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
|
|
/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
|
|
///
|
|
/// Some examples where `base` is {5,3,2}:
|
|
/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
|
|
/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
|
|
/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
|
|
///
|
|
/// Invalid:
|
|
/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
|
|
///
|
|
/// Overflows not handled correctly:
|
|
/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
|
|
static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
|
|
MutableArrayRef<int64_t> rhs) {
|
|
|
|
// For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
|
|
for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
|
|
int64_t dimBase = base[dim];
|
|
assert(rhs[dim] < dimBase && "rhs not in base");
|
|
|
|
int64_t incremented = rhs[dim] + lhs;
|
|
|
|
// If the incremented value excedes the dimension base, we must spill to the
|
|
// next most significant dimension and repeat (we might need to spill to
|
|
// more significant dimensions multiple times).
|
|
lhs = incremented / dimBase;
|
|
rhs[dim] = incremented % dimBase;
|
|
if (lhs == 0)
|
|
break;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// shape_cast is converted to a sequence of extract, extract_strided_slice,
|
|
/// insert_strided_slice, and insert operations. The running example will be:
|
|
///
|
|
/// %0 = vector.shape_cast %arg0 :
|
|
/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
|
|
///
|
|
/// In this example the source and result shapes share a common suffix of 7x11.
|
|
/// This means we can always decompose the shape_cast into extract, insert, and
|
|
/// their strided equivalents, on vectors with shape suffix 7x11.
|
|
///
|
|
/// The greatest common divisor (gcd) of the first dimension preceding the
|
|
/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
|
|
/// on vectors with shapes that are `multiples` of (what we define as) the
|
|
/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
|
|
///
|
|
/// vector<2x2x3x4x7x11xi8> to
|
|
/// vector<8x6x7x11xi8>
|
|
/// | ||||
|
|
/// | ++++------------> common suffix of 7x11
|
|
/// +-----------------> gcd(4,6) is 2 | |
|
|
/// | | |
|
|
/// v v v
|
|
/// atomic shape <----- 2x7x11
|
|
///
|
|
///
|
|
///
|
|
/// The decomposition implemented in this pattern consists of a sequence of
|
|
/// repeated steps:
|
|
///
|
|
/// (1) Extract vectors from the suffix of the source.
|
|
/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
|
|
///
|
|
/// (2) Do extract_strided_slice down to the atomic shape.
|
|
/// In our example this is 4x7x11 -> 2x7x11.
|
|
///
|
|
/// (3) Do insert_strided_slice to the suffix of the result.
|
|
/// In our example this is 2x7x11 -> 6x7x11.
|
|
///
|
|
/// (4) insert these vectors into the result vector.
|
|
/// In our example this is 6x7x11 -> 8x6x7x11.
|
|
///
|
|
/// These steps occur with different periods. In this example
|
|
/// (1) occurs 12 times,
|
|
/// (2) and (3) occur 24 times, and
|
|
/// (4) occurs 8 times.
|
|
///
|
|
/// Two special cases are handled independently in this pattern
|
|
/// (i) A shape_cast that just does leading 1 insertion/removal
|
|
/// (ii) A shape_cast where the gcd is 1.
|
|
///
|
|
/// These 2 cases can have more compact IR generated by not using the generic
|
|
/// algorithm described above.
|
|
///
|
|
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
|
|
|
|
// Case (i) of description.
|
|
// Assumes source and result shapes are identical up to some leading ones.
|
|
static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
|
|
PatternRewriter &rewriter) {
|
|
|
|
const Location loc = shapeCast.getLoc();
|
|
const VectorType sourceType = shapeCast.getSourceVectorType();
|
|
const VectorType resultType = shapeCast.getResultVectorType();
|
|
|
|
const int64_t sourceRank = sourceType.getRank();
|
|
const int64_t resultRank = resultType.getRank();
|
|
const int64_t delta = sourceRank - resultRank;
|
|
const int64_t sourceLeading = delta > 0 ? delta : 0;
|
|
const int64_t resultLeading = delta > 0 ? 0 : -delta;
|
|
|
|
const Value source = shapeCast.getSource();
|
|
const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
|
|
const Value extracted = vector::ExtractOp::create(
|
|
rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0));
|
|
const Value result =
|
|
vector::InsertOp::create(rewriter, loc, extracted, poison,
|
|
SmallVector<int64_t>(resultLeading, 0));
|
|
|
|
rewriter.replaceOp(shapeCast, result);
|
|
return success();
|
|
}
|
|
|
|
// Case (ii) of description.
|
|
// Assumes a shape_cast where the suffix shape of the source starting at
|
|
// `sourceDim` and the suffix shape of the result starting at `resultDim` are
|
|
// identical.
|
|
static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
|
|
int64_t sourceDim,
|
|
int64_t resultDim,
|
|
PatternRewriter &rewriter) {
|
|
|
|
const Location loc = shapeCast.getLoc();
|
|
|
|
const Value source = shapeCast.getSource();
|
|
const ArrayRef<int64_t> sourceShape =
|
|
shapeCast.getSourceVectorType().getShape();
|
|
|
|
const VectorType resultType = shapeCast.getResultVectorType();
|
|
const ArrayRef<int64_t> resultShape = resultType.getShape();
|
|
|
|
const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim));
|
|
SmallVector<int64_t> extractIndex(sourceDim, 0);
|
|
SmallVector<int64_t> insertIndex(resultDim, 0);
|
|
Value result = ub::PoisonOp::create(rewriter, loc, resultType);
|
|
|
|
for (int i = 0; i < nSlices; ++i) {
|
|
Value extracted =
|
|
vector::ExtractOp::create(rewriter, loc, source, extractIndex);
|
|
|
|
result = vector::InsertOp::create(rewriter, loc, extracted, result,
|
|
insertIndex);
|
|
|
|
inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
|
|
inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
|
|
}
|
|
rewriter.replaceOp(shapeCast, result);
|
|
return success();
|
|
}
|
|
|
|
public:
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
VectorType sourceType = op.getSourceVectorType();
|
|
VectorType resultType = op.getResultVectorType();
|
|
|
|
if (sourceType.isScalable() || resultType.isScalable())
|
|
return rewriter.notifyMatchFailure(
|
|
op,
|
|
"shape_cast where vectors are scalable not handled by this pattern");
|
|
|
|
const ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
const ArrayRef<int64_t> resultShape = resultType.getShape();
|
|
const int64_t sourceRank = sourceType.getRank();
|
|
const int64_t resultRank = resultType.getRank();
|
|
const int64_t numElms = sourceType.getNumElements();
|
|
const Value source = op.getSource();
|
|
|
|
// Set the first dimension (starting at the end) in the source and result
|
|
// respectively where the dimension sizes differ. Using the running example:
|
|
//
|
|
// dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
|
|
// shapes: (2,2,3,4,7,11) -> (8,6,7,11)
|
|
// ^ ^
|
|
// | |
|
|
// sourceSuffixStartDim is 3 |
|
|
// |
|
|
// resultSuffixStartDim is 1
|
|
int64_t sourceSuffixStartDim = sourceRank - 1;
|
|
int64_t resultSuffixStartDim = resultRank - 1;
|
|
while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
|
|
(sourceType.getDimSize(sourceSuffixStartDim) ==
|
|
resultType.getDimSize(resultSuffixStartDim))) {
|
|
--sourceSuffixStartDim;
|
|
--resultSuffixStartDim;
|
|
}
|
|
|
|
// This is the case (i) where there are just some leading ones to contend
|
|
// with in the source or result. It can be handled with a single
|
|
// extract/insert pair.
|
|
if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
|
|
return leadingOnesLowering(op, rewriter);
|
|
|
|
const int64_t sourceSuffixStartDimSize =
|
|
sourceType.getDimSize(sourceSuffixStartDim);
|
|
const int64_t resultSuffixStartDimSize =
|
|
resultType.getDimSize(resultSuffixStartDim);
|
|
const int64_t greatestCommonDivisor =
|
|
std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
|
|
const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
|
|
const size_t extractPeriod =
|
|
sourceSuffixStartDimSize / greatestCommonDivisor;
|
|
const size_t insertPeriod =
|
|
resultSuffixStartDimSize / greatestCommonDivisor;
|
|
|
|
SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
|
|
sourceShape.end());
|
|
atomicShape[0] = greatestCommonDivisor;
|
|
|
|
const int64_t numAtomicElms = std::accumulate(
|
|
atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
|
|
const size_t nAtomicSlices = numElms / numAtomicElms;
|
|
|
|
// This is the case (ii) where the strided dimension size is 1. More compact
|
|
// IR is generated in this case if we just extract and insert the elements
|
|
// directly. In other words, we don't use extract_strided_slice and
|
|
// insert_strided_slice.
|
|
if (greatestCommonDivisor == 1)
|
|
return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
|
|
resultSuffixStartDim + 1, rewriter);
|
|
|
|
// The insert_strided_slice result's type
|
|
const ArrayRef<int64_t> insertStridedShape =
|
|
resultShape.drop_front(resultSuffixStartDim);
|
|
const VectorType insertStridedType =
|
|
VectorType::get(insertStridedShape, resultType.getElementType());
|
|
|
|
SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
|
|
SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
|
|
SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
|
|
SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
|
|
const SmallVector<int64_t> sizes(stridedSliceRank, 1);
|
|
|
|
Value extracted = {};
|
|
Value extractedStrided = {};
|
|
Value insertedSlice = {};
|
|
Value result = ub::PoisonOp::create(rewriter, loc, resultType);
|
|
const Value partResult =
|
|
ub::PoisonOp::create(rewriter, loc, insertStridedType);
|
|
|
|
for (size_t i = 0; i < nAtomicSlices; ++i) {
|
|
|
|
const size_t extractStridedPhase = i % extractPeriod;
|
|
const size_t insertStridedPhase = i % insertPeriod;
|
|
|
|
// vector.extract
|
|
if (extractStridedPhase == 0) {
|
|
extracted =
|
|
vector::ExtractOp::create(rewriter, loc, source, extractIndex);
|
|
inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
|
|
extractIndex);
|
|
}
|
|
|
|
// vector.extract_strided_slice
|
|
extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
|
|
extractedStrided = vector::ExtractStridedSliceOp::create(
|
|
rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
|
|
|
|
// vector.insert_strided_slice
|
|
if (insertStridedPhase == 0) {
|
|
insertedSlice = partResult;
|
|
}
|
|
insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
|
|
insertedSlice = vector::InsertStridedSliceOp::create(
|
|
rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
|
|
|
|
// vector.insert
|
|
if (insertStridedPhase + 1 == insertPeriod) {
|
|
result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
|
|
insertIndex);
|
|
inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
|
|
insertIndex);
|
|
}
|
|
}
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A shape_cast lowering for scalable vectors with a single trailing scalable
|
|
/// dimension. This is similar to the general shape_cast lowering but makes use
|
|
/// of vector.scalable.insert and vector.scalable.extract to move elements a
|
|
/// subvector at a time.
|
|
///
|
|
/// E.g.:
|
|
/// ```
|
|
/// // Flatten scalable vector
|
|
/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// // Flatten scalable vector
|
|
/// %c = arith.constant dense<0> : vector<[8]xi32>
|
|
/// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
|
|
/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
|
|
/// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
|
|
/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
|
|
/// ```
|
|
/// or:
|
|
/// ```
|
|
/// // Un-flatten scalable vector
|
|
/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// // Un-flatten scalable vector
|
|
/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
|
|
/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
|
|
/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
|
|
/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
|
|
/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
|
|
/// ```
|
|
class ScalableShapeCastOpRewritePattern
|
|
: public OpRewritePattern<vector::ShapeCastOp> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
auto sourceVectorType = op.getSourceVectorType();
|
|
auto resultVectorType = op.getResultVectorType();
|
|
auto srcRank = sourceVectorType.getRank();
|
|
auto resRank = resultVectorType.getRank();
|
|
|
|
// This can only lower shape_casts where both the source and result types
|
|
// have a single trailing scalable dimension. This is because there are no
|
|
// legal representation of other scalable types in LLVM (and likely won't be
|
|
// soon). There are also (currently) no operations that can index or extract
|
|
// from >= 2-D scalable vectors or scalable vectors of fixed vectors.
|
|
if (!isTrailingDimScalable(sourceVectorType) ||
|
|
!isTrailingDimScalable(resultVectorType)) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "trailing dims are not scalable, not handled by this pattern");
|
|
}
|
|
|
|
// The sizes of the trailing dimension of the source and result vectors, the
|
|
// size of subvector to move, and the number of elements in the vectors.
|
|
// These are "min" sizes as they are the size when vscale == 1.
|
|
auto minSourceTrailingSize = sourceVectorType.getShape().back();
|
|
auto minResultTrailingSize = resultVectorType.getShape().back();
|
|
auto minExtractionSize =
|
|
std::min(minSourceTrailingSize, minResultTrailingSize);
|
|
int64_t minNumElts = 1;
|
|
for (auto size : sourceVectorType.getShape())
|
|
minNumElts *= size;
|
|
|
|
// The subvector type to move from the source to the result. Note that this
|
|
// is a scalable vector. This rewrite will generate code in terms of the
|
|
// "min" size (vscale == 1 case), that scales to any vscale.
|
|
auto extractionVectorType = VectorType::get(
|
|
{minExtractionSize}, sourceVectorType.getElementType(), {true});
|
|
|
|
Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
|
|
SmallVector<int64_t> srcIdx(srcRank, 0);
|
|
SmallVector<int64_t> resIdx(resRank, 0);
|
|
|
|
// TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
|
|
// once D150000 lands.
|
|
Value currentResultScalableVector;
|
|
Value currentSourceScalableVector;
|
|
for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
|
|
// 1. Extract a scalable subvector from the source vector.
|
|
if (!currentSourceScalableVector) {
|
|
if (srcRank != 1) {
|
|
currentSourceScalableVector =
|
|
vector::ExtractOp::create(rewriter, loc, op.getSource(),
|
|
llvm::ArrayRef(srcIdx).drop_back());
|
|
} else {
|
|
currentSourceScalableVector = op.getSource();
|
|
}
|
|
}
|
|
Value sourceSubVector = currentSourceScalableVector;
|
|
if (minExtractionSize < minSourceTrailingSize) {
|
|
sourceSubVector = vector::ScalableExtractOp::create(
|
|
rewriter, loc, extractionVectorType, sourceSubVector,
|
|
srcIdx.back());
|
|
}
|
|
|
|
// 2. Insert the scalable subvector into the result vector.
|
|
if (!currentResultScalableVector) {
|
|
if (minExtractionSize == minResultTrailingSize) {
|
|
currentResultScalableVector = sourceSubVector;
|
|
} else if (resRank != 1) {
|
|
currentResultScalableVector = vector::ExtractOp::create(
|
|
rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back());
|
|
} else {
|
|
currentResultScalableVector = result;
|
|
}
|
|
}
|
|
if (minExtractionSize < minResultTrailingSize) {
|
|
currentResultScalableVector = vector::ScalableInsertOp::create(
|
|
rewriter, loc, sourceSubVector, currentResultScalableVector,
|
|
resIdx.back());
|
|
}
|
|
|
|
// 3. Update the source and result scalable vectors if needed.
|
|
if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
|
|
currentResultScalableVector != result) {
|
|
// Finished row of result. Insert complete scalable vector into result
|
|
// (n-D) vector.
|
|
result = vector::InsertOp::create(rewriter, loc,
|
|
currentResultScalableVector, result,
|
|
llvm::ArrayRef(resIdx).drop_back());
|
|
currentResultScalableVector = {};
|
|
}
|
|
if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
|
|
// Finished row of source.
|
|
currentSourceScalableVector = {};
|
|
}
|
|
|
|
// 4. Increment the insert/extract indices, stepping by minExtractionSize
|
|
// for the trailing dimensions.
|
|
inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
|
|
inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
static bool isTrailingDimScalable(VectorType type) {
|
|
return type.getRank() >= 1 && type.getScalableDims().back() &&
|
|
!llvm::is_contained(type.getScalableDims().drop_back(), true);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::vector::populateVectorShapeCastLoweringPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit benefit) {
|
|
patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
|
|
patterns.getContext(), benefit);
|
|
}
|