Follow-up to #[189948](https://github.com/llvm/llvm-project/pull/189948#discussion_r3027394937). Addresses review feedback Co-authored-by: padivedi <padivedi@amd.com>
297 lines
9.5 KiB
C++
297 lines
9.5 KiB
C++
//===- MachineUniformityAnalysis.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
|
|
#include "llvm/ADT/GenericUniformityImpl.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/CodeGen/MachineCycleAnalysis.h"
|
|
#include "llvm/CodeGen/MachineDominators.h"
|
|
#include "llvm/CodeGen/MachineRegisterInfo.h"
|
|
#include "llvm/CodeGen/MachineSSAContext.h"
|
|
#include "llvm/CodeGen/TargetInstrInfo.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
|
|
const MachineInstr &I) const {
|
|
for (auto &Op : I.all_defs()) {
|
|
if (isDivergent(Op.getReg()))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
|
|
const MachineInstr &Instr) {
|
|
bool InsertedDivergent = false;
|
|
const auto &MRI = F.getRegInfo();
|
|
const auto &RBI = *F.getSubtarget().getRegBankInfo();
|
|
const auto &TRI = *MRI.getTargetRegisterInfo();
|
|
for (auto &Op : Instr.all_defs()) {
|
|
if (!Op.getReg().isVirtual())
|
|
continue;
|
|
assert(!Op.getSubReg());
|
|
if (TRI.isUniformReg(MRI, RBI, Op.getReg()))
|
|
continue;
|
|
InsertedDivergent |= markDivergent(Op.getReg());
|
|
}
|
|
return InsertedDivergent;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
|
|
// Pre-populate UniformValues with all register defs. Physical register defs
|
|
// are included because they are never analyzed for divergence (initialize
|
|
// and markDefsDivergent skip them), so they must be in UniformValues to
|
|
// avoid being falsely reported as divergent.
|
|
for (const MachineBasicBlock &BB : F) {
|
|
for (const MachineInstr &MI : BB.instrs()) {
|
|
for (const MachineOperand &Op : MI.all_defs()) {
|
|
Register Reg = Op.getReg();
|
|
if (Reg)
|
|
UniformValues.insert(Reg);
|
|
}
|
|
}
|
|
}
|
|
|
|
const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
|
|
|
|
for (const MachineBasicBlock &MBB : F) {
|
|
for (const MachineInstr &MI : MBB) {
|
|
ValueUniformity VU = InstrInfo.getValueUniformity(MI);
|
|
|
|
switch (VU) {
|
|
case ValueUniformity::AlwaysUniform:
|
|
addUniformOverride(MI);
|
|
break;
|
|
case ValueUniformity::NeverUniform:
|
|
markDivergent(MI);
|
|
break;
|
|
case ValueUniformity::Custom:
|
|
break;
|
|
case ValueUniformity::Default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
|
|
Register Reg) {
|
|
assert(isDivergent(Reg));
|
|
const auto &RegInfo = F.getRegInfo();
|
|
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
|
|
markDivergent(UserInstr);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
|
|
const MachineInstr &Instr) {
|
|
assert(!isAlwaysUniform(Instr));
|
|
if (Instr.isTerminator())
|
|
return;
|
|
for (const MachineOperand &Op : Instr.all_defs()) {
|
|
auto Reg = Op.getReg();
|
|
if (isDivergent(Reg))
|
|
pushUsers(Reg);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
|
|
const MachineInstr &I, const MachineCycle &DefCycle) const {
|
|
assert(!isAlwaysUniform(I));
|
|
for (auto &Op : I.operands()) {
|
|
if (!Op.isReg() || !Op.readsReg())
|
|
continue;
|
|
auto Reg = Op.getReg();
|
|
|
|
// FIXME: Physical registers need to be properly checked instead of always
|
|
// returning true
|
|
if (Reg.isPhysical())
|
|
return true;
|
|
|
|
auto *Def = F.getRegInfo().getVRegDef(Reg);
|
|
if (DefCycle.contains(Def->getParent()))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
|
|
propagateTemporalDivergence(const MachineInstr &I,
|
|
const MachineCycle &DefCycle) {
|
|
const auto &RegInfo = F.getRegInfo();
|
|
for (auto &Op : I.all_defs()) {
|
|
if (!Op.getReg().isVirtual())
|
|
continue;
|
|
auto Reg = Op.getReg();
|
|
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
|
|
if (DefCycle.contains(UserInstr.getParent()))
|
|
continue;
|
|
markDivergent(UserInstr);
|
|
|
|
recordTemporalDivergence(Reg, &UserInstr, &DefCycle);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
|
|
const MachineOperand &U) const {
|
|
if (!U.isReg())
|
|
return false;
|
|
|
|
auto Reg = U.getReg();
|
|
if (isDivergent(Reg))
|
|
return true;
|
|
|
|
const auto &RegInfo = F.getRegInfo();
|
|
auto *Def = RegInfo.getOneDef(Reg);
|
|
if (!Def)
|
|
return true;
|
|
|
|
auto *DefInstr = Def->getParent();
|
|
auto *UseInstr = U.getParent();
|
|
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
|
|
}
|
|
|
|
template <>
|
|
bool GenericUniformityAnalysisImpl<MachineSSAContext>::isCustomUniform(
|
|
const MachineInstr &MI) const {
|
|
llvm_unreachable("no MIR instructions use Custom uniformity yet");
|
|
}
|
|
|
|
// This ensures explicit instantiation of
|
|
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
|
|
template class llvm::GenericUniformityInfo<MachineSSAContext>;
|
|
template struct llvm::GenericUniformityAnalysisImplDeleter<
|
|
llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
|
|
|
|
MachineUniformityInfo llvm::computeMachineUniformityInfo(
|
|
MachineFunction &F, const MachineCycleInfo &CI,
|
|
const MachineDominatorTree &DT, bool HasBranchDivergence) {
|
|
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
|
|
MachineUniformityInfo UI(DT, CI);
|
|
if (HasBranchDivergence)
|
|
UI.compute();
|
|
return UI;
|
|
}
|
|
|
|
namespace {
|
|
|
|
class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
|
|
public:
|
|
static char ID;
|
|
|
|
MachineUniformityInfoPrinterPass();
|
|
|
|
bool runOnMachineFunction(MachineFunction &F) override;
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
AnalysisKey MachineUniformityAnalysis::Key;
|
|
|
|
MachineUniformityAnalysis::Result
|
|
MachineUniformityAnalysis::run(MachineFunction &MF,
|
|
MachineFunctionAnalysisManager &MFAM) {
|
|
MachineDominatorTree &DT = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
|
|
MachineCycleInfo &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
|
|
FunctionAnalysisManager &FAM =
|
|
MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
|
|
.getManager();
|
|
Function &F = MF.getFunction();
|
|
TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
|
|
return computeMachineUniformityInfo(MF, CI, DT, TTI.hasBranchDivergence(&F));
|
|
}
|
|
|
|
PreservedAnalyses
|
|
MachineUniformityPrinterPass::run(MachineFunction &MF,
|
|
MachineFunctionAnalysisManager &MFAM) {
|
|
MachineUniformityInfo &MUI = MFAM.getResult<MachineUniformityAnalysis>(MF);
|
|
OS << "MachineUniformityInfo for function: ";
|
|
MF.getFunction().printAsOperand(OS, /*PrintType=*/false);
|
|
OS << '\n';
|
|
MUI.print(OS);
|
|
return PreservedAnalyses::all();
|
|
}
|
|
|
|
char MachineUniformityAnalysisPass::ID = 0;
|
|
|
|
MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
|
|
: MachineFunctionPass(ID) {}
|
|
|
|
INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
|
|
"Machine Uniformity Info Analysis", false, true)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
|
|
"Machine Uniformity Info Analysis", false, true)
|
|
|
|
void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
|
|
AU.addRequired<MachineDominatorTreeWrapperPass>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
|
|
MachineDominatorTree &DT =
|
|
getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
|
|
MachineCycleInfo &CI =
|
|
getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
|
|
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
|
|
// default NoTTI
|
|
UI = computeMachineUniformityInfo(MF, CI, DT, true);
|
|
return false;
|
|
}
|
|
|
|
void MachineUniformityAnalysisPass::print(raw_ostream &OS,
|
|
const Module *) const {
|
|
OS << "MachineUniformityInfo for function: ";
|
|
UI.getFunction().getFunction().printAsOperand(OS, /*PrintType=*/false);
|
|
OS << '\n';
|
|
UI.print(OS);
|
|
}
|
|
|
|
char MachineUniformityInfoPrinterPass::ID = 0;
|
|
|
|
MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
|
|
: MachineFunctionPass(ID) {}
|
|
|
|
INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
|
|
"print-machine-uniformity",
|
|
"Print Machine Uniformity Info Analysis", true, true)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
|
|
INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
|
|
"print-machine-uniformity",
|
|
"Print Machine Uniformity Info Analysis", true, true)
|
|
|
|
void MachineUniformityInfoPrinterPass::getAnalysisUsage(
|
|
AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<MachineUniformityAnalysisPass>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
|
|
MachineFunction &F) {
|
|
MachineUniformityAnalysisPass &UI =
|
|
getAnalysis<MachineUniformityAnalysisPass>();
|
|
UI.print(errs());
|
|
return false;
|
|
}
|