//===- VectorLinearize.cpp - vector linearization transforms --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns and pass for linearizing ND vectors into 1D. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include #include #include using namespace mlir; static FailureOr linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value) { if (auto dstElementsAttr = dyn_cast(value)) { if (resType.isScalable() && !isa(value)) return rewriter.notifyMatchFailure( loc, "Cannot linearize a constant scalable vector that's not a splat"); return dstElementsAttr.reshape(resType); } if (auto poisonAttr = dyn_cast(value)) return poisonAttr; return rewriter.notifyMatchFailure(loc, "unsupported attr type"); } namespace { struct LinearizeConstantLike final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; LinearizeConstantLike(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); if (op->getNumResults() != 1) return rewriter.notifyMatchFailure(loc, "expected 1 result"); const TypeConverter &typeConverter = *getTypeConverter(); auto resType = typeConverter.convertType(op->getResult(0).getType()); assert(resType && "expected 1-D vector type"); StringAttr attrName = rewriter.getStringAttr("value"); Attribute value = op->getAttr(attrName); if (!value) return rewriter.notifyMatchFailure(loc, "no 'value' attr"); FailureOr newValue = linearizeConstAttr(loc, rewriter, resType, value); if (failed(newValue)) return failure(); FailureOr convertResult = convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter); if (failed(convertResult)) return failure(); Operation *newOp = *convertResult; newOp->setAttr(attrName, *newValue); rewriter.replaceOp(op, newOp); return success(); } }; struct LinearizeVectorizable final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; public: LinearizeVectorizable(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { FailureOr newOp = convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); if (failed(newOp)) return failure(); rewriter.replaceOp(op, (*newOp)->getResults()); return success(); } }; template static bool stridesAllOne(TOp op) { static_assert( std::is_same_v || std::is_same_v, "expected vector.extract_strided_slice or vector.insert_strided_slice"); ArrayAttr strides = op.getStrides(); return llvm::all_of(strides, isOneInteger); } /// Convert an array of attributes into a vector of integers, if possible. static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { if (!attrs) return failure(); SmallVector ints; ints.reserve(attrs.size()); for (auto attr : attrs) { if (auto intAttr = dyn_cast(attr)) { ints.push_back(intAttr.getInt()); } else { return failure(); } } return ints; } /// Consider inserting a vector of shape `small` into a vector of shape `large`, /// at position `offsets`: this function enumeratates all the indices in `large` /// that are written to. The enumeration is with row-major ordering. /// /// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 /// positions written to are (1,3) and (1,4), which have linearized indices 8 /// and 9. So [8,9] is returned. /// /// The length of the returned vector is equal to the number of elements in /// the shape `small` (i.e. the product of dimensions of `small`). SmallVector static getStridedSliceInsertionIndices( ArrayRef small, ArrayRef large, ArrayRef offsets) { // Example of alignment between, `large`, `small` and `offsets`: // large = 4, 5, 6, 7, 8 // small = 1, 6, 7, 8 // offsets = 2, 3, 0 // // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. assert((large.size() >= small.size()) && "rank of 'large' cannot be lower than rank of 'small'"); assert((large.size() >= offsets.size()) && "rank of 'large' cannot be lower than the number of offsets"); unsigned delta = large.size() - small.size(); unsigned nOffsets = offsets.size(); auto getSmall = [&](int64_t i) -> int64_t { return i >= delta ? small[i - delta] : 1; }; auto getOffset = [&](int64_t i) -> int64_t { return i < nOffsets ? offsets[i] : 0; }; // Using 2 vectors of indices, at each iteration populate the updated set of // indices based on the old set of indices, and the size of the small vector // in the current iteration. SmallVector indices{0}; int64_t stride = 1; for (int i = large.size() - 1; i >= 0; --i) { int64_t currentSize = indices.size(); int64_t smallSize = getSmall(i); int64_t nextSize = currentSize * smallSize; SmallVector nextIndices(nextSize); int64_t *base = nextIndices.begin(); int64_t offset = getOffset(i) * stride; for (int j = 0; j < smallSize; ++j) { for (int k = 0; k < currentSize; ++k) { base[k] = indices[k] + offset; } offset += stride; base += currentSize; } stride *= large[i]; indices = std::move(nextIndices); } return indices; } /// This pattern converts a vector.extract_strided_slice operation into a /// vector.shuffle operation that has a rank-1 (linearized) operand and result. /// /// For example, the following: /// /// ``` /// vector.extract_strided_slice %source /// { offsets = [..], strides = [..], sizes = [..] } /// ``` /// /// is converted to : /// ``` /// %source_1d = vector.shape_cast %source /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d /// ``` /// /// `shuffle_indices_1d` is computed using the offsets and sizes of the original /// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern { using Base::Base; LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType flatOutputType = getTypeConverter()->convertType( extractStridedSliceOp.getType()); assert(flatOutputType && "vector type expected"); // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for extract_strided_slice allows non-1 strides). if (!stridesAllOne(extractStridedSliceOp)) { return rewriter.notifyMatchFailure( extractStridedSliceOp, "extract_strided_slice with strides != 1 not supported"); } FailureOr> offsets = intsFromArrayAttr(extractStridedSliceOp.getOffsets()); if (failed(offsets)) { return rewriter.notifyMatchFailure(extractStridedSliceOp, "failed to get integer offsets"); } ArrayRef inputShape = extractStridedSliceOp.getSourceVectorType().getShape(); ArrayRef outputShape = extractStridedSliceOp.getType().getShape(); SmallVector indices = getStridedSliceInsertionIndices( outputShape, inputShape, offsets.value()); Value srcVector = adaptor.getSource(); rewriter.replaceOpWithNewOp( extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); return success(); } }; /// This pattern converts a vector.insert_strided_slice operation into a /// vector.shuffle operation that has rank-1 (linearized) operands and result. /// /// For example, the following: /// ``` /// %0 = vector.insert_strided_slice %to_store, %into /// {offsets = [1, 0, 0, 0], strides = [1, 1]} /// : vector<2x2xi8> into vector<2x1x3x2xi8> /// ``` /// /// is converted to /// ``` /// %to_store_1d /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> /// ``` /// /// where shuffle_indices_1d in this case is /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. /// ^^^^^^^^^^^^^^ /// to_store_1d /// struct LinearizeVectorInsertStridedSlice final : public mlir::OpConversionPattern { using Base::Base; LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for insert_strided_slice allows non-1 strides). if (!stridesAllOne(insertStridedSliceOp)) { return rewriter.notifyMatchFailure( insertStridedSliceOp, "insert_strided_slice with strides != 1 not supported"); } VectorType inputType = insertStridedSliceOp.getValueToStore().getType(); ArrayRef inputShape = inputType.getShape(); VectorType outputType = insertStridedSliceOp.getType(); ArrayRef outputShape = outputType.getShape(); int64_t nOutputElements = outputType.getNumElements(); FailureOr> offsets = intsFromArrayAttr(insertStridedSliceOp.getOffsets()); if (failed(offsets)) { return rewriter.notifyMatchFailure(insertStridedSliceOp, "failed to get integer offsets"); } SmallVector sliceIndices = getStridedSliceInsertionIndices( inputShape, outputShape, offsets.value()); SmallVector indices(nOutputElements); std::iota(indices.begin(), indices.end(), 0); for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { indices[sliceIndex] = index + nOutputElements; } Value flatToStore = adaptor.getValueToStore(); Value flatDest = adaptor.getDest(); rewriter.replaceOpWithNewOp(insertStridedSliceOp, flatDest.getType(), flatDest, flatToStore, indices); return success(); } }; /// This pattern converts the ShuffleOp that works on nD (n > 1) /// vectors to a ShuffleOp that works on linearized vectors. /// Following, /// vector.shuffle %v1, %v2 [ shuffle_indices ] /// is converted to : /// %v1_1d = vector.shape_cast %v1 /// %v2_1d = vector.shape_cast %v2 /// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] /// %out_nd = vector.shape_cast %out_1d // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern { using Base::Base; LinearizeVectorShuffle(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstType = getTypeConverter()->convertType(shuffleOp.getType()); assert(dstType && "vector type destination expected."); Value vec1 = adaptor.getV1(); Value vec2 = adaptor.getV2(); int shuffleSliceLen = 1; int rank = shuffleOp.getV1().getType().getRank(); // If rank > 1, we need to do the shuffle in the granularity of slices // instead of scalars. Size of the slice is equal to the rank-1 innermost // dims. Mask of the shuffle op specifies which slice to take from the // outermost dim. if (rank > 1) { llvm::ArrayRef shape = shuffleOp.getV1().getType().getShape(); for (unsigned i = 1; i < shape.size(); ++i) { shuffleSliceLen *= shape[i]; } } // For each value in the mask, we generate the indices of the source vectors // that need to be shuffled to the destination vector. If shuffleSliceLen > // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of // elements) instead of scalars. ArrayRef mask = shuffleOp.getMask(); int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; llvm::SmallVector indices(totalSizeOfShuffledElmnts); for (auto [i, value] : llvm::enumerate(mask)) { std::iota(indices.begin() + shuffleSliceLen * i, indices.begin() + shuffleSliceLen * (i + 1), shuffleSliceLen * value); } rewriter.replaceOpWithNewOp(shuffleOp, dstType, vec1, vec2, indices); return success(); } }; /// This pattern linearizes `vector.extract` operations. It generates a 1-D /// version of the `vector.extract` operation when extracting a scalar from a /// vector. It generates a 1-D `vector.shuffle` operation when extracting a /// subvector from a larger vector. /// /// Example #1: /// /// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> /// /// is converted to: /// /// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32> /// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23, /// 24, 25, 26, 27, 28, 29, 30, 31] : /// vector<32xf32>, vector<32xf32> /// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32> /// /// Example #2: /// /// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32> /// /// is converted to: /// /// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32> /// %1 = vector.extract %0[6] : i32 from vector<8xi32> /// struct LinearizeVectorExtract final : public OpConversionPattern { using Base::Base; LinearizeVectorExtract(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstTy = getTypeConverter()->convertType(extractOp.getType()); assert(dstTy && "expected 1-D vector type"); // Dynamic position is not supported. if (extractOp.hasDynamicPosition()) return rewriter.notifyMatchFailure(extractOp, "dynamic position is not supported."); llvm::ArrayRef shape = extractOp.getSource().getType().getShape(); int64_t size = extractOp.getSource().getType().getNumElements(); // Compute linearized offset. int64_t linearizedOffset = 0; llvm::ArrayRef offsets = extractOp.getStaticPosition(); for (auto [i, off] : llvm::enumerate(offsets)) { size /= shape[i]; linearizedOffset += offsets[i] * size; } Value srcVector = adaptor.getSource(); if (!isa(extractOp.getType())) { // Scalar case: generate a 1-D extract. Value result = rewriter.createOrFold( extractOp.getLoc(), srcVector, linearizedOffset); rewriter.replaceOp(extractOp, result); return success(); } // Vector case: generate a shuffle. llvm::SmallVector indices(size); std::iota(indices.begin(), indices.end(), linearizedOffset); rewriter.replaceOpWithNewOp(extractOp, dstTy, srcVector, srcVector, indices); return success(); } }; /// This pattern linearizes `vector.insert` operations. It generates a 1-D /// version of the `vector.insert` operation when inserting a scalar into a /// vector. It generates a 1-D `vector.shuffle` operation when inserting a /// vector into another vector. /// /// Example #1: /// /// %0 = vector.insert %source, %destination[0] : /// vector<2x4xf32> into vector<2x2x4xf32> /// /// is converted to: /// /// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32> /// %1 = vector.shape_cast %destination : /// vector<2x2x4xf32> to vector<16xf32> /// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23 /// 8, 9, 10, 11, 12, 13, 14, 15] : /// vector<16xf32>, vector<8xf32> /// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32> /// /// Example #2: /// /// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32> /// /// is converted to: /// /// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32> /// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32> /// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32> /// struct LinearizeVectorInsert final : public OpConversionPattern { using Base::Base; LinearizeVectorInsert(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstTy = getTypeConverter()->convertType( insertOp.getDestVectorType()); assert(dstTy && "vector type destination expected."); // Dynamic position is not supported. if (insertOp.hasDynamicPosition()) return rewriter.notifyMatchFailure(insertOp, "dynamic position is not supported."); auto srcTy = insertOp.getValueToStoreType(); auto srcAsVec = dyn_cast(srcTy); uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1; auto dstShape = insertOp.getDestVectorType().getShape(); const auto dstSize = insertOp.getDestVectorType().getNumElements(); auto dstSizeForOffsets = dstSize; // Compute linearized offset. int64_t linearizedOffset = 0; auto offsetsNd = insertOp.getStaticPosition(); for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { dstSizeForOffsets /= dstShape[dim]; linearizedOffset += offset * dstSizeForOffsets; } Location loc = insertOp.getLoc(); Value valueToStore = adaptor.getValueToStore(); if (!isa(valueToStore.getType())) { // Scalar case: generate a 1-D insert. Value result = rewriter.createOrFold( loc, valueToStore, adaptor.getDest(), linearizedOffset); rewriter.replaceOp(insertOp, result); return success(); } // Vector case: generate a shuffle. llvm::SmallVector indices(dstSize); auto *origValsUntil = indices.begin(); std::advance(origValsUntil, linearizedOffset); // Original values that remain [0, offset). std::iota(indices.begin(), origValsUntil, 0); auto *newValsUntil = origValsUntil; std::advance(newValsUntil, srcSize); // New values [offset, offset+srcNumElements). std::iota(origValsUntil, newValsUntil, dstSize); // The rest of original values [offset+srcNumElements, end); std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize); Value result = rewriter.createOrFold( loc, dstTy, adaptor.getDest(), valueToStore, indices); rewriter.replaceOp(insertOp, result); return success(); } }; /// This pattern converts the BitCastOp that works on nD (n > 1) /// vectors to a BitCastOp that works on linearized vectors. /// Following, /// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> /// is converted to : /// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> /// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern { using Base::Base; LinearizeVectorBitCast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resType = getTypeConverter()->convertType(castOp.getType()); assert(resType && "expected 1-D vector type"); rewriter.replaceOpWithNewOp(castOp, resType, adaptor.getSource()); return mlir::success(); } }; /// This pattern converts the CreateMaskOp to work on a linearized vector. /// It currently supports only 2D masks with a unit outer dimension. /// Following, /// vector.create_mask %arg0, %arg1 : vector<1x4xi1> /// is converted to: /// %zero = arith.constant 0 : index /// %cmpi = arith.cmpi sgt, %arg0, %zero : index /// %index = arith.index_cast %cmpi : i1 to index /// %mul = arith.andi %index, %arg1 : index /// %mask = vector.create_mask %mul : vector<4xi1> /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern { using Base::Base; LinearizeVectorCreateMask(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = createMaskOp.getLoc(); VectorType srcTy = createMaskOp.getType(); auto srcShape = srcTy.getShape(); if (srcShape.size() != 2) return rewriter.notifyMatchFailure(createMaskOp, "only 2D mask is supported."); if (srcShape[0] != 1) return rewriter.notifyMatchFailure( createMaskOp, "only unit outer dimension is supported."); auto dstTy = getTypeConverter()->convertType(srcTy); if (!dstTy) return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); // Compare the first operand with 0. If it is greater than 0, the // corresponding mask element is set to true, otherwise false. // The result of the comparison is then multiplied with // the second operand of create_mask to get the 1D mask. auto firstOperand = adaptor.getOperands().front(); auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); auto isNonZero = rewriter.createOrFold( loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); auto isNonZeroIndex = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZero); auto secondOperand = adaptor.getOperands().back(); auto maskSize = rewriter.createOrFold( loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); auto newMask = mlir::vector::CreateMaskOp::create(rewriter, loc, dstTy, maskSize); rewriter.replaceOp(createMaskOp, newMask); return success(); } }; /// This pattern linearizes vector.load from vector<1x1x...xN> to vector /// It currently supports linearization where all but the last dimension are 1 /// The following, /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> /// is converted to: /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32> /// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32> /// For generic cases, the vector unroll pass should be used to unroll the load /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern { using Base::Base; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecTy = loadOp.getType(); if (!vecTy) return rewriter.notifyMatchFailure(loadOp, "expected vector type"); auto shape = vecTy.getShape(); auto scalableDims = vecTy.getScalableDims(); // All but the last dim must be 1, and only the last dim may be scalable (if // any). if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; })) return rewriter.notifyMatchFailure(loadOp, "only vector<1x1x...xN> supported"); if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; })) return rewriter.notifyMatchFailure(loadOp, "only innermost dim may be scalable"); auto linearTy = typeConverter->convertType(vecTy); auto newLoad = vector::LoadOp::create(rewriter, loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOp(loadOp, newLoad.getResult()); return success(); } }; /// This pattern linearizes vector.store from vector<1x1x...xN> to vector /// It currently supports linearization where all but the last dimension are 1 /// The following, /// vector.store %arg0, %arg1[%c0, %c0]s /// : vector<1x4xf32>, memref<1x4xf32> /// is converted to: /// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> /// vector.store %arg0, %arg1[%c0, %c0] /// : vector<4xf32>, memref<1x4xf32> /// For generic cases, the vector unroll pass should be used to unroll the store /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern { using Base::Base; LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecTy = storeOp.getValueToStore().getType(); if (!vecTy) return rewriter.notifyMatchFailure(storeOp, "expected vector type"); auto shape = vecTy.getShape(); auto scalableDims = vecTy.getScalableDims(); // All but the last dim must be 1, and only the last dim may be scalable (if // any). if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; })) return rewriter.notifyMatchFailure(storeOp, "only vector<1x1x...xN> supported"); if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; })) return rewriter.notifyMatchFailure(storeOp, "only innermost dim may be scalable"); rewriter.replaceOpWithNewOp( storeOp, adaptor.getValueToStore(), adaptor.getBase(), adaptor.getIndices()); return success(); } }; /// This pattern linearizes `vector.from_elements` operations by converting /// the result type to a 1-D vector while preserving all element values. /// The transformation creates a linearized `vector.from_elements` followed by /// a `vector.shape_cast` to restore the original multidimensional shape. /// /// Example: /// /// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32> /// /// is converted to: /// /// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32> /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> /// struct LinearizeVectorFromElements final : public OpConversionPattern { using Base::Base; LinearizeVectorFromElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstTy = getTypeConverter()->convertType(fromElementsOp.getType()); assert(dstTy && "vector type destination expected."); OperandRange elements = fromElementsOp.getElements(); assert(elements.size() == static_cast(dstTy.getNumElements()) && "expected same number of elements"); rewriter.replaceOpWithNewOp(fromElementsOp, dstTy, elements); return success(); } }; /// This pattern linearizes the operand in `vector.to_elements` operations /// by converting the source type to a 1-D vector while preserving all element /// values. The transformation creates a linearized `vector.shape_cast` /// followed by a `vector.to_elements`. /// /// Example: /// /// %0:4 = vector.to_elements %v : vector<2x2xf32> /// /// is converted to: /// /// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32> /// %0:4 = vector.to_elements %vector_cast : vector<4xf32> /// struct LinearizeVectorToElements final : public OpConversionPattern { using Base::Base; LinearizeVectorToElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecType = toElementsOp.getSource().getType(); if (vecType.getRank() <= 1) return rewriter.notifyMatchFailure( toElementsOp, "the rank is already less than or equal to 1"); assert(vecType.getNumScalableDims() == 0 && "to_elements does not support scalable vectors"); auto vec1DType = VectorType::get({vecType.getNumElements()}, vecType.getElementType()); Value shapeCast = vector::ShapeCastOp::create( rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource()); auto newToElementsOp = vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(), toElementsOp.getResultTypes(), shapeCast); rewriter.replaceOp(toElementsOp, newToElementsOp); return success(); } }; /// Convert broadcasts from scalars or 1-element vectors, such as /// /// ```mlir /// vector.broadcast %value : f32 to vector<4x4xf32> /// ``` /// /// to broadcasts to rank-1 vectors, with shape_casts before/after as needed. /// The above becomes, /// /// ```mlir /// %out_1d = vector.broadcast %value : f32 to vector<16xf32> /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> /// ``` struct LinearizeVectorBroadcast final : public OpConversionPattern { using Base::Base; LinearizeVectorBroadcast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int numElements = 1; Type sourceType = broadcastOp.getSourceType(); if (auto vecType = dyn_cast(sourceType)) { numElements = vecType.getNumElements(); } if (numElements != 1) { return rewriter.notifyMatchFailure( broadcastOp, "only broadcasts of single elements can be linearized."); } auto dstTy = getTypeConverter()->convertType(broadcastOp.getType()); rewriter.replaceOpWithNewOp(broadcastOp, dstTy, adaptor.getSource()); return success(); } }; } // namespace /// This method defines the set of operations that are linearizable, and hence /// that are considered illegal for the conversion target. static bool isLinearizable(Operation *op) { // Only ops that are in the vector dialect, are ConstantLike, or // are Vectorizable might be linearized currently. StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); StringRef opDialect = op->getDialect()->getNamespace(); bool supported = (opDialect == vectorDialect) || op->hasTrait() || op->hasTrait(); if (!supported) return false; return TypeSwitch(op) // As type legalization is done with vector.shape_cast, shape_cast // itself cannot be linearized (will create new shape_casts to linearize // ad infinitum). .Case([&](vector::ShapeCastOp) { return false; }) // The operations // - vector.extract_strided_slice // - vector.extract // - vector.insert_strided_slice // - vector.insert // are linearized to a rank-1 vector.shuffle by the current patterns. // vector.shuffle only supports fixed size vectors, so it is impossible to // use this approach to linearize these ops if they operate on scalable // vectors. .Case([&](vector::ExtractStridedSliceOp extractOp) { return !extractOp.getType().isScalable(); }) .Case([&](vector::InsertStridedSliceOp insertOp) { return !insertOp.getType().isScalable(); }) .Case([&](vector::InsertOp insertOp) { return !insertOp.getType().isScalable(); }) .Case([&](vector::ExtractOp extractOp) { return !extractOp.getSourceVectorType().isScalable(); }) .Default([&](auto) { return true; }); } void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, ConversionTarget &target) { auto convertType = [](Type type) -> std::optional { VectorType vectorType = dyn_cast(type); if (!vectorType || !isLinearizableVector(vectorType)) return type; VectorType linearizedType = VectorType::get(vectorType.getNumElements(), vectorType.getElementType(), vectorType.isScalable()); return linearizedType; }; typeConverter.addConversion(convertType); auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return nullptr; Value value = inputs.front(); if (!isa(type) || !isa(value.getType())) return nullptr; return vector::ShapeCastOp::create(builder, loc, type, value); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { if (!isLinearizable(op)) return true; // This will return true if, for all operand and result types `t`, // convertType(t) = t. This is true if there are no rank>=2 vectors. return typeConverter.isLegal(op); }); } void mlir::vector::populateVectorLinearizeBasePatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { patterns .add(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); }