[mlir] Fix crash in dropRedundantArguments with produced operands. (#172759)
dropRedundantArguments was incorrectly indexing into forwardedOperands using the block argument index directly. This crashes when the block has produced operands (generated by the terminator, not forwarded from predecessors) because forwardedOperands doesn't include them. The fix checks isOperandProduced() to skip produced arguments and uses SuccessorOperands::operator[] which handles the offset correctly.
This commit is contained in:
@@ -975,12 +975,22 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
|
||||
}
|
||||
unsigned succIndex = predIt.getSuccessorIndex();
|
||||
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
|
||||
auto branchOperands = succOperands.getForwardedOperands();
|
||||
|
||||
// Produced operands are generated by the terminator operation itself
|
||||
// (e.g., results of an async call) and cannot be forwarded or dropped.
|
||||
if (succOperands.isOperandProduced(argIdx)) {
|
||||
sameArg = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// Get the forwarded operand value using operator[] which correctly
|
||||
// adjusts for the produced operand offset.
|
||||
Value operandValue = succOperands[argIdx];
|
||||
if (!commonValue) {
|
||||
commonValue = branchOperands[argIdx];
|
||||
commonValue = operandValue;
|
||||
continue;
|
||||
}
|
||||
if (branchOperands[argIdx] != commonValue) {
|
||||
if (operandValue != commonValue) {
|
||||
sameArg = false;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -318,3 +318,36 @@ func.func @nested_loop(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i3
|
||||
^EXIT: // pred: ^Loop_header
|
||||
return
|
||||
}
|
||||
|
||||
// Test that dropRedundantArguments correctly handles produced successor operands.
|
||||
|
||||
// CHECK-LABEL: func.func @produced_operand_not_dropped
|
||||
// CHECK-SAME: (%{{.*}}: i1) -> i32 {
|
||||
func.func @produced_operand_not_dropped(%cond: i1) -> i32 {
|
||||
// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
|
||||
%c42 = arith.constant 42 : i32
|
||||
// CHECK: "test.internal_br"()[^[[UNUSED:.*]], ^[[TARGET:.*]]]
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// Branches via error path (successor 1) which has 1 produced + 1 forwarded arg.
|
||||
"test.internal_br"(%c42) [^bb_unused, ^bb_target] {
|
||||
operandSegmentSizes = array<i32: 0, 1>
|
||||
} : (i32) -> ()
|
||||
^bb2:
|
||||
// Also branches via error path with the same forwarded value.
|
||||
"test.internal_br"(%c42) [^bb_unused, ^bb_target] {
|
||||
operandSegmentSizes = array<i32: 0, 1>
|
||||
} : (i32) -> ()
|
||||
// CHECK: ^[[UNUSED]]:
|
||||
^bb_unused:
|
||||
// CHECK: "test.terminator"
|
||||
"test.terminator"() : () -> ()
|
||||
// arg0: produced by test.internal_br (kept)
|
||||
// arg1: forwarded %c42 (dropped - same from both preds, replaced with %c42 directly)
|
||||
// CHECK: ^[[TARGET]](%[[PRODUCED:.*]]: i32):
|
||||
^bb_target(%arg0: i32, %arg1: i32):
|
||||
// CHECK: %[[RESULT:.*]] = arith.addi %[[PRODUCED]], %[[C42]] : i32
|
||||
%result = arith.addi %arg0, %arg1 : i32
|
||||
// CHECK: return %[[RESULT]] : i32
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user