Follow-up to #[189948](https://github.com/llvm/llvm-project/pull/189948#discussion_r3027394937). Addresses review feedback Co-authored-by: padivedi <padivedi@amd.com>
243 lines
8.0 KiB
C++
243 lines
8.0 KiB
C++
//===- UniformityAnalysis.cpp ---------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Analysis/UniformityAnalysis.h"
|
|
#include "llvm/ADT/GenericUniformityImpl.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/Analysis/CycleAnalysis.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/IR/Dominators.h"
|
|
#include "llvm/IR/InstIterator.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
|
|
const Instruction &I) const {
|
|
return isDivergent((const Value *)&I);
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
|
|
const Instruction &Instr) {
|
|
return markDivergent(cast<Value>(&Instr));
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
|
|
const Value *V) {
|
|
for (const auto *User : V->users()) {
|
|
if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
|
|
markDivergent(*UserInstr);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
|
|
const Instruction &Instr) {
|
|
assert(!isAlwaysUniform(Instr));
|
|
if (Instr.isTerminator())
|
|
return;
|
|
pushUsers(cast<Value>(&Instr));
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::printDivergentArgs(
|
|
raw_ostream &OS) const {
|
|
bool HaveDivergentArgs = false;
|
|
for (const auto &Arg : F.args()) {
|
|
if (isDivergent(&Arg)) {
|
|
if (!HaveDivergentArgs) {
|
|
OS << "DIVERGENT ARGUMENTS:\n";
|
|
HaveDivergentArgs = true;
|
|
}
|
|
OS << " DIVERGENT: " << Context.print(&Arg) << '\n';
|
|
}
|
|
}
|
|
return HaveDivergentArgs;
|
|
}
|
|
|
|
template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
|
|
// Pre-populate UniformValues with uniform values, then seed divergence.
|
|
// NeverUniform values are not inserted -- they are divergent by definition
|
|
// and will be reported as such by isDivergent() (not in UniformValues).
|
|
SmallVector<const Value *, 4> DivergentArgs;
|
|
for (auto &Arg : F.args()) {
|
|
if (TTI->getValueUniformity(&Arg) == ValueUniformity::NeverUniform)
|
|
DivergentArgs.push_back(&Arg);
|
|
else
|
|
UniformValues.insert(&Arg);
|
|
}
|
|
for (auto &I : instructions(F)) {
|
|
ValueUniformity IU = TTI->getValueUniformity(&I);
|
|
switch (IU) {
|
|
case ValueUniformity::AlwaysUniform:
|
|
UniformValues.insert(&I);
|
|
addUniformOverride(I);
|
|
continue;
|
|
case ValueUniformity::NeverUniform:
|
|
// Skip inserting -- divergent by definition. Add to Worklist directly
|
|
// so compute() propagates divergence to users.
|
|
if (I.isTerminator())
|
|
DivergentTermBlocks.insert(I.getParent());
|
|
Worklist.push_back(&I);
|
|
continue;
|
|
case ValueUniformity::Custom:
|
|
UniformValues.insert(&I);
|
|
addCustomUniformityCandidate(&I);
|
|
continue;
|
|
case ValueUniformity::Default:
|
|
UniformValues.insert(&I);
|
|
break;
|
|
}
|
|
}
|
|
// Arguments are not instructions and cannot go on the Worklist, so we
|
|
// propagate their divergence to users explicitly here. This must happen
|
|
// after all instructions are in UniformValues so markDivergent (called
|
|
// inside pushUsers) can successfully erase user instructions from the set.
|
|
for (const Value *Arg : DivergentArgs)
|
|
pushUsers(Arg);
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
|
|
const Instruction &I, const Cycle &DefCycle) const {
|
|
assert(!isAlwaysUniform(I));
|
|
for (const Use &U : I.operands()) {
|
|
if (auto *I = dyn_cast<Instruction>(&U)) {
|
|
if (DefCycle.contains(I->getParent()))
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<
|
|
SSAContext>::propagateTemporalDivergence(const Instruction &I,
|
|
const Cycle &DefCycle) {
|
|
for (auto *User : I.users()) {
|
|
auto *UserInstr = cast<Instruction>(User);
|
|
if (DefCycle.contains(UserInstr->getParent()))
|
|
continue;
|
|
markDivergent(*UserInstr);
|
|
recordTemporalDivergence(&I, UserInstr, &DefCycle);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
|
|
const Use &U) const {
|
|
const auto *V = U.get();
|
|
if (isDivergent(V))
|
|
return true;
|
|
if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
|
|
const auto *UseInstr = cast<Instruction>(U.getUser());
|
|
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
bool GenericUniformityAnalysisImpl<SSAContext>::isCustomUniform(
|
|
const Instruction &I) const {
|
|
SmallBitVector UniformArgs(I.getNumOperands());
|
|
for (auto [Idx, Use] : enumerate(I.operands()))
|
|
UniformArgs[Idx] = !isDivergentUse(Use);
|
|
return TTI->isUniform(&I, UniformArgs);
|
|
}
|
|
|
|
// This ensures explicit instantiation of
|
|
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
|
|
template class llvm::GenericUniformityInfo<SSAContext>;
|
|
template struct llvm::GenericUniformityAnalysisImplDeleter<
|
|
llvm::GenericUniformityAnalysisImpl<SSAContext>>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UniformityInfoAnalysis and related pass implementations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
|
|
FunctionAnalysisManager &FAM) {
|
|
TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
|
|
if (!TTI.hasBranchDivergence(&F))
|
|
return UniformityInfo{};
|
|
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
|
|
CycleInfo &CI = FAM.getResult<CycleAnalysis>(F);
|
|
UniformityInfo UI{DT, CI, &TTI};
|
|
UI.compute();
|
|
return UI;
|
|
}
|
|
|
|
AnalysisKey UniformityInfoAnalysis::Key;
|
|
|
|
UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
|
|
: OS(OS) {}
|
|
|
|
PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
|
|
FunctionAnalysisManager &AM) {
|
|
OS << "UniformityInfo for function '" << F.getName() << "':\n";
|
|
AM.getResult<UniformityInfoAnalysis>(F).print(OS);
|
|
|
|
return PreservedAnalyses::all();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UniformityInfoWrapperPass Implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
char UniformityInfoWrapperPass::ID = 0;
|
|
|
|
UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
|
|
|
|
INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
|
|
"Uniformity Analysis", false, true)
|
|
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
|
|
INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
|
|
"Uniformity Analysis", false, true)
|
|
|
|
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<DominatorTreeWrapperPass>();
|
|
AU.addRequiredTransitive<CycleInfoWrapperPass>();
|
|
AU.addRequired<TargetTransformInfoWrapperPass>();
|
|
}
|
|
|
|
bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
|
|
TargetTransformInfo &TTI =
|
|
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
|
|
|
Fn = &F;
|
|
|
|
if (!TTI.hasBranchDivergence(Fn)) {
|
|
UI = UniformityInfo{};
|
|
return false;
|
|
}
|
|
|
|
CycleInfo &CI = getAnalysis<CycleInfoWrapperPass>().getResult();
|
|
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
|
UI = UniformityInfo{DT, CI, &TTI};
|
|
UI.compute();
|
|
return false;
|
|
}
|
|
|
|
void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
|
|
OS << "UniformityInfo for function '" << Fn->getName() << "':\n";
|
|
UI.print(OS);
|
|
}
|
|
|
|
void UniformityInfoWrapperPass::releaseMemory() {
|
|
UI = UniformityInfo{};
|
|
Fn = nullptr;
|
|
}
|