[mlir][linalg] Fuse transform op - variadic tile sizes (#194657)
Extends the 'structured.fuse' op to accept packed handle containing variable number of tile sizes. Use of packed handles allows for runtime tiling decisions for improved transform schedule flexibility and reusability. The extension's design follows the existing approach of transform 'structured.tile_using_forall' op to more closely align their usage. In case of tiling using nested loops, all created loops are packed into a single return handle. For each target op, corresponding loops are appended to the result handle. Assisted-by: Claude
This commit is contained in:
@@ -423,6 +423,12 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
|
||||
and loop interchange permutation can be provided as either static
|
||||
attributes or dynamic values (transform parameters or payload handles).
|
||||
|
||||
Additionally, tile sizes can also be provided as a single handle containing
|
||||
variadic number of values. In that case, the number of loops generated is
|
||||
determined at runtime from the number of values in the packed handle.
|
||||
For each target, created loops are appended to the single return handle in
|
||||
the same order as the target operations.
|
||||
|
||||
If `apply_cleanup` is true then slice canonicalization is applied between
|
||||
fusion steps. If `use_forall` is true then tiling method generates a
|
||||
`scf.forall` loop instead of `scf.for` loops.
|
||||
@@ -432,6 +438,7 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
|
||||
(ins TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
|
||||
Optional<TransformAnyParamTypeOrAnyHandle> : $packed_tile_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
|
||||
UnitAttr:$apply_cleanup,
|
||||
@@ -465,7 +472,9 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
|
||||
|
||||
let assemblyFormat = [{
|
||||
$target oilist(
|
||||
`tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes) |
|
||||
`tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
|
||||
$tile_sizes,
|
||||
$static_tile_sizes) |
|
||||
`interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)
|
||||
)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
|
||||
@@ -654,6 +654,7 @@ void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
|
||||
/*target=*/target,
|
||||
/*tile_sizes=*/dynamicTileSizes,
|
||||
/*tile_interchange=*/dynamicTileInterchange,
|
||||
/*packed_tile_sizes=*/Value(),
|
||||
/*static_tile_sizes=*/staticTileSizesAttr,
|
||||
/*static_tile_interchange=*/staticTileInterchangeAttr,
|
||||
/*apply_cleanup=*/applyCleanup,
|
||||
@@ -666,10 +667,12 @@ template <typename Range>
|
||||
static LogicalResult applyTilingToAll(
|
||||
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
|
||||
unsigned numLoops, transform::TransformResults &transformResults,
|
||||
bool packedResults,
|
||||
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
|
||||
applyFn) {
|
||||
SmallVector<Operation *> tiledLinalgOps;
|
||||
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
|
||||
size_t numTargets = llvm::range_size(payloadOps);
|
||||
|
||||
for (Operation *target : payloadOps) {
|
||||
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
||||
@@ -704,8 +707,22 @@ static LogicalResult applyTilingToAll(
|
||||
}
|
||||
|
||||
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
|
||||
for (unsigned int i = 0; i < numLoops; ++i)
|
||||
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
|
||||
if (packedResults) {
|
||||
// In case of packed results, all created loops are assigned to a single
|
||||
// handle. Loops are returned in order of targets such as:
|
||||
// %loops_handle = {
|
||||
// target0:loop0, ..., target0:loopN,
|
||||
// target1:loop0, ..., target1:loopN,
|
||||
// ... }
|
||||
SmallVector<Operation *> flattenedLoopOps;
|
||||
for (unsigned int idx = 0; idx < numTargets; ++idx)
|
||||
for (unsigned int i = 0; i < numLoops; ++i)
|
||||
flattenedLoopOps.push_back(loopOps[i][idx]);
|
||||
transformResults.set(transformOp->getOpResult(1), flattenedLoopOps);
|
||||
} else {
|
||||
for (unsigned int i = 0; i < numLoops; ++i)
|
||||
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -716,9 +733,13 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
|
||||
mlir::transform::TransformState &state) {
|
||||
auto transformOp = cast<TransformOpInterface>(getOperation());
|
||||
|
||||
SmallVector<int64_t> tileSizes;
|
||||
DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
|
||||
state, transformOp, getMixedTileSizes(), tileSizes);
|
||||
SmallVector<OpFoldResult> mixedTileSizes;
|
||||
DiagnosedSilenceableFailure status =
|
||||
getPackedTileSizes()
|
||||
? unpackSingleIndexResultPayloadOperations(
|
||||
state, transformOp, mixedTileSizes, getPackedTileSizes())
|
||||
: unpackSingleIndexResultPayloadOperations(
|
||||
state, transformOp, mixedTileSizes, getMixedTileSizes());
|
||||
if (!status.succeeded())
|
||||
return status;
|
||||
SmallVector<int64_t> tileInterchange;
|
||||
@@ -733,9 +754,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
|
||||
tilingOptions.setLoopType(useForall
|
||||
? scf::SCFTilingOptions::LoopType::ForallOp
|
||||
: scf::SCFTilingOptions::LoopType::ForOp);
|
||||
SmallVector<OpFoldResult> tileSizesOfr =
|
||||
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
||||
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
|
||||
tilingOptions = tilingOptions.setTileSizes(mixedTileSizes);
|
||||
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
||||
tileAndFuseOptions.tilingOptions = tilingOptions;
|
||||
|
||||
@@ -748,11 +767,20 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
|
||||
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
|
||||
}
|
||||
|
||||
size_t numLoops =
|
||||
useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
|
||||
size_t numLoops;
|
||||
if (useForall) {
|
||||
numLoops = 1;
|
||||
} else {
|
||||
numLoops = llvm::count_if(mixedTileSizes, [](OpFoldResult ofr) {
|
||||
auto attr = dyn_cast<Attribute>(ofr);
|
||||
if (!attr)
|
||||
return true;
|
||||
return cast<IntegerAttr>(attr).getInt() != 0;
|
||||
});
|
||||
}
|
||||
LogicalResult result = applyTilingToAll(
|
||||
rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
|
||||
transformResults,
|
||||
transformResults, /*packedResults=*/getPackedTileSizes() != nullptr,
|
||||
[&](TilingInterface tilingInterfaceOp)
|
||||
-> FailureOr<scf::SCFTileAndFuseResult> {
|
||||
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
|
||||
@@ -763,6 +791,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
|
||||
}
|
||||
|
||||
LogicalResult transform::FuseOp::verify() {
|
||||
bool hasPackedTiles = getPackedTileSizes() != nullptr;
|
||||
if (!getMixedTileSizes().empty() && hasPackedTiles)
|
||||
return emitOpError(
|
||||
"tile_sizes and packed_tile_sizes are mutually exclusive");
|
||||
|
||||
auto iterspace_rank = getStaticTileSizes().size();
|
||||
ArrayRef<int64_t> permutation = getStaticTileInterchange();
|
||||
if (permutation.size() > iterspace_rank)
|
||||
@@ -782,8 +815,9 @@ LogicalResult transform::FuseOp::verify() {
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> sizes = getStaticTileSizes();
|
||||
size_t numExpectedLoops =
|
||||
getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
|
||||
size_t numExpectedLoops = getUseForall() || hasPackedTiles
|
||||
? 1
|
||||
: sizes.size() - llvm::count(sizes, 0);
|
||||
if (numExpectedLoops != getNumResults() - 1)
|
||||
return emitOpError() << "expects " << numExpectedLoops << " loop results";
|
||||
|
||||
@@ -803,6 +837,7 @@ void transform::FuseOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
consumesHandle(getTargetMutable(), effects);
|
||||
onlyReadsHandle(getTileSizesMutable(), effects);
|
||||
onlyReadsHandle(getPackedTileSizesMutable(), effects);
|
||||
onlyReadsHandle(getTileInterchangeMutable(), effects);
|
||||
producesHandle(getOperation()->getOpResults(), effects);
|
||||
modifiesPayload(effects);
|
||||
|
||||
@@ -240,6 +240,8 @@ def _dispatch_mixed_values(
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
elif isinstance(size, IntegerAttr):
|
||||
static_values.append(size.value)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
|
||||
@@ -183,15 +183,19 @@ class FuseOp(FuseOp):
|
||||
tile_interchange = tile_interchange if tile_interchange else []
|
||||
(
|
||||
dynamic_tile_sizes,
|
||||
packed_tile_sizes,
|
||||
static_tile_sizes,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(tile_sizes)
|
||||
) = _dispatch_mixed_values(tile_sizes)
|
||||
(
|
||||
dynamic_tile_interchange,
|
||||
static_tile_interchange,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(tile_interchange)
|
||||
num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0)
|
||||
num_loops = (
|
||||
1
|
||||
if use_forall or packed_tile_sizes is not None
|
||||
else sum(1 for v in static_tile_sizes if v != 0)
|
||||
)
|
||||
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
@@ -210,6 +214,7 @@ class FuseOp(FuseOp):
|
||||
target,
|
||||
tile_sizes=dynamic_tile_sizes,
|
||||
tile_interchange=dynamic_tile_interchange,
|
||||
packed_tile_sizes=packed_tile_sizes,
|
||||
static_tile_sizes=static_tile_sizes,
|
||||
static_tile_interchange=static_tile_interchange,
|
||||
apply_cleanup=apply_cleanup,
|
||||
|
||||
@@ -112,6 +112,131 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes
|
||||
func.func @fuse_unary_packed_tile_sizes(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
||||
// CHECK: %[[RES:.*]] = scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.exp
|
||||
// CHECK: linalg.add
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%c32 = transform.param.constant 32 : i64 -> !transform.any_param
|
||||
%c64 = transform.param.constant 64 : i64 -> !transform.any_param
|
||||
%tiles = transform.merge_handles %c32, %c64 : !transform.any_param
|
||||
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
|
||||
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
|
||||
// Verify that correct number of loops is present in packed result.
|
||||
%loop:2 = transform.split_handle %loops : (!transform.any_op)
|
||||
-> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_forall
|
||||
func.func @fuse_unary_packed_tile_sizes_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
||||
// CHECK: %[[RES:.*]] = scf.forall
|
||||
// CHECK: linalg.exp
|
||||
// CHECK: linalg.add
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%c32 = transform.param.constant 32 : i64 -> !transform.any_param
|
||||
%c64 = transform.param.constant 64 : i64 -> !transform.any_param
|
||||
%tiles = transform.merge_handles %c32, %c64 : !transform.any_param
|
||||
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) {use_forall}
|
||||
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
|
||||
// Verify that correct number of loops is present in packed result.
|
||||
%loop:1 = transform.split_handle %loops : (!transform.any_op)
|
||||
-> (!transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_multiple_targets
|
||||
func.func @fuse_unary_packed_tile_sizes_multiple_targets(
|
||||
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.add
|
||||
// CHECK: %[[RES:.*]] = scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.add
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%c32 = transform.param.constant 32 : i64 -> !transform.any_param
|
||||
%c64 = transform.param.constant 64 : i64 -> !transform.any_param
|
||||
%tiles = transform.merge_handles %c32, %c64 : !transform.any_param
|
||||
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
|
||||
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
|
||||
// Verify that correct number of loops is present in packed result.
|
||||
%loop:4 = transform.split_handle %loops : (!transform.any_op)
|
||||
-> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fuse_no_tiling_packed_tile_sizes
|
||||
func.func @fuse_no_tiling_packed_tile_sizes(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.exp
|
||||
// CHECK: %[[RES:.*]] = linalg.add
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%c0 = transform.param.constant 0 : i64 -> !transform.any_param
|
||||
%tiles = transform.merge_handles %c0, %c0 : !transform.any_param
|
||||
%1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
|
||||
: (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @interchange_reduction
|
||||
// CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
|
||||
func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {
|
||||
|
||||
@@ -191,6 +191,33 @@ def testFuseOpAttributes(target):
|
||||
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
|
||||
|
||||
|
||||
@run
|
||||
@create_sequence
|
||||
def testFuseOpPackedTileSizes(target):
|
||||
tiles = structured.MatchOp.match_op_names(target, ["arith.constant"])
|
||||
structured.FuseOp(target, tile_sizes=tiles)
|
||||
# CHECK-LABEL: TEST: testFuseOpPackedTileSizes
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %[[T:.*]] = transform.structured.match
|
||||
# CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse
|
||||
# CHECK-SAME: tile_sizes *(%[[T]])
|
||||
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
|
||||
@run
|
||||
@create_sequence
|
||||
def testFuseOpPackedTileSizesForall(target):
|
||||
tiles = structured.MatchOp.match_op_names(target, ["arith.constant"])
|
||||
structured.FuseOp(target, tile_sizes=tiles, use_forall=True)
|
||||
# CHECK-LABEL: TEST: testFuseOpPackedTileSizesForall
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %[[T:.*]] = transform.structured.match
|
||||
# CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse
|
||||
# CHECK-SAME: tile_sizes *(%[[T]])
|
||||
# CHECK-SAME: {use_forall}
|
||||
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
|
||||
@run
|
||||
@create_sequence
|
||||
def testGeneralize(target):
|
||||
|
||||
Reference in New Issue
Block a user