//=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //==-----------------------------------------------------------------------===// // /// \file /// Implements helpers for target-independent rematerialization at the MIR /// level. // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/Rematerializer.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/Register.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "rematerializer" using namespace llvm; using RegisterIdx = Rematerializer::RegisterIdx; // Pin the vtable to this file. void Rematerializer::Listener::anchor() {} /// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this /// implies it is also live there). When \p LI has sub-ranges, checks that /// all sub-ranges intersecting with \p Mask are also live at \p UseIdx. static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask, SlotIndex UseIdx, const LiveInterval &LI) { if (&OVNI != LI.getVNInfoAt(UseIdx)) return false; if (LI.hasSubRanges()) { // Check that intersecting subranges are live at user. for (const LiveInterval::SubRange &SR : LI.subranges()) { if ((SR.LaneMask & Mask).none()) continue; if (!SR.liveAt(UseIdx)) return false; // Early exit if all used lanes are checked. No need to continue. Mask &= ~SR.LaneMask; if (Mask.none()) break; } } return true; } /// If \p MO is a virtual read register, returns it. Otherwise returns the /// sentinel register. static Register getRegDependency(const MachineOperand &MO) { if (!MO.isReg() || !MO.readsReg()) return Register(); Register Reg = MO.getReg(); if (Reg.isPhysical()) { // By the requirements on trivially rematerializable instructions, a // physical register use is either constant or ignorable. return Register(); } return Reg; } RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx, unsigned UseRegion, DependencyReuseInfo &DRI) { MachineInstr *FirstMI = getReg(RootIdx).getRegionUseBounds(UseRegion, LIS).first; // If there are no users in the region, rematerialize the register at the very // end of the region. MachineBasicBlock::iterator InsertPos = FirstMI ? FirstMI : Regions[UseRegion].second; RegisterIdx NewRegIdx = rematerializeToPos(RootIdx, UseRegion, InsertPos, DRI); transferRegionUsers(RootIdx, NewRegIdx, UseRegion); return NewRegIdx; } RegisterIdx Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion, MachineBasicBlock::iterator InsertPos, DependencyReuseInfo &DRI) { assert(!DRI.DependencyMap.contains(RootIdx)); LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << '\n'); SmallVector NewDeps; // Copy all dependencies because recursive rematerialization of dependencies // may invalidate references to the backing vector of registers. SmallVector OldDeps(getReg(RootIdx).Dependencies); for (const Reg::Dependency &Dep : OldDeps) { // Recursively rematerialize required dependencies at the same position as // the root. Registers form a DAG so the recursion is guaranteed to // terminate. auto RematIdx = DRI.DependencyMap.find(Dep.RegIdx); RegisterIdx NewDepRegIdx; if (RematIdx == DRI.DependencyMap.end()) NewDepRegIdx = rematerializeToPos(Dep.RegIdx, UseRegion, InsertPos, DRI); else NewDepRegIdx = RematIdx->second; NewDeps.emplace_back(Dep.MOIdx, NewDepRegIdx); } RegisterIdx NewIdx = rematerializeReg(RootIdx, UseRegion, InsertPos, std::move(NewDeps)); DRI.DependencyMap.insert({RootIdx, NewIdx}); return NewIdx; } void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, unsigned UserRegion, MachineInstr &UserMI) { transferUserImpl(FromRegIdx, ToRegIdx, UserMI); Regs[FromRegIdx].eraseUser(&UserMI, UserRegion); Regs[ToRegIdx].addUser(&UserMI, UserRegion); deleteRegIfUnused(FromRegIdx); } void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, unsigned UseRegion) { auto &FromRegUsers = Regs[FromRegIdx].Uses; auto UsesIt = FromRegUsers.find(UseRegion); if (UsesIt == FromRegUsers.end()) return; const SmallDenseSet &RegionUsers = UsesIt->getSecond(); for (MachineInstr *UserMI : RegionUsers) transferUserImpl(FromRegIdx, ToRegIdx, *UserMI); Regs[ToRegIdx].addUsers(RegionUsers, UseRegion); FromRegUsers.erase(UseRegion); deleteRegIfUnused(FromRegIdx); } void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx) { Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx]; for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) { for (MachineInstr *UserMI : RegionUsers) transferUserImpl(FromRegIdx, ToRegIdx, *UserMI); ToReg.addUsers(RegionUsers, UseRegion); } FromReg.Uses.clear(); deleteRegIfUnused(FromRegIdx); } void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, MachineInstr &UserMI) { assert(FromRegIdx != ToRegIdx && "identical registers"); assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) && "unrelated registers"); LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to " << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n'); UserMI.substituteRegister(getReg(FromRegIdx).getDefReg(), getReg(ToRegIdx).getDefReg(), 0, TRI); LISUpdates.insert(FromRegIdx); LISUpdates.insert(ToRegIdx); // If the user is rematerializable, we must change its dependency to the // new register. if (RegisterIdx UserRegIdx = getDefRegIdx(UserMI); UserRegIdx != NoReg) { // Look for the user's dependency that matches the register. for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) { if (Dep.RegIdx == FromRegIdx) { Dep.RegIdx = ToRegIdx; return; } } llvm_unreachable("broken dependency"); } } void Rematerializer::updateLiveIntervals() { DenseSet SeenUnrematRegs; for (RegisterIdx RegIdx : LISUpdates) { const Reg &UpdateReg = getReg(RegIdx); assert(UpdateReg.isAlive() && "dead register"); Register DefReg = UpdateReg.getDefReg(); if (LIS.hasInterval(DefReg)) LIS.removeInterval(DefReg); // Rematerializable registers have a single definition by construction so // re-creating their interval cannot yield a live interval with multiple // connected components. LIS.createAndComputeVirtRegInterval(DefReg); LLVM_DEBUG({ dbgs() << "Re-computed interval for " << printID(RegIdx) << ": "; LIS.getInterval(DefReg).print(dbgs()); dbgs() << '\n' << printRegUsers(RegIdx); }); // Update intervals for unrematerializable operands. for (unsigned MOIdx : getUnrematableOprds(RegIdx)) { Register UnrematReg = UpdateReg.DefMI->getOperand(MOIdx).getReg(); if (!SeenUnrematRegs.insert(UnrematReg).second) continue; LIS.removeInterval(UnrematReg); bool NeedSplit = false; // Unrematerializable registers may end up with multiple connected // components in their live interval after it is re-created. It needs to // be split in such cases. We don't track unrematerializable registers by // their actual register index (just by operand index) so we do not need // to update any state in the rematerializer. LiveInterval &LI = LIS.createAndComputeVirtRegInterval(UnrematReg, NeedSplit); if (NeedSplit) { SmallVector SplitLIs; LIS.splitSeparateComponents(LI, SplitLIs); } LLVM_DEBUG( dbgs() << " Re-computed interval for register " << printReg(UnrematReg, &TRI, UpdateReg.DefMI->getOperand(MOIdx).getSubReg(), &MRI) << '\n'); } } LISUpdates.clear(); } bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO, ArrayRef Uses) const { if (Uses.empty()) return true; Register Reg = MO.getReg(); unsigned SubIdx = MO.getSubReg(); LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx) : MRI.getMaxLaneMaskForVReg(Reg); const LiveInterval &LI = LIS.getInterval(Reg); const VNInfo *DefVN = LI.getVNInfoAt(LIS.getInstructionIndex(*MO.getParent()).getRegSlot(true)); for (SlotIndex Use : Uses) { if (!isIdenticalAtUse(*DefVN, Mask, Use, LI)) return false; } return true; } RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx, unsigned Region, SlotIndex Before) const { auto It = Rematerializations.find(getOriginOrSelf(RegIdx)); if (It == Rematerializations.end()) return NoReg; const RematsOf &Remats = It->getSecond(); SlotIndex BestSlot; RegisterIdx BestRegIdx = NoReg; for (RegisterIdx RematRegIdx : Remats) { const Reg &RematReg = getReg(RematRegIdx); if (RematReg.DefRegion != Region || RematReg.Uses.empty()) continue; SlotIndex RematRegSlot = LIS.getInstructionIndex(*RematReg.DefMI).getRegSlot(); if (RematRegSlot < Before && (BestRegIdx == NoReg || RematRegSlot > BestSlot)) { BestSlot = RematRegSlot; BestRegIdx = RematRegIdx; } } return BestRegIdx; } void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) { if (!getReg(RootIdx).Uses.empty()) return; // Traverse the root's dependency DAG depth-first to find the set of registers // we can delete and a legal order to delete them in. SmallVector DepDAG{RootIdx}; SmallSetVector DeleteOrder; DeleteOrder.insert(RootIdx); do { // A deleted register's dependencies may be deletable too. const Reg &DeleteReg = getReg(DepDAG.pop_back_val()); for (const Reg::Dependency &Dep : DeleteReg.Dependencies) { // All dependencies loose a user (the deleted register). Reg &DepReg = Regs[Dep.RegIdx]; DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion); if (DepReg.Uses.empty()) { DeleteOrder.insert(Dep.RegIdx); DepDAG.push_back(Dep.RegIdx); } } } while (!DepDAG.empty()); for (RegisterIdx RegIdx : reverse(DeleteOrder)) { Reg &DeleteReg = Regs[RegIdx]; // It is possible that the defined register we are deleting doesn't have an // interval yet if the LIS hasn't been updated since it was created. Register DefReg = DeleteReg.getDefReg(); if (LIS.hasInterval(DefReg)) LIS.removeInterval(DefReg); LISUpdates.erase(RegIdx); deleteReg(RegIdx); if (isRematerializedRegister(RegIdx)) { // Delete rematerialized register from its origin's rematerializations. const RegisterIdx OriginIdx = getOriginOf(RegIdx); RematsOf &OriginRemats = Rematerializations.at(OriginIdx); assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link"); OriginRemats.erase(RegIdx); if (OriginRemats.empty()) Rematerializations.erase(OriginIdx); } LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n"); } } void Rematerializer::deleteReg(RegisterIdx RegIdx) { noteRegDeleted(RegIdx); Reg &DeleteReg = Regs[RegIdx]; assert(DeleteReg.DefMI && "register was already deleted"); // It is not possible for the deleted instruction to be the upper region // boundary since we don't ever consider them rematerializable. MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first; if (RegionBegin == DeleteReg.DefMI) RegionBegin = std::next(MachineBasicBlock::iterator(DeleteReg.DefMI)); LIS.RemoveMachineInstrFromMaps(*DeleteReg.DefMI); DeleteReg.DefMI->eraseFromParent(); DeleteReg.DefMI = nullptr; } Rematerializer::Rematerializer(MachineFunction &MF, SmallVectorImpl &Regions, LiveIntervals &LIS) : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS), TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) { #ifdef EXPENSIVE_CHECKS // Check that regions are valid. DenseSet SeenMIs; for (const auto &[RegionBegin, RegionEnd] : Regions) { assert(RegionBegin != RegionEnd && "empty region"); for (auto MI = RegionBegin; MI != RegionEnd; ++MI) { bool IsNewMI = SeenMIs.insert(&*MI).second; assert(IsNewMI && "overlapping regions"); assert(!MI->isTerminator() && "terminator in region"); } if (RegionEnd != RegionBegin->getParent()->end()) { bool IsNewMI = SeenMIs.insert(&*RegionEnd).second; assert(IsNewMI && "overlapping regions (upper bound)"); } } #endif } bool Rematerializer::analyze() { Regs.clear(); UnrematableOprds.clear(); Origins.clear(); Rematerializations.clear(); RegionMBB.clear(); RegToIdx.clear(); LISUpdates.clear(); if (Regions.empty()) return false; /// Maps all MIs to their parent region. Region terminators are considered /// part of the region they terminate. DenseMap MIRegion; // Initialize MI to containing region mapping. RegionMBB.reserve(Regions.size()); for (unsigned I = 0, E = Regions.size(); I < E; ++I) { RegionBoundaries Region = Regions[I]; assert(Region.first != Region.second && "empty cannot be region"); for (auto MI = Region.first; MI != Region.second; ++MI) { assert(!MIRegion.contains(&*MI) && "regions should not intersect"); MIRegion.insert({&*MI, I}); } MachineBasicBlock &MBB = *Region.first->getParent(); RegionMBB.push_back(&MBB); // A terminator instruction is considered part of the region it terminates. if (Region.second != MBB.end()) { MachineInstr *RegionTerm = &*Region.second; assert(!MIRegion.contains(RegionTerm) && "regions should not intersect"); MIRegion.insert({RegionTerm, I}); } } const unsigned NumVirtRegs = MRI.getNumVirtRegs(); BitVector SeenRegs(NumVirtRegs); for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) { if (!SeenRegs[I]) addRegIfRematerializable(I, MIRegion, SeenRegs); } assert(Regs.size() == UnrematableOprds.size()); LLVM_DEBUG({ for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I) dbgs() << printDependencyDAG(I) << '\n'; }); return !Regs.empty(); } void Rematerializer::addRegIfRematerializable( unsigned VirtRegIdx, const DenseMap &MIRegion, BitVector &SeenRegs) { assert(!SeenRegs[VirtRegIdx] && "register already seen"); Register DefReg = Register::index2VirtReg(VirtRegIdx); SeenRegs.set(VirtRegIdx); MachineOperand *MO = MRI.getOneDef(DefReg); if (!MO) return; MachineInstr &DefMI = *MO->getParent(); if (!isMIRematerializable(DefMI)) return; auto DefRegion = MIRegion.find(&DefMI); if (DefRegion == MIRegion.end()) return; Reg RematReg; RematReg.DefMI = &DefMI; RematReg.DefRegion = DefRegion->second; unsigned SubIdx = DefMI.getOperand(0).getSubReg(); RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx) : MRI.getMaxLaneMaskForVReg(DefReg); // Collect the candidate's direct users, both rematerializable and // unrematerializable. MIs outside provided regions cannot be tracked so the // registers they use are not safely rematerializable. for (MachineInstr &UseMI : MRI.use_nodbg_instructions(DefReg)) { if (auto UseRegion = MIRegion.find(&UseMI); UseRegion != MIRegion.end()) RematReg.addUser(&UseMI, UseRegion->second); else return; } if (RematReg.Uses.empty()) return; // Collect the candidate's dependencies. If the same register is used // multiple times we just need to consider it once. SmallDenseSet AllDepRegs; SmallVector UnrematDeps; for (const auto &[MOIdx, MO] : enumerate(RematReg.DefMI->operands())) { Register DepReg = getRegDependency(MO); if (!DepReg || !AllDepRegs.insert(DepReg).second) continue; unsigned DepRegIdx = DepReg.virtRegIndex(); if (!SeenRegs[DepRegIdx]) addRegIfRematerializable(DepRegIdx, MIRegion, SeenRegs); if (auto DepIt = RegToIdx.find(DepReg); DepIt != RegToIdx.end()) RematReg.Dependencies.push_back(Reg::Dependency(MOIdx, DepIt->second)); else UnrematDeps.push_back(MOIdx); } // The register is rematerializable. RegToIdx.insert({DefReg, Regs.size()}); Regs.push_back(RematReg); UnrematableOprds.push_back(UnrematDeps); } bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const { if (!TII.isReMaterializable(MI)) return false; assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual"); assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def"); for (const MachineOperand &MO : MI.all_uses()) { // We can't remat physreg uses, unless it is a constant or an ignorable // use (e.g. implicit exec use on VALU instructions) if (MO.getReg().isPhysical()) { if (MRI.isConstantPhysReg(MO.getReg()) || TII.isIgnorableUse(MO)) continue; return false; } } return true; } RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const { if (!MI.getNumOperands() || !MI.getOperand(0).isReg() || MI.getOperand(0).readsReg()) return NoReg; Register Reg = MI.getOperand(0).getReg(); auto UserRegIt = RegToIdx.find(Reg); if (UserRegIt == RegToIdx.end()) return NoReg; return UserRegIt->second; } RegisterIdx Rematerializer::rematerializeReg( RegisterIdx RegIdx, unsigned UseRegion, MachineBasicBlock::iterator InsertPos, SmallVectorImpl &&Dependencies) { RegisterIdx NewRegIdx = Regs.size(); Reg &NewReg = Regs.emplace_back(); Reg &FromReg = Regs[RegIdx]; NewReg.Mask = FromReg.Mask; NewReg.DefRegion = UseRegion; NewReg.Dependencies = std::move(Dependencies); // Track rematerialization link between registers. Origins are always // registers that existed originally, and rematerializations are always // attached to them. const RegisterIdx OriginIdx = getOriginOrSelf(RegIdx); Origins.push_back(OriginIdx); Rematerializations[OriginIdx].insert(NewRegIdx); // Use the TII to rematerialize the defining instruction with a new defined // register. Register NewDefReg = MRI.cloneVirtualRegister(FromReg.getDefReg()); TII.reMaterialize(*RegionMBB[UseRegion], InsertPos, NewDefReg, 0, *FromReg.DefMI); NewReg.DefMI = &*std::prev(InsertPos); RegToIdx.insert({NewDefReg, NewRegIdx}); postRematerialization(RegIdx, NewRegIdx, InsertPos); noteRegCreated(NewRegIdx); LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as " << printRematReg(NewRegIdx) << '\n'); return NewRegIdx; } void Rematerializer::recreateReg( RegisterIdx RegIdx, unsigned DefRegion, MachineBasicBlock::iterator InsertPos, Register DefReg, SmallVectorImpl &&Dependencies) { assert(RegToIdx.contains(DefReg) && "unknown defined register"); assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register"); assert(!getReg(RegIdx).DefMI && "register is still alive"); Reg &OriginReg = Regs[RegIdx]; OriginReg.DefRegion = DefRegion; OriginReg.Dependencies = std::move(Dependencies); // Re-establish the link between origin and rematerialization if necessary. const bool RecreateOriginalReg = isOriginalRegister(RegIdx); if (!RecreateOriginalReg) Rematerializations[getOriginOf(RegIdx)].insert(RegIdx); // Rematerialize from one of the existing rematerializations or from the // origin. We expect at least one to exist, otherwise it would mean the value // held by the original register is no longer available anywhere in the MF. RegisterIdx ModelRegIdx; if (RecreateOriginalReg) { assert(Rematerializations.contains(RegIdx) && "expected remats"); ModelRegIdx = *Rematerializations.at(RegIdx).begin(); } else { assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin"); ModelRegIdx = getOriginOf(RegIdx); } const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI; TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI); OriginReg.DefMI = &*std::prev(InsertPos); postRematerialization(ModelRegIdx, RegIdx, InsertPos); LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as " << printRematReg(RegIdx) << '\n'); } void Rematerializer::postRematerialization( RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx, MachineBasicBlock::iterator InsertPos) { // The start of the new register's region may have changed. Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx]; LIS.InsertMachineInstrInMaps(*RematReg.DefMI); MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first; if (RegionBegin == std::next(MachineBasicBlock::iterator(RematReg.DefMI))) RegionBegin = RematReg.DefMI; // Replace dependencies as needed in the rematerialized MI. All dependencies // of the latter gain a new user. auto ZipedDeps = zip_equal(ModelReg.Dependencies, RematReg.Dependencies); for (const auto &[OldDep, NewDep] : ZipedDeps) { assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch"); LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": " << printID(OldDep.RegIdx) << " -> " << printID(NewDep.RegIdx) << '\n'); Reg &NewDepReg = Regs[NewDep.RegIdx]; if (OldDep.RegIdx != NewDep.RegIdx) { Register OldDefReg = ModelReg.DefMI->getOperand(OldDep.MOIdx).getReg(); RematReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0, TRI); LISUpdates.insert(OldDep.RegIdx); } NewDepReg.addUser(RematReg.DefMI, RematReg.DefRegion); LISUpdates.insert(NewDep.RegIdx); } } std::pair Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion, const LiveIntervals &LIS) const { auto It = Uses.find(UseRegion); if (It == Uses.end()) return {nullptr, nullptr}; const RegionUsers &RegionUsers = It->getSecond(); assert(!RegionUsers.empty() && "empty userset in region"); auto User = RegionUsers.begin(), UserEnd = RegionUsers.end(); MachineInstr *FirstMI = *User, *LastMI = FirstMI; SlotIndex FirstIndex = LIS.getInstructionIndex(*FirstMI), LastIndex = FirstIndex; while (++User != UserEnd) { SlotIndex UserIndex = LIS.getInstructionIndex(**User); if (UserIndex < FirstIndex) { FirstIndex = UserIndex; FirstMI = *User; } else if (UserIndex > LastIndex) { LastIndex = UserIndex; LastMI = *User; } } return {FirstMI, LastMI}; } void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) { Uses[Region].insert(MI); } void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers, unsigned Region) { Uses[Region].insert_range(NewUsers); } void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) { RegionUsers &RUsers = Uses.at(Region); assert(RUsers.contains(MI) && "user not in region"); if (RUsers.size() == 1) Uses.erase(Region); else RUsers.erase(MI); } Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const { return Printable([&, RootIdx](raw_ostream &OS) { DenseMap RegDepths; std::function WalkTree = [&](RegisterIdx RegIdx, unsigned Depth) -> void { unsigned MaxDepth = std::max(RegDepths.lookup_or(RegIdx, Depth), Depth); RegDepths.emplace_or_assign(RegIdx, MaxDepth); for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) WalkTree(Dep.RegIdx, Depth + 1); }; WalkTree(RootIdx, 0); // Sort in decreasing depth order to print root at the bottom. SmallVector> Regs(RegDepths.begin(), RegDepths.end()); sort(Regs, [](const auto &LHS, const auto &RHS) { return LHS.second > RHS.second; }); OS << printID(RootIdx) << " has " << Regs.size() - 1 << " dependencies\n"; for (const auto &[RegIdx, Depth] : Regs) { OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' ' << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n'; } OS << printRegUsers(RootIdx); }); } Printable Rematerializer::printID(RegisterIdx RegIdx) const { return Printable([&, RegIdx](raw_ostream &OS) { const Reg &PrintReg = getReg(RegIdx); OS << '(' << RegIdx << '/'; if (!PrintReg.DefMI) { OS << ""; } else { OS << printReg(PrintReg.getDefReg(), &TRI, PrintReg.DefMI->getOperand(0).getSubReg(), &MRI); } OS << ")[" << PrintReg.DefRegion << "]"; }); } Printable Rematerializer::printRematReg(RegisterIdx RegIdx, bool SkipRegions) const { return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) { const Reg &PrintReg = getReg(RegIdx); if (!SkipRegions) { OS << printID(RegIdx) << " [" << PrintReg.DefRegion; if (!PrintReg.Uses.empty()) { assert(PrintReg.DefMI && "dead register cannot have uses"); const LiveInterval &LI = LIS.getInterval(PrintReg.getDefReg()); // First display all regions in which the register is live-through and // not used. bool First = true; for (const auto [I, Bounds] : enumerate(Regions)) { if (Bounds.first == Bounds.second) continue; if (!PrintReg.Uses.contains(I) && LI.liveAt(LIS.getInstructionIndex(*Bounds.first)) && LI.liveAt(LIS.getInstructionIndex(*std::prev(Bounds.second)) .getRegSlot())) { OS << (First ? " - " : ",") << I; First = false; } } OS << (First ? " --> " : " -> "); // Then display regions in which the register is used. auto It = PrintReg.Uses.begin(); OS << It->first; while (++It != PrintReg.Uses.end()) OS << "," << It->first; } OS << "] "; } OS << printID(RegIdx) << ' '; PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false, /*SkipDebugLoc=*/false, /*AddNewLine=*/false); OS << " @ "; LIS.getInstructionIndex(*PrintReg.DefMI).print(OS); }); } Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const { return Printable([&, RegIdx](raw_ostream &OS) { for (const auto &[UseRegion, Users] : getReg(RegIdx).Uses) { for (MachineInstr *MI : Users) OS << " User " << printUser(MI, UseRegion) << '\n'; } }); } Printable Rematerializer::printUser(const MachineInstr *MI, std::optional UseRegion) const { return Printable([&, MI, UseRegion](raw_ostream &OS) { RegisterIdx RegIdx = getDefRegIdx(*MI); if (RegIdx != NoReg) { OS << printID(RegIdx); } else { OS << "(-/-)["; if (UseRegion) OS << *UseRegion; else OS << '?'; OS << ']'; } OS << ' '; MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false, /*SkipDebugLoc=*/false, /*AddNewLine=*/false); OS << " @ "; LIS.getInstructionIndex(*MI).print(OS); }); } Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater, RegisterIdx RegIdx) { const Rematerializer::Reg &Reg = Remater.getReg(RegIdx); DefReg = Reg.getDefReg(); DefRegion = Reg.DefRegion; Dependencies = Reg.Dependencies; InsertPos = std::next(Reg.DefMI->getIterator()); if (InsertPos != Reg.DefMI->getParent()->end()) NextRegIdx = Remater.getDefRegIdx(*InsertPos); else NextRegIdx = Rematerializer::NoReg; } void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater, RegisterIdx RegIdx) { if (RollingBack) return; Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx); } void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater, RegisterIdx RegIdx) { if (RollingBack || Remater.isRematerializedRegister(RegIdx)) return; DeadRegs.try_emplace(RegIdx, Remater, RegIdx); } void Rollbacker::rollback(Rematerializer &Remater) { RollingBack = true; // Re-create deleted registers. for (auto &[RegIdx, Info] : DeadRegs) { assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead"); // The MI that was originally just after the MI defining the register we // are trying to re-create may have been deleted. In such cases, we can // re-create at that MI's own insert position (and apply the same logic // recursively). MachineBasicBlock::iterator InsertPos = Info.InsertPos; RegisterIdx NextRegIdx = Info.NextRegIdx; while (NextRegIdx != Rematerializer::NoReg) { const auto *NextRegRollback = DeadRegs.find(NextRegIdx); if (NextRegRollback == DeadRegs.end()) break; InsertPos = NextRegRollback->second.InsertPos; NextRegIdx = NextRegRollback->second.NextRegIdx; } Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg, std::move(Info.Dependencies)); } // Rollback rematerializations. for (const auto &[RegIdx, RematsOf] : Rematerializations) { for (RegisterIdx RematRegIdx : RematsOf) { // It is possible that rematerializations were deleted. Their users would // have been transfered to some other rematerialization so we can safely // ignore them. Original registers that were deleted were just re-created // so we do not need to check for that. if (Remater.getReg(RematRegIdx).isAlive()) Remater.transferAllUsers(RematRegIdx, RegIdx); } } Remater.updateLiveIntervals(); DeadRegs.clear(); Rematerializations.clear(); RollingBack = false; }