[mlir][linalg] Make conv dim inference return pairing (outputImage, filterLoop) (#180859)

The original method sorts all the dimensions which loses the information
about pairing. It makes other transformation that works on generic op
form harder. The revision ensures the pairing, so callers have more
useful information when they work on transformations.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
This commit is contained in:
Han-Chung Wang
2026-02-11 13:18:18 -08:00
committed by GitHub
parent fc648683cd
commit 524ae2f347
5 changed files with 217 additions and 12 deletions

View File

@@ -105,8 +105,12 @@ struct ConvolutionDimensions {
/// 7. All dimensions appear only once in any given indexing map.
/// This allows e.g. detecting that some convolution is embedded within
/// `linalgOp` with some orthogonal heuristic.
/// When multiple dimension occurrences exist that match any classification
/// indices are returned in sorted order.
///
/// The `outputImage` and `filterLoop` arrays are ordered such that
/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
/// dimension 2). Other dimension sets are returned in sorted order.
///
/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);

View File

@@ -732,11 +732,17 @@ getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) {
/// Classifies dimensions in the `linalgOp` used by a convolution
/// subcomputation, as captured by `inputExprWalker`. If
/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
/// at least convolved dimension pair (output image + filter loop). Convolution
/// dimensions are specified in sorted order, and strides match the order of
/// the filter loop dimensions, while the dilations match the order of the
/// output image dimensions.
/// `allowEmptyConvolvedDims` is not set this will fail if there is not
/// at least one convolved dimension pair (output image + filter loop).
///
/// The returned dimensions are ordered as follows:
/// - `outputImage` is sorted by dimension index.
/// - `filterLoop` is ordered to match the pairing with `outputImage`, i.e.,
/// `outputImage[i]` and `filterLoop[i]` are paired dimensions from the
/// convolution access pattern (e.g., `oh + kh` pairs `oh` with `kh`).
/// - `strides[i]` corresponds to `outputImage[i]`.
/// - `dilations[i]` corresponds to `filterLoop[i]`.
/// - Other dimension sets (batch, outputChannel, etc.) are sorted by index.
static FailureOr<ConvolutionDimensions>
inferConvolutionDimsImpl(LinalgOp linalgOp,
ConvAccessExprWalker &inputExprWalker,
@@ -784,12 +790,13 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
if (oi.empty() && !allowEmptyConvolvedDims)
return failure();
// Return each set in sorted order.
// Return each set in sorted order, with outputImage and filterLoop
// ordered so that outputImage[i] pairs with filterLoop[i].
ConvolutionDimensions dimensions{
SmallVector<unsigned, 2>(batch.begin(), batch.end()),
SmallVector<unsigned, 2>(oi.begin(), oi.end()),
SmallVector<unsigned, 2>(oc.begin(), oc.end()),
SmallVector<unsigned, 2>(fl.begin(), fl.end()),
/*filterLoop=*/SmallVector<unsigned, 2>{},
SmallVector<unsigned, 2>(ic.begin(), ic.end()),
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
@@ -797,9 +804,14 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
llvm::sort(dimensions.batch);
llvm::sort(dimensions.outputImage);
llvm::sort(dimensions.outputChannel);
llvm::sort(dimensions.filterLoop);
llvm::sort(dimensions.inputChannel);
llvm::sort(dimensions.depth);
// Order filterLoop to match the pairing with outputImage. Each outputImage
// dimension has a corresponding filterLoop dimension from the convolution
// access pattern (e.g., oh + kh). This ensures outputImage[i] pairs with
// filterLoop[i].
for (unsigned oiDim : dimensions.outputImage)
dimensions.filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
@@ -846,8 +858,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
/// 7. All dimensions appear only once in any given indexing map.
/// This allows e.g. detecting that some convolution is embedded within
/// `linalgOp` with some orthogonal heuristic.
/// When multiple dimension occurrences exist that match any classification
/// indices are returned in sorted order.
///
/// The `outputImage` and `filterLoop` arrays are ordered such that
/// `outputImage[i]` pairs with `filterLoop[i]` based on the convolution access
/// pattern in the input indexing map (e.g., `d0 + d2` pairs dimension 0 with
/// dimension 2). Other dimension sets are returned in sorted order.
///
/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
FailureOr<ConvolutionDimensions>
mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {

View File

@@ -9,6 +9,7 @@ mlir_target_link_libraries(MLIRDialectTests
add_subdirectory(AMDGPU)
add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(OpenACC)

View File

@@ -0,0 +1,11 @@
add_mlir_unittest(MLIRLinalgTests
InferConvolutionDimsTest.cpp
)
mlir_target_link_libraries(MLIRLinalgTests
PRIVATE
MLIRIR
MLIRArithDialect
MLIRFuncDialect
MLIRLinalgDialect
MLIRTensorDialect
)

View File

@@ -0,0 +1,173 @@
//===- InferConvolutionDimsTest.cpp - inferConvolutionDims unit tests -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
class InferConvolutionDimsTest : public ::testing::Test {
protected:
void SetUp() override {
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
tensor::TensorDialect, func::FuncDialect>();
ctx = std::make_unique<MLIRContext>(registry);
ctx->loadAllAvailableDialects();
}
DialectRegistry registry;
std::unique_ptr<MLIRContext> ctx;
};
/// Creates a Conv2DOp with loop order (d0, d1, d2, d3) where:
/// d0 = output height (oh), parallel
/// d1 = output width (ow), parallel
/// d2 = kernel height (kh), reduction
/// d3 = kernel width (kw), reduction
///
/// Indexing maps:
/// input: (d0 + d2, d1 + d3)
/// filter: (d2, d3)
/// output: (d0, d1)
///
/// Semantic pairing: d0 <-> d2, d1 <-> d3
static linalg::Conv2DOp createConv2DOp(OpBuilder &builder, int64_t oh,
int64_t ow, int64_t kh, int64_t kw) {
Location loc = builder.getUnknownLoc();
auto f32Type = builder.getF32Type();
int64_t ih = oh + kh - 1;
int64_t iw = ow + kw - 1;
auto inputType = RankedTensorType::get({ih, iw}, f32Type);
auto filterType = RankedTensorType::get({kh, kw}, f32Type);
auto outputType = RankedTensorType::get({oh, ow}, f32Type);
Value input = tensor::EmptyOp::create(builder, loc, inputType.getShape(),
inputType.getElementType());
Value filter = tensor::EmptyOp::create(builder, loc, filterType.getShape(),
filterType.getElementType());
Value output = tensor::EmptyOp::create(builder, loc, outputType.getShape(),
outputType.getElementType());
return linalg::Conv2DOp::create(
builder, loc, outputType, ValueRange{input, filter}, ValueRange{output});
}
/// Creates a linalg.generic equivalent to the given Conv2DOp but with filter
/// loop dimensions swapped. The resulting op has loop order (d0, d1, d2, d3):
/// d0 = output height (oh), parallel
/// d1 = output width (ow), parallel
/// d2 = kernel width (kw), reduction <-- swapped!
/// d3 = kernel height (kh), reduction <-- swapped!
///
/// Indexing maps:
/// input: (d0 + d3, d1 + d2)
/// filter: (d2, d3)
/// output: (d0, d1)
///
/// Semantic pairing: d0 <-> d3, d1 <-> d2
static linalg::GenericOp
createConv2DWithSwappedFilterLoops(OpBuilder &builder,
linalg::Conv2DOp conv2DOp) {
Location loc = conv2DOp.getLoc();
MLIRContext *ctx = builder.getContext();
// Extract dimensions from the Conv2DOp. Require static shapes for simplicity.
auto inputType = cast<RankedTensorType>(conv2DOp.getInputs()[0].getType());
auto filterType = cast<RankedTensorType>(conv2DOp.getInputs()[1].getType());
auto outputType = cast<RankedTensorType>(conv2DOp.getOutputs()[0].getType());
assert(inputType.hasStaticShape() && "expected static input shape");
assert(filterType.hasStaticShape() && "expected static filter shape");
assert(outputType.hasStaticShape() && "expected static output shape");
(void)outputType;
int64_t kh = filterType.getDimSize(0);
int64_t kw = filterType.getDimSize(1);
// Filter dimensions are swapped: (kw, kh) instead of (kh, kw).
auto f32Type = builder.getF32Type();
auto swappedFilterType = RankedTensorType::get({kw, kh}, f32Type);
Value input = tensor::EmptyOp::create(builder, loc, inputType.getShape(),
inputType.getElementType());
Value filter =
tensor::EmptyOp::create(builder, loc, swappedFilterType.getShape(),
swappedFilterType.getElementType());
Value output = tensor::EmptyOp::create(builder, loc, outputType.getShape(),
outputType.getElementType());
// Build indexing maps for swapped filter loop dimensions.
// Original Conv2DOp: (d0=oh, d1=ow, d2=kh, d3=kw)
// Swapped: (d0=oh, d1=ow, d2=kw, d3=kh)
AffineExpr d0, d1, d2, d3;
bindDims(ctx, d0, d1, d2, d3);
auto inputMap = AffineMap::get(4, 0, {d0 + d3, d1 + d2}, ctx);
auto filterMap = AffineMap::get(4, 0, {d2, d3}, ctx);
auto outputMap = AffineMap::get(4, 0, {d0, d1}, ctx);
SmallVector<AffineMap> indexingMaps = {inputMap, filterMap, outputMap};
SmallVector<utils::IteratorType> iterTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::reduction, utils::IteratorType::reduction};
return linalg::GenericOp::create(
builder, loc, outputType, ValueRange{input, filter}, ValueRange{output},
indexingMaps, iterTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value mul = arith::MulFOp::create(b, loc, args[0], args[1]);
Value add = arith::AddFOp::create(b, loc, args[2], mul);
linalg::YieldOp::create(b, loc, add);
});
}
TEST_F(InferConvolutionDimsTest, Conv2DPairing) {
// Use non-square kernel to ensure dimension swapping is tested properly.
const int64_t oh = 6, ow = 12, kh = 3, kw = 5;
// Create Conv2DOp where the standard loop order is (oh, ow, kh, kw).
OpBuilder builder(ctx.get());
linalg::Conv2DOp conv2DOp = createConv2DOp(builder, oh, ow, kh, kw);
FailureOr<ConvolutionDimensions> origDims = inferConvolutionDims(conv2DOp);
ASSERT_TRUE(succeeded(origDims));
ASSERT_EQ(origDims->outputImage.size(), 2u);
ASSERT_EQ(origDims->filterLoop.size(), 2u);
// Standard pairing: outputImage=[0,1], filterLoop=[2,3]
// d0 <-> d2 (oh <-> kh), d1 <-> d3 (ow <-> kw)
EXPECT_EQ(origDims->outputImage[0], 0u);
EXPECT_EQ(origDims->outputImage[1], 1u);
EXPECT_EQ(origDims->filterLoop[0], 2u);
EXPECT_EQ(origDims->filterLoop[1], 3u);
// Create equivalent generic with swapped filter loop order: (oh, ow, kw, kh)
linalg::GenericOp swappedOp =
createConv2DWithSwappedFilterLoops(builder, conv2DOp);
FailureOr<ConvolutionDimensions> swappedDims =
inferConvolutionDims(swappedOp);
ASSERT_TRUE(succeeded(swappedDims));
ASSERT_EQ(swappedDims->outputImage.size(), 2u);
ASSERT_EQ(swappedDims->filterLoop.size(), 2u);
// outputImage should still be [0, 1] after sorting.
EXPECT_EQ(swappedDims->outputImage[0], 0u);
EXPECT_EQ(swappedDims->outputImage[1], 1u);
// In swapped version:
// Input map: (d0 + d3, d1 + d2) -> d0 <-> d3, d1 <-> d2
// So filterLoop should be [3, 2] to maintain
// outputImage[i] <-> filterLoop[i].
EXPECT_EQ(swappedDims->filterLoop[0], 3u)
<< "outputImage[0]=0 should pair with filterLoop[0]=3 (oh <-> kh)";
EXPECT_EQ(swappedDims->filterLoop[1], 2u)
<< "outputImage[1]=1 should pair with filterLoop[1]=2 (ow <-> kw)";
}
} // namespace