From c10f33e8935de001a10f46bc370dc12aa67e5674 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 30 Apr 2026 10:06:09 +0200 Subject: [PATCH] [mlir][linalg] Fuse transform op - variadic tile sizes (#194657) Extends the 'structured.fuse' op to accept packed handle containing variable number of tile sizes. Use of packed handles allows for runtime tiling decisions for improved transform schedule flexibility and reusability. The extension's design follows the existing approach of transform 'structured.tile_using_forall' op to more closely align their usage. In case of tiling using nested loops, all created loops are packed into a single return handle. For each target op, corresponding loops are appended to the result handle. Assisted-by: Claude --- .../Linalg/TransformOps/LinalgTransformOps.td | 11 +- .../TransformOps/LinalgTransformOps.cpp | 61 +++++++-- mlir/python/mlir/dialects/_ods_common.py | 2 + .../mlir/dialects/transform/structured.py | 11 +- .../Dialect/Linalg/transform-op-fuse.mlir | 125 ++++++++++++++++++ .../dialects/transform_structured_ext.py | 27 ++++ 6 files changed, 220 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index cb61177bc753..1a59f4c7d1ac 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -423,6 +423,12 @@ def FuseOp : Op : $tile_sizes, Variadic : $tile_interchange, + Optional : $packed_tile_sizes, DefaultValuedOptionalAttr:$static_tile_sizes, DefaultValuedOptionalAttr:$static_tile_interchange, UnitAttr:$apply_cleanup, @@ -465,7 +472,9 @@ def FuseOp : Op($tile_sizes, $static_tile_sizes) | + `tile_sizes` custom($packed_tile_sizes, + $tile_sizes, + $static_tile_sizes) | `interchange` custom($tile_interchange, $static_tile_interchange) ) attr-dict `:` functional-type(operands, results) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index baa57f892009..f44693096b26 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -654,6 +654,7 @@ void transform::FuseOp::build(OpBuilder &builder, OperationState &result, /*target=*/target, /*tile_sizes=*/dynamicTileSizes, /*tile_interchange=*/dynamicTileInterchange, + /*packed_tile_sizes=*/Value(), /*static_tile_sizes=*/staticTileSizesAttr, /*static_tile_interchange=*/staticTileInterchangeAttr, /*apply_cleanup=*/applyCleanup, @@ -666,10 +667,12 @@ template static LogicalResult applyTilingToAll( RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, + bool packedResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); + size_t numTargets = llvm::range_size(payloadOps); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); @@ -704,8 +707,22 @@ static LogicalResult applyTilingToAll( } transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); - for (unsigned int i = 0; i < numLoops; ++i) - transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + if (packedResults) { + // In case of packed results, all created loops are assigned to a single + // handle. Loops are returned in order of targets such as: + // %loops_handle = { + // target0:loop0, ..., target0:loopN, + // target1:loop0, ..., target1:loopN, + // ... } + SmallVector flattenedLoopOps; + for (unsigned int idx = 0; idx < numTargets; ++idx) + for (unsigned int i = 0; i < numLoops; ++i) + flattenedLoopOps.push_back(loopOps[i][idx]); + transformResults.set(transformOp->getOpResult(1), flattenedLoopOps); + } else { + for (unsigned int i = 0; i < numLoops; ++i) + transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + } return success(); } @@ -716,9 +733,13 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformState &state) { auto transformOp = cast(getOperation()); - SmallVector tileSizes; - DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( - state, transformOp, getMixedTileSizes(), tileSizes); + SmallVector mixedTileSizes; + DiagnosedSilenceableFailure status = + getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); if (!status.succeeded()) return status; SmallVector tileInterchange; @@ -733,9 +754,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tilingOptions.setLoopType(useForall ? scf::SCFTilingOptions::LoopType::ForallOp : scf::SCFTilingOptions::LoopType::ForOp); - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); - tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); + tilingOptions = tilingOptions.setTileSizes(mixedTileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; @@ -748,11 +767,20 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } - size_t numLoops = - useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); + size_t numLoops; + if (useForall) { + numLoops = 1; + } else { + numLoops = llvm::count_if(mixedTileSizes, [](OpFoldResult ofr) { + auto attr = dyn_cast(ofr); + if (!attr) + return true; + return cast(attr).getInt() != 0; + }); + } LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, - transformResults, + transformResults, /*packedResults=*/getPackedTileSizes() != nullptr, [&](TilingInterface tilingInterfaceOp) -> FailureOr { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -763,6 +791,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { + bool hasPackedTiles = getPackedTileSizes() != nullptr; + if (!getMixedTileSizes().empty() && hasPackedTiles) + return emitOpError( + "tile_sizes and packed_tile_sizes are mutually exclusive"); + auto iterspace_rank = getStaticTileSizes().size(); ArrayRef permutation = getStaticTileInterchange(); if (permutation.size() > iterspace_rank) @@ -782,8 +815,9 @@ LogicalResult transform::FuseOp::verify() { } ArrayRef sizes = getStaticTileSizes(); - size_t numExpectedLoops = - getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); + size_t numExpectedLoops = getUseForall() || hasPackedTiles + ? 1 + : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; @@ -803,6 +837,7 @@ void transform::FuseOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getPackedTileSizesMutable(), effects); onlyReadsHandle(getTileInterchangeMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 10abd06ff266..7f1bd2183a0c 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -240,6 +240,8 @@ def _dispatch_mixed_values( for size in values or []: if isinstance(size, int): static_values.append(size) + elif isinstance(size, IntegerAttr): + static_values.append(size.value) else: static_values.append(ShapedType.get_dynamic_size()) dynamic_values.append(size) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index d9ab504f0de5..a3c3057ddb83 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -183,15 +183,19 @@ class FuseOp(FuseOp): tile_interchange = tile_interchange if tile_interchange else [] ( dynamic_tile_sizes, + packed_tile_sizes, static_tile_sizes, - _, - ) = _dispatch_dynamic_index_list(tile_sizes) + ) = _dispatch_mixed_values(tile_sizes) ( dynamic_tile_interchange, static_tile_interchange, _, ) = _dispatch_dynamic_index_list(tile_interchange) - num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0) + num_loops = ( + 1 + if use_forall or packed_tile_sizes is not None + else sum(1 for v in static_tile_sizes if v != 0) + ) if isinstance(loop_types_or_target, (Operation, Value, OpView)): loop_types = [transform.AnyOpType.get()] * num_loops @@ -210,6 +214,7 @@ class FuseOp(FuseOp): target, tile_sizes=dynamic_tile_sizes, tile_interchange=dynamic_tile_interchange, + packed_tile_sizes=packed_tile_sizes, static_tile_sizes=static_tile_sizes, static_tile_interchange=static_tile_interchange, apply_cleanup=apply_cleanup, diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index b05dc1f295a4..dab849170810 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -112,6 +112,131 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes +func.func @fuse_unary_packed_tile_sizes(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c32 = transform.param.constant 32 : i64 -> !transform.any_param + %c64 = transform.param.constant 64 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c32, %c64 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + // Verify that correct number of loops is present in packed result. + %loop:2 = transform.split_handle %loops : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_forall +func.func @fuse_unary_packed_tile_sizes_forall(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.forall + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c32 = transform.param.constant 32 : i64 -> !transform.any_param + %c64 = transform.param.constant 64 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c32, %c64 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) {use_forall} + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + // Verify that correct number of loops is present in packed result. + %loop:1 = transform.split_handle %loops : (!transform.any_op) + -> (!transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_multiple_targets +func.func @fuse_unary_packed_tile_sizes_multiple_targets( + %arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: scf.for + // CHECK: scf.for + // CHECK: linalg.add + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.add ins(%arg0, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c32 = transform.param.constant 32 : i64 -> !transform.any_param + %c64 = transform.param.constant 64 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c32, %c64 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + // Verify that correct number of loops is present in packed result. + %loop:4 = transform.split_handle %loops : (!transform.any_op) + -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_no_tiling_packed_tile_sizes +func.func @fuse_no_tiling_packed_tile_sizes(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK-NOT: scf.for + // CHECK: linalg.exp + // CHECK: %[[RES:.*]] = linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.add ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c0 = transform.param.constant 0 : i64 -> !transform.any_param + %tiles = transform.merge_handles %c0, %c0 : !transform.any_param + %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) + : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @interchange_reduction // CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index e58b7646316f..fcede61100e0 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -191,6 +191,33 @@ def testFuseOpAttributes(target): # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +@run +@create_sequence +def testFuseOpPackedTileSizes(target): + tiles = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp(target, tile_sizes=tiles) + # CHECK-LABEL: TEST: testFuseOpPackedTileSizes + # CHECK: transform.sequence + # CHECK: %[[T:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse + # CHECK-SAME: tile_sizes *(%[[T]]) + # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpPackedTileSizesForall(target): + tiles = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp(target, tile_sizes=tiles, use_forall=True) + # CHECK-LABEL: TEST: testFuseOpPackedTileSizesForall + # CHECK: transform.sequence + # CHECK: %[[T:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse + # CHECK-SAME: tile_sizes *(%[[T]]) + # CHECK-SAME: {use_forall} + # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + @run @create_sequence def testGeneralize(target):