[mlir][dataflow] Register dependency when const-prop fold returns non-operand (#194372)

Fixes #137509.

When `op->fold` returns a Value that is not one of `op`'s operands (e.g.
`unrealized_conversion_cast`'s fold returns the inner cast's operand),
`SparseConstantPropagation` read that value's lattice without
subscribing to it -- so the op was not revisited when the lattice
widened and its stale fold result was not updated.

Fix by using `getLatticeElementFor(getProgramPointAfter(op), v)` to
register the dependency. This matches a few places in
`SparseAnalysis.cpp` where the same strategy is used.

I'd love to use something even simpler than `unrealized_conversion_cast`
operation in the test, but this is what i got when minimizing the
reproduction from the original issue (#137509) and i wasn't able to find
any operation that would work for this reproduction.

Assisted-By: Claude Code
This commit is contained in:
Zmicier Prybysh
2026-04-30 15:38:35 +02:00
committed by GitHub
parent 3232d38a59
commit 6f1e6e47bd
2 changed files with 34 additions and 2 deletions

View File

@@ -103,9 +103,13 @@ LogicalResult SparseConstantPropagation::visitOperation(
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
} else {
LDBG() << "Folded to value: " << cast<Value>(foldResult);
Value foldValue = cast<Value>(foldResult);
LDBG() << "Folded to value: " << foldValue;
// The folded value may not be an operand of `op`, so we need to use
// `getLatticeElementFor` (and not `getLatticeElement`) so that
// this operation is revisited if that value's lattice widens later.
AbstractSparseForwardDataFlowAnalysis::join(
lattice, *getLatticeElement(cast<Value>(foldResult)));
lattice, *getLatticeElementFor(getProgramPointAfter(op), foldValue));
}
}
return success();

View File

@@ -307,3 +307,31 @@ func.func @no_crash_emitc_switch_unsigned_condition() {
}
return
}
// -----
// Regression test for https://github.com/llvm/llvm-project/issues/137509
//
// %a (the ^bb3 block arg) joins ^bb1's Constant %c1 with ^bb2's non-foldable
// %v, so %a's lattice is overdefined. %cast2's fold collapses the round-trip
// cast chain to %a - a Value that is not one of %cast2's own operands - so
// constant propagation must read %a's lattice directly to compute %cast2's.
// %a being overdefined, the cast chain is preserved.
// CHECK-LABEL: func @fold_to_non_operand_value
func.func @fold_to_non_operand_value(%x: i64, %cond: i1) -> i64 {
%c1 = arith.constant 1 : i64
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%c1 : i64)
^bb2:
%v = arith.addi %x, %c1 : i64
cf.br ^bb3(%v : i64)
^bb3(%a: i64):
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i64 to index
// CHECK: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[CAST1]] : index to i64
// CHECK: return %[[CAST2]] : i64
%cast1 = builtin.unrealized_conversion_cast %a : i64 to index
%cast2 = builtin.unrealized_conversion_cast %cast1 : index to i64
return %cast2 : i64
}