The `OneDimMultiReductionToTwoDim` pattern had some issues. For the
input program:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
%0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
return %0 : f32
}
```
* when lowering using the inner-parallel strategy, the compiler would
essentially produce scalar code:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
%0 = vector.shape_cast %arg0 : vector<8xf32> to vector<1x8xf32>
%1 = vector.broadcast %arg1 : f32 to vector<1xf32>
%2 = vector.transpose %0, [1, 0] : vector<1x8xf32> to vector<8x1xf32>
%3 = vector.extract %2[0] : vector<1xf32> from vector<8x1xf32>
%4 = arith.addf %3, %1 : vector<1xf32>
%5 = vector.extract %2[1] : vector<1xf32> from vector<8x1xf32>
%6 = arith.addf %5, %4 : vector<1xf32>
... (repeats for all 8 elements) ...
%17 = vector.extract %2[7] : vector<1xf32> from vector<8x1xf32>
%18 = arith.addf %17, %16 : vector<1xf32>
%19 = vector.extract %18[0] : f32 from vector<1xf32>
return %19 : f32
}
```
* when lowering using the inner-reduction strategy, the compiler would
first unnecessarily transform it into a 2-D multi_reduction operation
<1x8xf32> and then extract an <8xf32> vector and apply reduction. The
canonicalization and folding would lead to the following final result:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<8xf32> into f32
return %0 : f32
}
```
Now, after this change:
* when lowering the compiler now produces for both strategies in one
step.
```
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<8xf32> into f32
return %0 : f32
}
```
This pattern is also useful for an ongoing refactoring that is happening
in the multi_reduction patterns. It is the only pattern that increases
multi_reduction in rank and would lead to an infinite loop when
attempting to reach a fixed point once we generalize other unrolling
patterns.
Assisted-by: Claude
166 lines
7.1 KiB
Python
166 lines
7.1 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
from mlir.dialects import transform
|
|
from mlir.dialects.transform import vector
|
|
|
|
|
|
def run_apply_patterns(f):
|
|
with Context(), Location.unknown():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
sequence = transform.SequenceOp(
|
|
transform.FailurePropagationMode.Propagate,
|
|
[],
|
|
transform.AnyOpType.get(),
|
|
)
|
|
with InsertionPoint(sequence.body):
|
|
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
|
|
with InsertionPoint(apply.patterns):
|
|
f()
|
|
transform.YieldOp()
|
|
print("\nTEST:", f.__name__)
|
|
print(module)
|
|
return f
|
|
|
|
|
|
@run_apply_patterns
|
|
def non_configurable_patterns():
|
|
# CHECK-LABEL: TEST: non_configurable_patterns
|
|
# CHECK: apply_patterns
|
|
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
|
|
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
|
|
vector.ApplyRankReducingSubviewPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
|
|
vector.ApplyTransferPermutationPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_broadcast
|
|
vector.ApplyLowerBroadcastPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_masks
|
|
vector.ApplyLowerMasksPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
|
|
vector.ApplyLowerMaskedTransfersPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.materialize_masks
|
|
vector.ApplyMaterializeMasksPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_outerproduct
|
|
vector.ApplyLowerOuterProductPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_gather
|
|
vector.ApplyLowerGatherPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.unroll_from_elements
|
|
vector.ApplyUnrollFromElementsPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.unroll_to_elements
|
|
vector.ApplyUnrollToElementsPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_scan
|
|
vector.ApplyLowerScanPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_shape_cast
|
|
vector.ApplyLowerShapeCastPatternsOp()
|
|
|
|
|
|
@run_apply_patterns
|
|
def configurable_patterns():
|
|
# CHECK-LABEL: TEST: configurable_patterns
|
|
# CHECK: apply_patterns
|
|
# CHECK: transform.apply_patterns.vector.lower_transfer
|
|
# CHECK-SAME: max_transfer_rank = 4
|
|
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
|
|
# CHECK: transform.apply_patterns.vector.transfer_to_scf
|
|
# CHECK-SAME: max_transfer_rank = 3
|
|
# CHECK-SAME: full_unroll = true
|
|
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
|
|
# CHECK: transform.apply_patterns.vector.flatten_vector_transfer_ops
|
|
# CHECK-SAME: target_vector_bitwidth = 1
|
|
vector.ApplyFlattenVectorTransferOpsPatternsOp(target_vector_bitwidth=1)
|
|
|
|
|
|
@run_apply_patterns
|
|
def enum_configurable_patterns():
|
|
# CHECK: transform.apply_patterns.vector.lower_contraction
|
|
vector.ApplyLowerContractionPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_contraction
|
|
# CHECK-SAME: lowering_strategy = llvmintr
|
|
vector.ApplyLowerContractionPatternsOp(
|
|
lowering_strategy=vector.VectorContractLowering.LLVMIntr
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_contraction
|
|
# CHECK-SAME: lowering_strategy = parallelarith
|
|
vector.ApplyLowerContractionPatternsOp(
|
|
lowering_strategy=vector.VectorContractLowering.ParallelArith
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
|
|
vector.ApplyReorderMultiReductionPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
|
|
# CHECK-SAME: lowering_strategy = innerreduction
|
|
vector.ApplyReorderMultiReductionPatternsOp(
|
|
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.multi_reduction_flattening
|
|
vector.ApplyMultiReductionFlatteningPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.multi_reduction_flattening
|
|
# CHECK-SAME: lowering_strategy = innerreduction
|
|
vector.ApplyMultiReductionFlatteningPatternsOp(
|
|
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
|
|
vector.ApplyMultiReductionUnrollingPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
|
|
# CHECK-SAME: lowering_strategy = innerreduction
|
|
vector.ApplyMultiReductionUnrollingPatternsOp(
|
|
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
vector.ApplyLowerTransposePatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# This is the default strategy, not printed.
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.EltWise
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = llvmintr
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.LLVMIntr
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = shuffle_1d
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = shuffle_16x16
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = llvmintr
|
|
# CHECK-SAME: avx2_lowering_strategy = true
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.LLVMIntr,
|
|
avx2_lowering_strategy=True,
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
vector.ApplySplitTransferFullPartialPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# CHECK-SAME: split_transfer_strategy = none
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.None_
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# This is the default mode, not printed.
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
|
|
)
|