//===- UniformityAnalysis.cpp ---------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/ADT/GenericUniformityImpl.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/Analysis/CycleAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" using namespace llvm; template <> bool llvm::GenericUniformityAnalysisImpl::hasDivergentDefs( const Instruction &I) const { return isDivergent((const Value *)&I); } template <> bool llvm::GenericUniformityAnalysisImpl::markDefsDivergent( const Instruction &Instr) { return markDivergent(cast(&Instr)); } template <> void llvm::GenericUniformityAnalysisImpl::pushUsers( const Value *V) { for (const auto *User : V->users()) { if (const auto *UserInstr = dyn_cast(User)) { markDivergent(*UserInstr); } } } template <> void llvm::GenericUniformityAnalysisImpl::pushUsers( const Instruction &Instr) { assert(!isAlwaysUniform(Instr)); if (Instr.isTerminator()) return; pushUsers(cast(&Instr)); } template <> bool llvm::GenericUniformityAnalysisImpl::printDivergentArgs( raw_ostream &OS) const { bool HaveDivergentArgs = false; for (const auto &Arg : F.args()) { if (isDivergent(&Arg)) { if (!HaveDivergentArgs) { OS << "DIVERGENT ARGUMENTS:\n"; HaveDivergentArgs = true; } OS << " DIVERGENT: " << Context.print(&Arg) << '\n'; } } return HaveDivergentArgs; } template <> void llvm::GenericUniformityAnalysisImpl::initialize() { // Pre-populate UniformValues with uniform values, then seed divergence. // NeverUniform values are not inserted -- they are divergent by definition // and will be reported as such by isDivergent() (not in UniformValues). SmallVector DivergentArgs; for (auto &Arg : F.args()) { if (TTI->getValueUniformity(&Arg) == ValueUniformity::NeverUniform) DivergentArgs.push_back(&Arg); else UniformValues.insert(&Arg); } for (auto &I : instructions(F)) { ValueUniformity IU = TTI->getValueUniformity(&I); switch (IU) { case ValueUniformity::AlwaysUniform: UniformValues.insert(&I); addUniformOverride(I); continue; case ValueUniformity::NeverUniform: // Skip inserting -- divergent by definition. Add to Worklist directly // so compute() propagates divergence to users. if (I.isTerminator()) DivergentTermBlocks.insert(I.getParent()); Worklist.push_back(&I); continue; case ValueUniformity::Custom: UniformValues.insert(&I); addCustomUniformityCandidate(&I); continue; case ValueUniformity::Default: UniformValues.insert(&I); break; } } // Arguments are not instructions and cannot go on the Worklist, so we // propagate their divergence to users explicitly here. This must happen // after all instructions are in UniformValues so markDivergent (called // inside pushUsers) can successfully erase user instructions from the set. for (const Value *Arg : DivergentArgs) pushUsers(Arg); } template <> bool llvm::GenericUniformityAnalysisImpl::usesValueFromCycle( const Instruction &I, const Cycle &DefCycle) const { assert(!isAlwaysUniform(I)); for (const Use &U : I.operands()) { if (auto *I = dyn_cast(&U)) { if (DefCycle.contains(I->getParent())) return true; } } return false; } template <> void llvm::GenericUniformityAnalysisImpl< SSAContext>::propagateTemporalDivergence(const Instruction &I, const Cycle &DefCycle) { for (auto *User : I.users()) { auto *UserInstr = cast(User); if (DefCycle.contains(UserInstr->getParent())) continue; markDivergent(*UserInstr); recordTemporalDivergence(&I, UserInstr, &DefCycle); } } template <> bool llvm::GenericUniformityAnalysisImpl::isDivergentUse( const Use &U) const { const auto *V = U.get(); if (isDivergent(V)) return true; if (const auto *DefInstr = dyn_cast(V)) { const auto *UseInstr = cast(U.getUser()); return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); } return false; } template <> bool GenericUniformityAnalysisImpl::isCustomUniform( const Instruction &I) const { SmallBitVector UniformArgs(I.getNumOperands()); for (auto [Idx, Use] : enumerate(I.operands())) UniformArgs[Idx] = !isDivergentUse(Use); return TTI->isUniform(&I, UniformArgs); } // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo; template struct llvm::GenericUniformityAnalysisImplDeleter< llvm::GenericUniformityAnalysisImpl>; //===----------------------------------------------------------------------===// // UniformityInfoAnalysis and related pass implementations //===----------------------------------------------------------------------===// llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { TargetTransformInfo &TTI = FAM.getResult(F); if (!TTI.hasBranchDivergence(&F)) return UniformityInfo{}; DominatorTree &DT = FAM.getResult(F); CycleInfo &CI = FAM.getResult(F); UniformityInfo UI{DT, CI, &TTI}; UI.compute(); return UI; } AnalysisKey UniformityInfoAnalysis::Key; UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS) : OS(OS) {} PreservedAnalyses UniformityInfoPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { OS << "UniformityInfo for function '" << F.getName() << "':\n"; AM.getResult(F).print(OS); return PreservedAnalyses::all(); } //===----------------------------------------------------------------------===// // UniformityInfoWrapperPass Implementation //===----------------------------------------------------------------------===// char UniformityInfoWrapperPass::ID = 0; UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {} INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", "Uniformity Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", "Uniformity Analysis", false, true) void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired(); AU.addRequiredTransitive(); AU.addRequired(); } bool UniformityInfoWrapperPass::runOnFunction(Function &F) { TargetTransformInfo &TTI = getAnalysis().getTTI(F); Fn = &F; if (!TTI.hasBranchDivergence(Fn)) { UI = UniformityInfo{}; return false; } CycleInfo &CI = getAnalysis().getResult(); DominatorTree &DT = getAnalysis().getDomTree(); UI = UniformityInfo{DT, CI, &TTI}; UI.compute(); return false; } void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { OS << "UniformityInfo for function '" << Fn->getName() << "':\n"; UI.print(OS); } void UniformityInfoWrapperPass::releaseMemory() { UI = UniformityInfo{}; Fn = nullptr; }