//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// // /// Part of the LLVM Project, under the Apache License v2.0 with LLVM /// Exceptions. See https://llvm.org/LICENSE.txt for license information. /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.multi_reduction' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace vector { #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" } // namespace vector } // namespace mlir #define DEBUG_TYPE "vector-multi-reduction" using namespace mlir; namespace { /// This file implements the following transformations as composable atomic /// patterns. /// Converts vector.multi_reduction into inner-most/outer-most reduction form /// by using vector.transpose class InnerOuterDimReductionConversion : public OpRewritePattern { public: using Base::Base; explicit InnerOuterDimReductionConversion( MLIRContext *context, vector::VectorMultiReductionLowering options, PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit), useInnerDimsForReduction( options == vector::VectorMultiReductionLowering::InnerReduction) {} LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { // Vector mask setup. OpBuilder::InsertionGuard guard(rewriter); auto maskableOp = cast(multiReductionOp.getOperation()); Operation *rootOp; if (maskableOp.isMasked()) { rewriter.setInsertionPoint(maskableOp.getMaskingOp()); rootOp = maskableOp.getMaskingOp(); } else { rootOp = multiReductionOp; } auto src = multiReductionOp.getSource(); auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Separate reduction and parallel dims ArrayRef reductionDims = multiReductionOp.getReductionDims(); llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); int64_t reductionSize = reductionDims.size(); SmallVector parallelDims; for (int64_t i = 0; i < srcRank; ++i) if (!reductionDimsSet.contains(i)) parallelDims.push_back(i); // Add transpose only if inner-most/outer-most dimensions are not parallel // and there are parallel dims. if (parallelDims.empty()) return failure(); if (useInnerDimsForReduction && (parallelDims == llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) return failure(); if (!useInnerDimsForReduction && (parallelDims == llvm::to_vector<4>(llvm::seq( reductionDims.size(), parallelDims.size() + reductionDims.size())))) return failure(); SmallVector indices; if (useInnerDimsForReduction) { indices.append(parallelDims.begin(), parallelDims.end()); indices.append(reductionDims.begin(), reductionDims.end()); } else { indices.append(reductionDims.begin(), reductionDims.end()); indices.append(parallelDims.begin(), parallelDims.end()); } // If masked, transpose the original mask. Value transposedMask; if (maskableOp.isMasked()) { transposedMask = vector::TransposeOp::create( rewriter, loc, maskableOp.getMaskingOp().getMask(), indices); } // Transpose reduction source. auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { if (useInnerDimsForReduction) reductionMask[srcRank - i - 1] = true; else reductionMask[i] = true; } Operation *newMultiRedOp = vector::MultiDimReductionOp::create( rewriter, multiReductionOp.getLoc(), transposeOp.getResult(), multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); newMultiRedOp = mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0)); return success(); } private: const bool useInnerDimsForReduction; }; /// Flattens vector.multi_reduction to 2D /// /// Given all reduction dimensions are either inner most or outer most, /// flattens all reduction and parallel dimensions so that there are only 2Ds. /// /// BEFORE /// vector.multi_reduction , %vec, %acc [2, 3] : vector<2x3x4x5xi32> to /// vector<2x3xi32> /// AFTER /// %vec_sc = vector.shape_cast %vec /// %acc_sc = vector.shape_cast %acc /// %res = vector.multi_reduction , %vec_sc, %acc_cs [1] : /// vector<6x20xi32> to vector<6xi32> %res_sc = vector.shape_cast %res class FlattenMultiReduction : public OpRewritePattern { public: using Base::Base; explicit FlattenMultiReduction(MLIRContext *context, vector::VectorMultiReductionLowering options, PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit), useInnerDimsForReduction( options == vector::VectorMultiReductionLowering::InnerReduction) {} LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { // Vector mask setup. OpBuilder::InsertionGuard guard(rewriter); auto maskableOp = cast(multiReductionOp.getOperation()); Operation *rootOp; if (maskableOp.isMasked()) { rewriter.setInsertionPoint(maskableOp.getMaskingOp()); rootOp = maskableOp.getMaskingOp(); } else { rootOp = multiReductionOp; } auto srcRank = multiReductionOp.getSourceVectorType().getRank(); auto srcShape = multiReductionOp.getSourceVectorType().getShape(); auto srcScalableDims = multiReductionOp.getSourceVectorType().getScalableDims(); auto loc = multiReductionOp.getLoc(); // If rank less than 2, nothing to do. if (srcRank < 2) return failure(); // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. // `vscale * vscale` that's currently not modelled. if (llvm::count(srcScalableDims, true) > 1) return failure(); // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. SmallVector reductionMask = multiReductionOp.getReductionMask(); if (srcRank == 2 && reductionMask.front() != reductionMask.back()) return failure(); // 1. Separate reduction and parallel dims. SmallVector parallelDims, parallelShapes; SmallVector parallelScalableDims; SmallVector reductionDims, reductionShapes; bool isReductionDimScalable = false; for (const auto &it : llvm::enumerate(reductionMask)) { int64_t i = it.index(); bool isReduction = it.value(); if (isReduction) { reductionDims.push_back(i); reductionShapes.push_back(srcShape[i]); isReductionDimScalable |= srcScalableDims[i]; } else { parallelDims.push_back(i); parallelShapes.push_back(srcShape[i]); parallelScalableDims.push_back(srcScalableDims[i]); } } // 2. Compute flattened parallel and reduction sizes. int flattenedParallelDim = 0; int flattenedReductionDim = 0; if (!parallelShapes.empty()) { flattenedParallelDim = 1; for (auto d : parallelShapes) flattenedParallelDim *= d; } if (!reductionShapes.empty()) { flattenedReductionDim = 1; for (auto d : reductionShapes) flattenedReductionDim *= d; } // We must at least have some parallel or some reduction. assert((flattenedParallelDim || flattenedReductionDim) && "expected at least one parallel or reduction dim"); // 3. Fail if reduction/parallel dims are not contiguous. // Check parallelDims are exactly [0 .. size). int64_t counter = 0; if (useInnerDimsForReduction && llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) return failure(); // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). counter = reductionDims.size(); if (!useInnerDimsForReduction && llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) return failure(); // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into // a single parallel (resp. reduction) dim. SmallVector mask; SmallVector scalableDims; SmallVector vectorShape; bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); if (flattenedParallelDim) { mask.push_back(false); vectorShape.push_back(flattenedParallelDim); scalableDims.push_back(isParallelDimScalable); } if (flattenedReductionDim) { mask.push_back(true); vectorShape.push_back(flattenedReductionDim); scalableDims.push_back(isReductionDimScalable); } if (!useInnerDimsForReduction && vectorShape.size() == 2) { std::swap(mask.front(), mask.back()); std::swap(vectorShape.front(), vectorShape.back()); std::swap(scalableDims.front(), scalableDims.back()); } Value newVectorMask; if (maskableOp.isMasked()) { Value vectorMask = maskableOp.getMaskingOp().getMask(); auto maskCastedType = VectorType::get( vectorShape, llvm::cast(vectorMask.getType()).getElementType()); newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType, vectorMask); } auto castedType = VectorType::get( vectorShape, multiReductionOp.getSourceVectorType().getElementType(), scalableDims); Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType, multiReductionOp.getSource()); Value acc = multiReductionOp.getAcc(); if (flattenedParallelDim) { auto accType = VectorType::get( {flattenedParallelDim}, multiReductionOp.getSourceVectorType().getElementType(), /*scalableDims=*/{isParallelDimScalable}); acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc); } // 6. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create( rewriter, loc, cast, acc, mask, multiReductionOp.getKind()); newMultiDimRedOp = mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); // 7. If there are no parallel shapes, the result is a scalar. // TODO: support 0-d vectors when available. if (parallelShapes.empty()) { rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0)); return success(); } // 8. Shape cast the flattened result back to the original n-D parallel // shape. VectorType outputCastedType = VectorType::get( parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), parallelScalableDims); rewriter.replaceOpWithNewOp( rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); return success(); } private: const bool useInnerDimsForReduction; }; /// Lowers 2D vector.multi_reduction to a squence of Arith Ops /// /// The reduction dimension must be the outer-most dimension. /// /// BEFORE: /// /// %1 = vector.multi_reduction , %src, %acc [0] : vector<4x2xf32> to /// vector<2xf32> /// /// AFTER: /// /// // Prod 1. /// %vec_0 = vector.extract %src[0] : vector<2xf32> from vector<4x2xf32> /// %mul_0 = arith.mulf %vec_0, %acc : vector<2xf32> /// /// // Prod 2. /// %vec_1 = vector.extract %src[1] : vector<2xf32> from vector<4x2xf32> /// %mul_2 = arith.mulf %vec_1, %mul_0 : vector<2xf32> /// /// // Prod 3. /// %vec_3 = vector.extract %src[2] : vector<2xf32> from vector<4x2xf32> /// %mul_3 = arith.mulf %vec_3, %mul_2 : vector<2xf32> /// /// // Prod 4. /// %vec_4 = vector.extract %src[3] : vector<2xf32> from vector<4x2xf32> /// %res = arith.mulf %vec_4, %mul_3 : vector<2xf32> struct TwoDimMultiReductionToElementWise : public vector::MaskableOpRewritePattern { using MaskableOpRewritePattern::MaskableOpRewritePattern; FailureOr matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, vector::MaskingOpInterface maskingOp, PatternRewriter &rewriter) const override { auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-2 ["parallel", "reduce"] or bail. if (srcRank != 2) return failure(); if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) return failure(); Value mask = maskingOp ? maskingOp.getMask() : Value(); auto loc = multiReductionOp.getLoc(); Value source = multiReductionOp.getSource(); ArrayRef srcShape = multiReductionOp.getSourceVectorType().getShape(); int outerDim = srcShape[0]; Value result = multiReductionOp.getAcc(); for (int64_t i = 0; i < outerDim; i++) { auto v = vector::ExtractOp::create(rewriter, loc, source, i); Value m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i)) : nullptr; result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), v, result, /*fastmath=*/nullptr, m); } return result; } }; /// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops. /// /// The reduction dimension must be the inner-most dimension. /// /// BEFORE: /// vector.multi_reduction , %src, %acc [1] : vector<2x4xf32> to /// vector<2xf32> /// /// AFTER: /// // 1st reduction /// %v_0 = vector.extract %src[0] : vector<4xf32> from vector<2x4xf32> /// %a_0 = vector.extract %acc[0] : f32 from vector<2xf32> /// %red_1 = vector.reduction , %v_0, %a_1 : vector<4xf32> into f32 /// %res_tmp = vector.insert %red_1, %res [0] : f32 into vector<2xf32> /// /// // 2nd reduction /// %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32> /// %a_1 = vector.extract %acc[1] : f32 from vector<2xf32> /// %red_2 = vector.reduction , %v_1, %a_1 : vector<4xf32> into f32 /// %res_final = vector.insert %red_2, %res_tmp [1] : f32 into vector<2xf32> struct TwoDimMultiReductionToReduction : public vector::MaskableOpRewritePattern { using MaskableOpRewritePattern::MaskableOpRewritePattern; FailureOr matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, vector::MaskingOpInterface maskingOp, PatternRewriter &rewriter) const override { auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-2 ["reduce", "parallel"] or bail. if (srcRank != 2) return failure(); if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) return failure(); Value mask = maskingOp ? maskingOp.getMask() : nullptr; auto loc = multiReductionOp.getLoc(); Value source = multiReductionOp.getSource(); Value acc = multiReductionOp.getAcc(); int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; Value result = arith::ConstantOp::create( rewriter, loc, multiReductionOp.getDestType(), rewriter.getZeroAttr(multiReductionOp.getDestType())); SmallVector vectors(outerDim); for (int64_t i = 0; i < outerDim; ++i) { Value v = vector::ExtractOp::create(rewriter, loc, source, i); Value a = vector::ExtractOp::create(rewriter, loc, acc, i); Operation *reductionOp = vector::ReductionOp::create( rewriter, loc, multiReductionOp.getKind(), v, a); if (mask) { Value m = vector::ExtractOp::create(rewriter, loc, mask, i); reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, m); } result = vector::InsertOp::create(rewriter, loc, reductionOp->getResult(0), result, i); } return result; } }; /// Converts 1D vector.multi_reduction directly to vector.reduction. /// /// Example: /// ```mlir /// // Before /// %r = vector.multi_reduction , %v, %acc [0] : vector to f32 /// /// // After /// %r = vector.reduction , %v, %acc : vector into f32 /// ``` struct OneDimMultiReductionToReduction : public vector::MaskableOpRewritePattern { using MaskableOpRewritePattern::MaskableOpRewritePattern; FailureOr matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, vector::MaskingOpInterface maskingOp, PatternRewriter &rewriter) const override { auto srcRank = multiReductionOp.getSourceVectorType().getRank(); if (srcRank != 1) return failure(); if (!multiReductionOp.isReducedDim(0)) return failure(); auto loc = multiReductionOp.getLoc(); Value mask = maskingOp ? maskingOp.getMask() : Value(); Operation *reductionOp = vector::ReductionOp::create( rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(), multiReductionOp.getAcc()); if (mask) reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); return reductionOp->getResult(0); } }; struct LowerVectorMultiReductionPass : public vector::impl::LowerVectorMultiReductionBase< LowerVectorMultiReductionPass> { LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) { this->loweringStrategy = option; } void runOnOperation() override { Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); mlir::vector::populateVectorMultiReductionReorderPatterns( patterns, this->loweringStrategy); if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); RewritePatternSet flatteningPatterns(context); mlir::vector::populateVectorMultiReductionFlatteningPatterns( flatteningPatterns, this->loweringStrategy); if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns)))) signalPassFailure(); RewritePatternSet unrollingPatterns(context); mlir::vector::populateVectorMultiReductionUnrollingPatterns( unrollingPatterns, this->loweringStrategy); if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns)))) signalPassFailure(); } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } }; } // namespace void mlir::vector::populateVectorMultiReductionReorderPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { patterns.add(patterns.getContext(), options, benefit); } void mlir::vector::populateVectorMultiReductionFlatteningPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { patterns.add(patterns.getContext(), options, benefit); } void mlir::vector::populateVectorMultiReductionUnrollingPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); if (options == VectorMultiReductionLowering ::InnerReduction) patterns.add(patterns.getContext(), benefit); else patterns.add(patterns.getContext(), benefit); } std::unique_ptr vector::createLowerVectorMultiReductionPass( vector::VectorMultiReductionLowering option) { return std::make_unique(option); }