[mlir][linalg] Fuse transform op - variadic tile sizes (#194657)

Extends the 'structured.fuse' op to accept packed handle containing
variable number of tile sizes.

Use of packed handles allows for runtime tiling decisions for improved
transform schedule flexibility and reusability.
The extension's design follows the existing approach of transform
'structured.tile_using_forall' op to more closely align their usage.

In case of tiling using nested loops, all created loops are packed into
a single return handle. For each target op, corresponding loops are
appended to the result handle.

Assisted-by: Claude
This commit is contained in:
Adam Siemieniuk
2026-04-30 10:06:09 +02:00
committed by GitHub
parent 272812b9b4
commit c10f33e893
6 changed files with 220 additions and 17 deletions

View File

@@ -423,6 +423,12 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
and loop interchange permutation can be provided as either static
attributes or dynamic values (transform parameters or payload handles).
Additionally, tile sizes can also be provided as a single handle containing
variadic number of values. In that case, the number of loops generated is
determined at runtime from the number of values in the packed handle.
For each target, created loops are appended to the single return handle in
the same order as the target operations.
If `apply_cleanup` is true then slice canonicalization is applied between
fusion steps. If `use_forall` is true then tiling method generates a
`scf.forall` loop instead of `scf.for` loops.
@@ -432,6 +438,7 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
(ins TransformHandleTypeInterface:$target,
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
Optional<TransformAnyParamTypeOrAnyHandle> : $packed_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
UnitAttr:$apply_cleanup,
@@ -465,7 +472,9 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let assemblyFormat = [{
$target oilist(
`tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes) |
`tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
$tile_sizes,
$static_tile_sizes) |
`interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)
)
attr-dict `:` functional-type(operands, results)

View File

@@ -654,6 +654,7 @@ void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
/*target=*/target,
/*tile_sizes=*/dynamicTileSizes,
/*tile_interchange=*/dynamicTileInterchange,
/*packed_tile_sizes=*/Value(),
/*static_tile_sizes=*/staticTileSizesAttr,
/*static_tile_interchange=*/staticTileInterchangeAttr,
/*apply_cleanup=*/applyCleanup,
@@ -666,10 +667,12 @@ template <typename Range>
static LogicalResult applyTilingToAll(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
unsigned numLoops, transform::TransformResults &transformResults,
bool packedResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) {
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
size_t numTargets = llvm::range_size(payloadOps);
for (Operation *target : payloadOps) {
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
@@ -704,8 +707,22 @@ static LogicalResult applyTilingToAll(
}
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
if (packedResults) {
// In case of packed results, all created loops are assigned to a single
// handle. Loops are returned in order of targets such as:
// %loops_handle = {
// target0:loop0, ..., target0:loopN,
// target1:loop0, ..., target1:loopN,
// ... }
SmallVector<Operation *> flattenedLoopOps;
for (unsigned int idx = 0; idx < numTargets; ++idx)
for (unsigned int i = 0; i < numLoops; ++i)
flattenedLoopOps.push_back(loopOps[i][idx]);
transformResults.set(transformOp->getOpResult(1), flattenedLoopOps);
} else {
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
}
return success();
}
@@ -716,9 +733,13 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformState &state) {
auto transformOp = cast<TransformOpInterface>(getOperation());
SmallVector<int64_t> tileSizes;
DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
state, transformOp, getMixedTileSizes(), tileSizes);
SmallVector<OpFoldResult> mixedTileSizes;
DiagnosedSilenceableFailure status =
getPackedTileSizes()
? unpackSingleIndexResultPayloadOperations(
state, transformOp, mixedTileSizes, getPackedTileSizes())
: unpackSingleIndexResultPayloadOperations(
state, transformOp, mixedTileSizes, getMixedTileSizes());
if (!status.succeeded())
return status;
SmallVector<int64_t> tileInterchange;
@@ -733,9 +754,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tilingOptions.setLoopType(useForall
? scf::SCFTilingOptions::LoopType::ForallOp
: scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
tilingOptions = tilingOptions.setTileSizes(mixedTileSizes);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
@@ -748,11 +767,20 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
size_t numLoops =
useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
size_t numLoops;
if (useForall) {
numLoops = 1;
} else {
numLoops = llvm::count_if(mixedTileSizes, [](OpFoldResult ofr) {
auto attr = dyn_cast<Attribute>(ofr);
if (!attr)
return true;
return cast<IntegerAttr>(attr).getInt() != 0;
});
}
LogicalResult result = applyTilingToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
transformResults,
transformResults, /*packedResults=*/getPackedTileSizes() != nullptr,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -763,6 +791,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
bool hasPackedTiles = getPackedTileSizes() != nullptr;
if (!getMixedTileSizes().empty() && hasPackedTiles)
return emitOpError(
"tile_sizes and packed_tile_sizes are mutually exclusive");
auto iterspace_rank = getStaticTileSizes().size();
ArrayRef<int64_t> permutation = getStaticTileInterchange();
if (permutation.size() > iterspace_rank)
@@ -782,8 +815,9 @@ LogicalResult transform::FuseOp::verify() {
}
ArrayRef<int64_t> sizes = getStaticTileSizes();
size_t numExpectedLoops =
getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
size_t numExpectedLoops = getUseForall() || hasPackedTiles
? 1
: sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
@@ -803,6 +837,7 @@ void transform::FuseOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getTileSizesMutable(), effects);
onlyReadsHandle(getPackedTileSizesMutable(), effects);
onlyReadsHandle(getTileInterchangeMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);

View File

@@ -240,6 +240,8 @@ def _dispatch_mixed_values(
for size in values or []:
if isinstance(size, int):
static_values.append(size)
elif isinstance(size, IntegerAttr):
static_values.append(size.value)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)

View File

@@ -183,15 +183,19 @@ class FuseOp(FuseOp):
tile_interchange = tile_interchange if tile_interchange else []
(
dynamic_tile_sizes,
packed_tile_sizes,
static_tile_sizes,
_,
) = _dispatch_dynamic_index_list(tile_sizes)
) = _dispatch_mixed_values(tile_sizes)
(
dynamic_tile_interchange,
static_tile_interchange,
_,
) = _dispatch_dynamic_index_list(tile_interchange)
num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0)
num_loops = (
1
if use_forall or packed_tile_sizes is not None
else sum(1 for v in static_tile_sizes if v != 0)
)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -210,6 +214,7 @@ class FuseOp(FuseOp):
target,
tile_sizes=dynamic_tile_sizes,
tile_interchange=dynamic_tile_interchange,
packed_tile_sizes=packed_tile_sizes,
static_tile_sizes=static_tile_sizes,
static_tile_interchange=static_tile_interchange,
apply_cleanup=apply_cleanup,

View File

@@ -112,6 +112,131 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes
func.func @fuse_unary_packed_tile_sizes(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.exp
// CHECK: linalg.add
// CHECK: return %[[RES]]
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%c32 = transform.param.constant 32 : i64 -> !transform.any_param
%c64 = transform.param.constant 64 : i64 -> !transform.any_param
%tiles = transform.merge_handles %c32, %c64 : !transform.any_param
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
// Verify that correct number of loops is present in packed result.
%loop:2 = transform.split_handle %loops : (!transform.any_op)
-> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_forall
func.func @fuse_unary_packed_tile_sizes_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[RES:.*]] = scf.forall
// CHECK: linalg.exp
// CHECK: linalg.add
// CHECK: return %[[RES]]
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%c32 = transform.param.constant 32 : i64 -> !transform.any_param
%c64 = transform.param.constant 64 : i64 -> !transform.any_param
%tiles = transform.merge_handles %c32, %c64 : !transform.any_param
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) {use_forall}
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
// Verify that correct number of loops is present in packed result.
%loop:1 = transform.split_handle %loops : (!transform.any_op)
-> (!transform.any_op)
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_multiple_targets
func.func @fuse_unary_packed_tile_sizes_multiple_targets(
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: scf.for
// CHECK: scf.for
// CHECK: linalg.add
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.add
// CHECK: return %[[RES]]
%0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%c32 = transform.param.constant 32 : i64 -> !transform.any_param
%c64 = transform.param.constant 64 : i64 -> !transform.any_param
%tiles = transform.merge_handles %c32, %c64 : !transform.any_param
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
// Verify that correct number of loops is present in packed result.
%loop:4 = transform.split_handle %loops : (!transform.any_op)
-> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @fuse_no_tiling_packed_tile_sizes
func.func @fuse_no_tiling_packed_tile_sizes(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NOT: scf.for
// CHECK: linalg.exp
// CHECK: %[[RES:.*]] = linalg.add
// CHECK: return %[[RES]]
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%c0 = transform.param.constant 0 : i64 -> !transform.any_param
%tiles = transform.merge_handles %c0, %c0 : !transform.any_param
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// CHECK-LABEL: func.func @interchange_reduction
// CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {

View File

@@ -191,6 +191,33 @@ def testFuseOpAttributes(target):
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
def testFuseOpPackedTileSizes(target):
tiles = structured.MatchOp.match_op_names(target, ["arith.constant"])
structured.FuseOp(target, tile_sizes=tiles)
# CHECK-LABEL: TEST: testFuseOpPackedTileSizes
# CHECK: transform.sequence
# CHECK: %[[T:.*]] = transform.structured.match
# CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse
# CHECK-SAME: tile_sizes *(%[[T]])
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
@create_sequence
def testFuseOpPackedTileSizesForall(target):
tiles = structured.MatchOp.match_op_names(target, ["arith.constant"])
structured.FuseOp(target, tile_sizes=tiles, use_forall=True)
# CHECK-LABEL: TEST: testFuseOpPackedTileSizesForall
# CHECK: transform.sequence
# CHECK: %[[T:.*]] = transform.structured.match
# CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse
# CHECK-SAME: tile_sizes *(%[[T]])
# CHECK-SAME: {use_forall}
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
@create_sequence
def testGeneralize(target):