//===- MachineUniformityAnalysis.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/MachineUniformityAnalysis.h" #include "llvm/ADT/GenericUniformityImpl.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/MachineCycleAnalysis.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/MachineSSAContext.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/InitializePasses.h" using namespace llvm; template <> bool llvm::GenericUniformityAnalysisImpl::hasDivergentDefs( const MachineInstr &I) const { for (auto &Op : I.all_defs()) { if (isDivergent(Op.getReg())) return true; } return false; } template <> bool llvm::GenericUniformityAnalysisImpl::markDefsDivergent( const MachineInstr &Instr) { bool InsertedDivergent = false; const auto &MRI = F.getRegInfo(); const auto &RBI = *F.getSubtarget().getRegBankInfo(); const auto &TRI = *MRI.getTargetRegisterInfo(); for (auto &Op : Instr.all_defs()) { if (!Op.getReg().isVirtual()) continue; assert(!Op.getSubReg()); if (TRI.isUniformReg(MRI, RBI, Op.getReg())) continue; InsertedDivergent |= markDivergent(Op.getReg()); } return InsertedDivergent; } template <> void llvm::GenericUniformityAnalysisImpl::initialize() { // Pre-populate UniformValues with all register defs. Physical register defs // are included because they are never analyzed for divergence (initialize // and markDefsDivergent skip them), so they must be in UniformValues to // avoid being falsely reported as divergent. for (const MachineBasicBlock &BB : F) { for (const MachineInstr &MI : BB.instrs()) { for (const MachineOperand &Op : MI.all_defs()) { Register Reg = Op.getReg(); if (Reg) UniformValues.insert(Reg); } } } const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); for (const MachineBasicBlock &MBB : F) { for (const MachineInstr &MI : MBB) { ValueUniformity VU = InstrInfo.getValueUniformity(MI); switch (VU) { case ValueUniformity::AlwaysUniform: addUniformOverride(MI); break; case ValueUniformity::NeverUniform: markDivergent(MI); break; case ValueUniformity::Custom: break; case ValueUniformity::Default: break; } } } } template <> void llvm::GenericUniformityAnalysisImpl::pushUsers( Register Reg) { assert(isDivergent(Reg)); const auto &RegInfo = F.getRegInfo(); for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { markDivergent(UserInstr); } } template <> void llvm::GenericUniformityAnalysisImpl::pushUsers( const MachineInstr &Instr) { assert(!isAlwaysUniform(Instr)); if (Instr.isTerminator()) return; for (const MachineOperand &Op : Instr.all_defs()) { auto Reg = Op.getReg(); if (isDivergent(Reg)) pushUsers(Reg); } } template <> bool llvm::GenericUniformityAnalysisImpl::usesValueFromCycle( const MachineInstr &I, const MachineCycle &DefCycle) const { assert(!isAlwaysUniform(I)); for (auto &Op : I.operands()) { if (!Op.isReg() || !Op.readsReg()) continue; auto Reg = Op.getReg(); // FIXME: Physical registers need to be properly checked instead of always // returning true if (Reg.isPhysical()) return true; auto *Def = F.getRegInfo().getVRegDef(Reg); if (DefCycle.contains(Def->getParent())) return true; } return false; } template <> void llvm::GenericUniformityAnalysisImpl:: propagateTemporalDivergence(const MachineInstr &I, const MachineCycle &DefCycle) { const auto &RegInfo = F.getRegInfo(); for (auto &Op : I.all_defs()) { if (!Op.getReg().isVirtual()) continue; auto Reg = Op.getReg(); for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { if (DefCycle.contains(UserInstr.getParent())) continue; markDivergent(UserInstr); recordTemporalDivergence(Reg, &UserInstr, &DefCycle); } } } template <> bool llvm::GenericUniformityAnalysisImpl::isDivergentUse( const MachineOperand &U) const { if (!U.isReg()) return false; auto Reg = U.getReg(); if (isDivergent(Reg)) return true; const auto &RegInfo = F.getRegInfo(); auto *Def = RegInfo.getOneDef(Reg); if (!Def) return true; auto *DefInstr = Def->getParent(); auto *UseInstr = U.getParent(); return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); } template <> bool GenericUniformityAnalysisImpl::isCustomUniform( const MachineInstr &MI) const { llvm_unreachable("no MIR instructions use Custom uniformity yet"); } // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo; template struct llvm::GenericUniformityAnalysisImplDeleter< llvm::GenericUniformityAnalysisImpl>; MachineUniformityInfo llvm::computeMachineUniformityInfo( MachineFunction &F, const MachineCycleInfo &CI, const MachineDominatorTree &DT, bool HasBranchDivergence) { assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); MachineUniformityInfo UI(DT, CI); if (HasBranchDivergence) UI.compute(); return UI; } namespace { class MachineUniformityInfoPrinterPass : public MachineFunctionPass { public: static char ID; MachineUniformityInfoPrinterPass(); bool runOnMachineFunction(MachineFunction &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override; }; } // namespace AnalysisKey MachineUniformityAnalysis::Key; MachineUniformityAnalysis::Result MachineUniformityAnalysis::run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM) { MachineDominatorTree &DT = MFAM.getResult(MF); MachineCycleInfo &CI = MFAM.getResult(MF); FunctionAnalysisManager &FAM = MFAM.getResult(MF) .getManager(); Function &F = MF.getFunction(); TargetTransformInfo &TTI = FAM.getResult(F); return computeMachineUniformityInfo(MF, CI, DT, TTI.hasBranchDivergence(&F)); } PreservedAnalyses MachineUniformityPrinterPass::run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM) { MachineUniformityInfo &MUI = MFAM.getResult(MF); OS << "MachineUniformityInfo for function: "; MF.getFunction().printAsOperand(OS, /*PrintType=*/false); OS << '\n'; MUI.print(OS); return PreservedAnalyses::all(); } char MachineUniformityAnalysisPass::ID = 0; MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() : MachineFunctionPass(ID) {} INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", "Machine Uniformity Info Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", "Machine Uniformity Info Analysis", false, true) void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { MachineDominatorTree &DT = getAnalysis().getDomTree(); MachineCycleInfo &CI = getAnalysis().getCycleInfo(); // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a // default NoTTI UI = computeMachineUniformityInfo(MF, CI, DT, true); return false; } void MachineUniformityAnalysisPass::print(raw_ostream &OS, const Module *) const { OS << "MachineUniformityInfo for function: "; UI.getFunction().getFunction().printAsOperand(OS, /*PrintType=*/false); OS << '\n'; UI.print(OS); } char MachineUniformityInfoPrinterPass::ID = 0; MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() : MachineFunctionPass(ID) {} INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, "print-machine-uniformity", "Print Machine Uniformity Info Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, "print-machine-uniformity", "Print Machine Uniformity Info Analysis", true, true) void MachineUniformityInfoPrinterPass::getAnalysisUsage( AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } bool MachineUniformityInfoPrinterPass::runOnMachineFunction( MachineFunction &F) { MachineUniformityAnalysisPass &UI = getAnalysis(); UI.print(errs()); return false; }