Files
llvm-project/mlir/test/python/dialects/transform_vector_ext.py
Erick Ochoa Lopez 613a5c555e [mlir][vector] Replace OneDimMultiReductionToTwoDim with OneDimMultiReductionToReduction (#184241)
The `OneDimMultiReductionToTwoDim` pattern had some issues. For the
input program:

```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
    %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
    return %0 : f32
}
```

* when lowering using the inner-parallel strategy, the compiler would
essentially produce scalar code:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
    %0 = vector.shape_cast %arg0 : vector<8xf32> to vector<1x8xf32>
    %1 = vector.broadcast %arg1 : f32 to vector<1xf32>
    %2 = vector.transpose %0, [1, 0] : vector<1x8xf32> to vector<8x1xf32>
    %3 = vector.extract %2[0] : vector<1xf32> from vector<8x1xf32>
    %4 = arith.addf %3, %1 : vector<1xf32>
    %5 = vector.extract %2[1] : vector<1xf32> from vector<8x1xf32>
    %6 = arith.addf %5, %4 : vector<1xf32>
    ... (repeats for all 8 elements) ...
    %17 = vector.extract %2[7] : vector<1xf32> from vector<8x1xf32>
    %18 = arith.addf %17, %16 : vector<1xf32>
    %19 = vector.extract %18[0] : f32 from vector<1xf32>
    return %19 : f32
}
```
* when lowering using the inner-reduction strategy, the compiler would
first unnecessarily transform it into a 2-D multi_reduction operation
<1x8xf32> and then extract an <8xf32> vector and apply reduction. The
canonicalization and folding would lead to the following final result:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
    %0 = vector.reduction <add>, %arg0, %arg1 : vector<8xf32> into f32
    return %0 : f32
}
```

Now, after this change:
* when lowering the compiler now produces for both strategies in one
step.
```
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
    %0 = vector.reduction <add>, %arg0, %arg1 : vector<8xf32> into f32
    return %0 : f32
}
```

This pattern is also useful for an ongoing refactoring that is happening
in the multi_reduction patterns. It is the only pattern that increases
multi_reduction in rank and would lead to an infinite loop when
attempting to reach a fixed point once we generalize other unrolling
patterns.

Assisted-by: Claude
2026-03-04 16:13:11 +00:00

166 lines
7.1 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import vector
def run_apply_patterns(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
f()
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f
@run_apply_patterns
def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
vector.ApplyRankReducingSubviewPatternsOp()
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
vector.ApplyTransferPermutationPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_broadcast
vector.ApplyLowerBroadcastPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masks
vector.ApplyLowerMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
vector.ApplyLowerMaskedTransfersPatternsOp()
# CHECK: transform.apply_patterns.vector.materialize_masks
vector.ApplyMaterializeMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_outerproduct
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_from_elements
vector.ApplyUnrollFromElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_to_elements
vector.ApplyUnrollToElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
vector.ApplyLowerShapeCastPatternsOp()
@run_apply_patterns
def configurable_patterns():
# CHECK-LABEL: TEST: configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.lower_transfer
# CHECK-SAME: max_transfer_rank = 4
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
# CHECK: transform.apply_patterns.vector.transfer_to_scf
# CHECK-SAME: max_transfer_rank = 3
# CHECK-SAME: full_unroll = true
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
# CHECK: transform.apply_patterns.vector.flatten_vector_transfer_ops
# CHECK-SAME: target_vector_bitwidth = 1
vector.ApplyFlattenVectorTransferOpsPatternsOp(target_vector_bitwidth=1)
@run_apply_patterns
def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.lower_contraction
vector.ApplyLowerContractionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = llvmintr
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.LLVMIntr
)
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = parallelarith
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.ParallelArith
)
# CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
vector.ApplyReorderMultiReductionPatternsOp()
# CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyReorderMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
# CHECK: transform.apply_patterns.vector.multi_reduction_flattening
vector.ApplyMultiReductionFlatteningPatternsOp()
# CHECK: transform.apply_patterns.vector.multi_reduction_flattening
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionFlatteningPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
# CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
vector.ApplyMultiReductionUnrollingPatternsOp()
# CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyMultiReductionUnrollingPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
# This is the default strategy, not printed.
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.EltWise
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = llvmintr
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.LLVMIntr
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_1d
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_16x16
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = llvmintr
# CHECK-SAME: avx2_lowering_strategy = true
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.LLVMIntr,
avx2_lowering_strategy=True,
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
vector.ApplySplitTransferFullPartialPatternsOp()
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = none
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.None_
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# This is the default mode, not printed.
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
)