[regalloc][LiveRegMatrix][AMDGPU] Fix LiveInterval dangling pointers in LiveRegMatrix. (#168556)

This patch correctly removes segments from LiveRegMatrix that reference LiveIntervals removed after spilling. Added validity check that LiveRegMatrix doesn't contain invalid references to intervals.
This commit is contained in:
Valery Pykhtin
2026-01-27 16:08:58 +01:00
committed by GitHub
parent e7063e8206
commit f32c00f6d8
7 changed files with 101 additions and 13 deletions

View File

@@ -93,6 +93,11 @@ public:
// Remove a live virtual register's segments from this union.
void extract(const LiveInterval &VirtReg, const LiveRange &Range);
// Remove all segments referencing VirtRegLI. This may be used if the register
// isn't used anymore. The interval should have valid register number but
// can have empty live ranges.
void clearAllSegmentsReferencing(const LiveInterval &VirtRegLI);
// Remove all inserted virtual registers.
void clear() { Segments.clear(); ++Tag; }

View File

@@ -133,7 +133,14 @@ public:
/// Unassign VirtReg from its PhysReg.
/// Assuming that VirtReg was previously assigned to a PhysReg, this undoes
/// the assignment and updates VirtRegMap accordingly.
void unassign(const LiveInterval &VirtReg);
/// ClearAllReferencingSegments changes the way segments are removed from
/// the matrix:
/// - If false (default), only segments that exactly match VirtReg's live
/// range are removed.
/// - If true, all segments that reference VirtReg are removed. This is
/// useful when VirtReg's live range(s) is already empty.
void unassign(const LiveInterval &VirtReg,
bool ClearAllReferencingSegments = false);
/// Returns true if the given \p PhysReg has any live intervals assigned.
bool isPhysRegUsed(MCRegister PhysReg) const;
@@ -170,6 +177,12 @@ public:
}
Register getOneVReg(unsigned PhysReg) const;
#ifndef NDEBUG
/// This checks that each LiveInterval referenced in LiveIntervalUnion
/// actually exists in LiveIntervals and is not a dangling pointer.
bool isValid() const;
#endif
};
class LiveRegMatrixWrapperLegacy : public MachineFunctionPass {

View File

@@ -86,6 +86,7 @@ class HoistSpillHelper : private LiveRangeEdit::Delegate {
const TargetInstrInfo &TII;
const TargetRegisterInfo &TRI;
const MachineBlockFrequencyInfo &MBFI;
LiveRegMatrix *Matrix;
InsertPointAnalysis IPA;
@@ -129,16 +130,17 @@ class HoistSpillHelper : private LiveRangeEdit::Delegate {
public:
HoistSpillHelper(const Spiller::RequiredAnalyses &Analyses,
MachineFunction &mf, VirtRegMap &vrm)
MachineFunction &mf, VirtRegMap &vrm, LiveRegMatrix *matrix)
: MF(mf), LIS(Analyses.LIS), LSS(Analyses.LSS), MDT(Analyses.MDT),
VRM(vrm), MRI(mf.getRegInfo()), TII(*mf.getSubtarget().getInstrInfo()),
TRI(*mf.getSubtarget().getRegisterInfo()), MBFI(Analyses.MBFI),
IPA(LIS, mf.getNumBlockIDs()) {}
Matrix(matrix), IPA(LIS, mf.getNumBlockIDs()) {}
void addToMergeableSpills(MachineInstr &Spill, int StackSlot,
Register Original);
bool rmFromMergeableSpills(MachineInstr &Spill, int StackSlot);
void hoistAllSpills();
bool LRE_CanEraseVirtReg(Register) override;
void LRE_DidCloneVirtReg(Register, Register) override;
};
@@ -191,7 +193,7 @@ public:
: MF(MF), LIS(Analyses.LIS), LSS(Analyses.LSS), VRM(VRM),
MRI(MF.getRegInfo()), TII(*MF.getSubtarget().getInstrInfo()),
TRI(*MF.getSubtarget().getRegisterInfo()), Matrix(Matrix),
HSpiller(Analyses, MF, VRM), VRAI(VRAI) {}
HSpiller(Analyses, MF, VRM, Matrix), VRAI(VRAI) {}
void spill(LiveRangeEdit &, AllocationOrder *Order = nullptr) override;
ArrayRef<Register> getSpilledRegs() override { return RegsToSpill; }
@@ -1750,6 +1752,17 @@ void HoistSpillHelper::hoistAllSpills() {
}
}
/// Called before a virtual register is erased from LiveIntervals.
/// Forcibly remove the register from LiveRegMatrix before it's deleted,
/// preventing dangling pointers.
bool HoistSpillHelper::LRE_CanEraseVirtReg(Register VirtReg) {
if (Matrix && VRM.hasPhys(VirtReg)) {
const LiveInterval &LI = LIS.getInterval(VirtReg);
Matrix->unassign(LI, /*ClearAllReferencingSegments=*/true);
}
return true; // Allow deletion to proceed
}
/// For VirtReg clone, the \p New register should have the same physreg or
/// stackslot as the \p old register.
void HoistSpillHelper::LRE_DidCloneVirtReg(Register New, Register Old) {

View File

@@ -79,6 +79,19 @@ void LiveIntervalUnion::extract(const LiveInterval &VirtReg,
}
}
void LiveIntervalUnion::clearAllSegmentsReferencing(
const LiveInterval &VirtRegLI) {
++Tag;
// Remove all segments referencing VirtReg.
for (SegmentIter SegPos = Segments.begin(); SegPos.valid();) {
if (SegPos.value()->reg() == VirtRegLI.reg())
SegPos.erase();
else
++SegPos;
}
}
void
LiveIntervalUnion::print(raw_ostream &OS, const TargetRegisterInfo *TRI) const {
if (empty()) {

View File

@@ -12,11 +12,13 @@
#include "llvm/CodeGen/LiveRegMatrix.h"
#include "RegisterCoalescer.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/LiveInterval.h"
#include "llvm/CodeGen/LiveIntervalUnion.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/CodeGen/VirtRegMap.h"
@@ -125,18 +127,25 @@ void LiveRegMatrix::assign(const LiveInterval &VirtReg, MCRegister PhysReg) {
LLVM_DEBUG(dbgs() << '\n');
}
void LiveRegMatrix::unassign(const LiveInterval &VirtReg) {
void LiveRegMatrix::unassign(const LiveInterval &VirtReg,
bool ClearAllReferencingSegments) {
Register PhysReg = VRM->getPhys(VirtReg.reg());
LLVM_DEBUG(dbgs() << "unassigning " << printReg(VirtReg.reg(), TRI)
<< " from " << printReg(PhysReg, TRI) << ':');
VRM->clearVirt(VirtReg.reg());
foreachUnit(TRI, VirtReg, PhysReg,
[&](MCRegUnit Unit, const LiveRange &Range) {
LLVM_DEBUG(dbgs() << ' ' << printRegUnit(Unit, TRI));
Matrix[Unit].extract(VirtReg, Range);
return false;
});
if (!ClearAllReferencingSegments) {
foreachUnit(TRI, VirtReg, PhysReg,
[&](MCRegUnit Unit, const LiveRange &Range) {
LLVM_DEBUG(dbgs() << ' ' << printRegUnit(Unit, TRI));
Matrix[Unit].extract(VirtReg, Range);
return false;
});
} else {
for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
Matrix[Unit].clearAllSegmentsReferencing(VirtReg);
}
}
++NumUnassigned;
LLVM_DEBUG(dbgs() << '\n');
@@ -290,6 +299,35 @@ Register LiveRegMatrix::getOneVReg(unsigned PhysReg) const {
return MCRegister::NoRegister;
}
#ifndef NDEBUG
bool LiveRegMatrix::isValid() const {
// Build set of all valid LiveInterval pointers from LiveIntervals.
DenseSet<const LiveInterval *> ValidIntervals;
for (unsigned RegIdx = 0, NumRegs = VRM->getRegInfo().getNumVirtRegs();
RegIdx < NumRegs; ++RegIdx) {
Register VReg = Register::index2VirtReg(RegIdx);
// Only track assigned registers since unassigned ones won't be in Matrix
if (VRM->hasPhys(VReg) && LIS->hasInterval(VReg))
ValidIntervals.insert(&LIS->getInterval(VReg));
}
// Now scan all LiveIntervalUnions in the matrix and verify each pointer
unsigned NumDanglingPointers = 0;
for (unsigned I = 0, Size = Matrix.size(); I < Size; ++I) {
MCRegUnit Unit = static_cast<MCRegUnit>(I);
for (const LiveInterval *LI : Matrix[Unit]) {
if (!ValidIntervals.contains(LI)) {
++NumDanglingPointers;
dbgs() << "ERROR: LiveInterval pointer is not found in LiveIntervals:\n"
<< " Register Unit: " << printRegUnit(Unit, TRI) << '\n'
<< " LiveInterval pointer: " << LI << '\n';
}
}
}
return NumDanglingPointers == 0;
}
#endif
AnalysisKey LiveRegMatrixAnalysis::Key;
LiveRegMatrix LiveRegMatrixAnalysis::run(MachineFunction &MF,

View File

@@ -155,6 +155,10 @@ void RegAllocBase::allocatePhysRegs() {
void RegAllocBase::postOptimization() {
spiller().postOptimization();
// Verify LiveRegMatrix after spilling (no dangling pointers).
assert(Matrix->isValid() && "LiveRegMatrix validation failed");
for (auto *DeadInst : DeadRemats) {
LIS->RemoveMachineInstrFromMaps(*DeadInst);
DeadInst->eraseFromParent();

View File

@@ -153,11 +153,13 @@ void SIPreAllocateWWMRegs::rewriteRegs(MachineFunction &MF) {
SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
for (unsigned Reg : RegsToRewrite) {
LIS->removeInterval(Reg);
const Register PhysReg = VRM->getPhys(Reg);
assert(PhysReg != 0);
LiveInterval &LI = LIS->getInterval(Reg);
Matrix->unassign(LI, /*ClearAllReferencingSegments=*/true);
LIS->removeInterval(Reg);
MFI->reserveWWMRegister(PhysReg);
}