//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.shape_cast' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/UB//IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" #include #define DEBUG_TYPE "vector-shape-cast-lowering" using namespace mlir; /// Perform the inplace update /// rhs <- lhs + rhs /// /// where `rhs` is a number expressed in mixed base `base` with most signficant /// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is /// {5,3,2} then `rhs` has value a*3*2 + b*2 + c. /// /// Some examples where `base` is {5,3,2}: /// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1} /// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0} /// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1} /// /// Invalid: /// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2} /// /// Overflows not handled correctly: /// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1}) static void inplaceAdd(int64_t lhs, ArrayRef base, MutableArrayRef rhs) { // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]: for (int dim : llvm::reverse(llvm::seq(0, rhs.size()))) { int64_t dimBase = base[dim]; assert(rhs[dim] < dimBase && "rhs not in base"); int64_t incremented = rhs[dim] + lhs; // If the incremented value excedes the dimension base, we must spill to the // next most significant dimension and repeat (we might need to spill to // more significant dimensions multiple times). lhs = incremented / dimBase; rhs[dim] = incremented % dimBase; if (lhs == 0) break; } } namespace { /// shape_cast is converted to a sequence of extract, extract_strided_slice, /// insert_strided_slice, and insert operations. The running example will be: /// /// %0 = vector.shape_cast %arg0 : /// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8> /// /// In this example the source and result shapes share a common suffix of 7x11. /// This means we can always decompose the shape_cast into extract, insert, and /// their strided equivalents, on vectors with shape suffix 7x11. /// /// The greatest common divisor (gcd) of the first dimension preceding the /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate /// on vectors with shapes that are `multiples` of (what we define as) the /// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`. /// /// vector<2x2x3x4x7x11xi8> to /// vector<8x6x7x11xi8> /// | |||| /// | ++++------------> common suffix of 7x11 /// +-----------------> gcd(4,6) is 2 | | /// | | | /// v v v /// atomic shape <----- 2x7x11 /// /// /// /// The decomposition implemented in this pattern consists of a sequence of /// repeated steps: /// /// (1) Extract vectors from the suffix of the source. /// In our example this is 2x2x3x4x7x11 -> 4x7x11. /// /// (2) Do extract_strided_slice down to the atomic shape. /// In our example this is 4x7x11 -> 2x7x11. /// /// (3) Do insert_strided_slice to the suffix of the result. /// In our example this is 2x7x11 -> 6x7x11. /// /// (4) insert these vectors into the result vector. /// In our example this is 6x7x11 -> 8x6x7x11. /// /// These steps occur with different periods. In this example /// (1) occurs 12 times, /// (2) and (3) occur 24 times, and /// (4) occurs 8 times. /// /// Two special cases are handled independently in this pattern /// (i) A shape_cast that just does leading 1 insertion/removal /// (ii) A shape_cast where the gcd is 1. /// /// These 2 cases can have more compact IR generated by not using the generic /// algorithm described above. /// class ShapeCastOpRewritePattern : public OpRewritePattern { // Case (i) of description. // Assumes source and result shapes are identical up to some leading ones. static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast, PatternRewriter &rewriter) { const Location loc = shapeCast.getLoc(); const VectorType sourceType = shapeCast.getSourceVectorType(); const VectorType resultType = shapeCast.getResultVectorType(); const int64_t sourceRank = sourceType.getRank(); const int64_t resultRank = resultType.getRank(); const int64_t delta = sourceRank - resultRank; const int64_t sourceLeading = delta > 0 ? delta : 0; const int64_t resultLeading = delta > 0 ? 0 : -delta; const Value source = shapeCast.getSource(); const Value poison = ub::PoisonOp::create(rewriter, loc, resultType); const Value extracted = vector::ExtractOp::create( rewriter, loc, source, SmallVector(sourceLeading, 0)); const Value result = vector::InsertOp::create(rewriter, loc, extracted, poison, SmallVector(resultLeading, 0)); rewriter.replaceOp(shapeCast, result); return success(); } // Case (ii) of description. // Assumes a shape_cast where the suffix shape of the source starting at // `sourceDim` and the suffix shape of the result starting at `resultDim` are // identical. static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast, int64_t sourceDim, int64_t resultDim, PatternRewriter &rewriter) { const Location loc = shapeCast.getLoc(); const Value source = shapeCast.getSource(); const ArrayRef sourceShape = shapeCast.getSourceVectorType().getShape(); const VectorType resultType = shapeCast.getResultVectorType(); const ArrayRef resultShape = resultType.getShape(); const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim)); SmallVector extractIndex(sourceDim, 0); SmallVector insertIndex(resultDim, 0); Value result = ub::PoisonOp::create(rewriter, loc, resultType); for (int i = 0; i < nSlices; ++i) { Value extracted = vector::ExtractOp::create(rewriter, loc, source, extractIndex); result = vector::InsertOp::create(rewriter, loc, extracted, result, insertIndex); inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex); inplaceAdd(1, resultShape.take_front(resultDim), insertIndex); } rewriter.replaceOp(shapeCast, result); return success(); } public: using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); VectorType sourceType = op.getSourceVectorType(); VectorType resultType = op.getResultVectorType(); if (sourceType.isScalable() || resultType.isScalable()) return rewriter.notifyMatchFailure( op, "shape_cast where vectors are scalable not handled by this pattern"); const ArrayRef sourceShape = sourceType.getShape(); const ArrayRef resultShape = resultType.getShape(); const int64_t sourceRank = sourceType.getRank(); const int64_t resultRank = resultType.getRank(); const int64_t numElms = sourceType.getNumElements(); const Value source = op.getSource(); // Set the first dimension (starting at the end) in the source and result // respectively where the dimension sizes differ. Using the running example: // // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ] // shapes: (2,2,3,4,7,11) -> (8,6,7,11) // ^ ^ // | | // sourceSuffixStartDim is 3 | // | // resultSuffixStartDim is 1 int64_t sourceSuffixStartDim = sourceRank - 1; int64_t resultSuffixStartDim = resultRank - 1; while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 && (sourceType.getDimSize(sourceSuffixStartDim) == resultType.getDimSize(resultSuffixStartDim))) { --sourceSuffixStartDim; --resultSuffixStartDim; } // This is the case (i) where there are just some leading ones to contend // with in the source or result. It can be handled with a single // extract/insert pair. if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) return leadingOnesLowering(op, rewriter); const int64_t sourceSuffixStartDimSize = sourceType.getDimSize(sourceSuffixStartDim); const int64_t resultSuffixStartDimSize = resultType.getDimSize(resultSuffixStartDim); const int64_t greatestCommonDivisor = std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize); const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim; const size_t extractPeriod = sourceSuffixStartDimSize / greatestCommonDivisor; const size_t insertPeriod = resultSuffixStartDimSize / greatestCommonDivisor; SmallVector atomicShape(sourceShape.begin() + sourceSuffixStartDim, sourceShape.end()); atomicShape[0] = greatestCommonDivisor; const int64_t numAtomicElms = std::accumulate( atomicShape.begin(), atomicShape.end(), 1, std::multiplies()); const size_t nAtomicSlices = numElms / numAtomicElms; // This is the case (ii) where the strided dimension size is 1. More compact // IR is generated in this case if we just extract and insert the elements // directly. In other words, we don't use extract_strided_slice and // insert_strided_slice. if (greatestCommonDivisor == 1) return noStridedSliceLowering(op, sourceSuffixStartDim + 1, resultSuffixStartDim + 1, rewriter); // The insert_strided_slice result's type const ArrayRef insertStridedShape = resultShape.drop_front(resultSuffixStartDim); const VectorType insertStridedType = VectorType::get(insertStridedShape, resultType.getElementType()); SmallVector extractIndex(sourceSuffixStartDim, 0); SmallVector insertIndex(resultSuffixStartDim, 0); SmallVector extractOffsets(stridedSliceRank, 0); SmallVector insertOffsets(stridedSliceRank, 0); const SmallVector sizes(stridedSliceRank, 1); Value extracted = {}; Value extractedStrided = {}; Value insertedSlice = {}; Value result = ub::PoisonOp::create(rewriter, loc, resultType); const Value partResult = ub::PoisonOp::create(rewriter, loc, insertStridedType); for (size_t i = 0; i < nAtomicSlices; ++i) { const size_t extractStridedPhase = i % extractPeriod; const size_t insertStridedPhase = i % insertPeriod; // vector.extract if (extractStridedPhase == 0) { extracted = vector::ExtractOp::create(rewriter, loc, source, extractIndex); inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim), extractIndex); } // vector.extract_strided_slice extractOffsets[0] = extractStridedPhase * greatestCommonDivisor; extractedStrided = vector::ExtractStridedSliceOp::create( rewriter, loc, extracted, extractOffsets, atomicShape, sizes); // vector.insert_strided_slice if (insertStridedPhase == 0) { insertedSlice = partResult; } insertOffsets[0] = insertStridedPhase * greatestCommonDivisor; insertedSlice = vector::InsertStridedSliceOp::create( rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes); // vector.insert if (insertStridedPhase + 1 == insertPeriod) { result = vector::InsertOp::create(rewriter, loc, insertedSlice, result, insertIndex); inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim), insertIndex); } } rewriter.replaceOp(op, result); return success(); } }; /// A shape_cast lowering for scalable vectors with a single trailing scalable /// dimension. This is similar to the general shape_cast lowering but makes use /// of vector.scalable.insert and vector.scalable.extract to move elements a /// subvector at a time. /// /// E.g.: /// ``` /// // Flatten scalable vector /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> /// ``` /// is rewritten to: /// ``` /// // Flatten scalable vector /// %c = arith.constant dense<0> : vector<[8]xi32> /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> /// ``` /// or: /// ``` /// // Un-flatten scalable vector /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> /// ``` /// is rewritten to: /// ``` /// // Un-flatten scalable vector /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> /// ``` class ScalableShapeCastOpRewritePattern : public OpRewritePattern { public: using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); auto srcRank = sourceVectorType.getRank(); auto resRank = resultVectorType.getRank(); // This can only lower shape_casts where both the source and result types // have a single trailing scalable dimension. This is because there are no // legal representation of other scalable types in LLVM (and likely won't be // soon). There are also (currently) no operations that can index or extract // from >= 2-D scalable vectors or scalable vectors of fixed vectors. if (!isTrailingDimScalable(sourceVectorType) || !isTrailingDimScalable(resultVectorType)) { return rewriter.notifyMatchFailure( op, "trailing dims are not scalable, not handled by this pattern"); } // The sizes of the trailing dimension of the source and result vectors, the // size of subvector to move, and the number of elements in the vectors. // These are "min" sizes as they are the size when vscale == 1. auto minSourceTrailingSize = sourceVectorType.getShape().back(); auto minResultTrailingSize = resultVectorType.getShape().back(); auto minExtractionSize = std::min(minSourceTrailingSize, minResultTrailingSize); int64_t minNumElts = 1; for (auto size : sourceVectorType.getShape()) minNumElts *= size; // The subvector type to move from the source to the result. Note that this // is a scalable vector. This rewrite will generate code in terms of the // "min" size (vscale == 1 case), that scales to any vscale. auto extractionVectorType = VectorType::get( {minExtractionSize}, sourceVectorType.getElementType(), {true}); Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType); SmallVector srcIdx(srcRank, 0); SmallVector resIdx(resRank, 0); // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) // once D150000 lands. Value currentResultScalableVector; Value currentSourceScalableVector; for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { // 1. Extract a scalable subvector from the source vector. if (!currentSourceScalableVector) { if (srcRank != 1) { currentSourceScalableVector = vector::ExtractOp::create(rewriter, loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); } else { currentSourceScalableVector = op.getSource(); } } Value sourceSubVector = currentSourceScalableVector; if (minExtractionSize < minSourceTrailingSize) { sourceSubVector = vector::ScalableExtractOp::create( rewriter, loc, extractionVectorType, sourceSubVector, srcIdx.back()); } // 2. Insert the scalable subvector into the result vector. if (!currentResultScalableVector) { if (minExtractionSize == minResultTrailingSize) { currentResultScalableVector = sourceSubVector; } else if (resRank != 1) { currentResultScalableVector = vector::ExtractOp::create( rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back()); } else { currentResultScalableVector = result; } } if (minExtractionSize < minResultTrailingSize) { currentResultScalableVector = vector::ScalableInsertOp::create( rewriter, loc, sourceSubVector, currentResultScalableVector, resIdx.back()); } // 3. Update the source and result scalable vectors if needed. if (resIdx.back() + minExtractionSize >= minResultTrailingSize && currentResultScalableVector != result) { // Finished row of result. Insert complete scalable vector into result // (n-D) vector. result = vector::InsertOp::create(rewriter, loc, currentResultScalableVector, result, llvm::ArrayRef(resIdx).drop_back()); currentResultScalableVector = {}; } if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { // Finished row of source. currentSourceScalableVector = {}; } // 4. Increment the insert/extract indices, stepping by minExtractionSize // for the trailing dimensions. inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx); inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx); } rewriter.replaceOp(op, result); return success(); } static bool isTrailingDimScalable(VectorType type) { return type.getRank() >= 1 && type.getScalableDims().back() && !llvm::is_contained(type.getScalableDims().drop_back(), true); } }; } // namespace void mlir::vector::populateVectorShapeCastLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( patterns.getContext(), benefit); }