We can end up with cases where the VP metadata only has zero profile values, for example if all of the functions end up being external and uninstrumented. This caused fixes an assertion failure on the BOLT builder that came up last time we tried to turn the pass on by default.
266 lines
10 KiB
C++
266 lines
10 KiB
C++
//===- JumpTableToSwitch.cpp ----------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
|
|
#include "llvm/ADT/DenseSet.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/Statistic.h"
|
|
#include "llvm/Analysis/ConstantFolding.h"
|
|
#include "llvm/Analysis/CtxProfAnalysis.h"
|
|
#include "llvm/Analysis/DomTreeUpdater.h"
|
|
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
|
|
#include "llvm/Analysis/PostDominators.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/LLVMContext.h"
|
|
#include "llvm/IR/ProfDataUtils.h"
|
|
#include "llvm/ProfileData/InstrProf.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Error.h"
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
#include <limits>
|
|
|
|
using namespace llvm;
|
|
|
|
static cl::opt<unsigned>
|
|
JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
|
|
cl::desc("Only split jump tables with size less or "
|
|
"equal than JumpTableSizeThreshold."),
|
|
cl::init(10));
|
|
|
|
// TODO: Consider adding a cost model for profitability analysis of this
|
|
// transformation. Currently we replace a jump table with a switch if all the
|
|
// functions in the jump table are smaller than the provided threshold.
|
|
static cl::opt<unsigned> FunctionSizeThreshold(
|
|
"jump-table-to-switch-function-size-threshold", cl::Hidden,
|
|
cl::desc("Only split jump tables containing functions whose sizes are less "
|
|
"or equal than this threshold."),
|
|
cl::init(50));
|
|
|
|
namespace llvm {
|
|
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
|
|
} // end namespace llvm
|
|
|
|
#define DEBUG_TYPE "jump-table-to-switch"
|
|
|
|
STATISTIC(NumEligibleJumpTables, "The number of jump tables seen by the pass "
|
|
"that can be converted if deemed profitable.");
|
|
STATISTIC(NumJumpTablesConverted,
|
|
"The number of jump tables converted into switches.");
|
|
|
|
namespace {
|
|
struct JumpTableTy {
|
|
Value *Index;
|
|
SmallVector<Function *, 10> Funcs;
|
|
};
|
|
} // anonymous namespace
|
|
|
|
static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
|
|
PointerType *PtrTy) {
|
|
Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
|
|
if (!Ptr)
|
|
return std::nullopt;
|
|
|
|
GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
|
|
if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
|
|
return std::nullopt;
|
|
|
|
Function &F = *GEP->getParent()->getParent();
|
|
const DataLayout &DL = F.getDataLayout();
|
|
const unsigned BitWidth =
|
|
DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
|
|
SmallMapVector<Value *, APInt, 4> VariableOffsets;
|
|
APInt ConstantOffset(BitWidth, 0);
|
|
if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
|
|
return std::nullopt;
|
|
if (VariableOffsets.size() != 1)
|
|
return std::nullopt;
|
|
// TODO: consider supporting more general patterns
|
|
if (!ConstantOffset.isZero())
|
|
return std::nullopt;
|
|
APInt StrideBytes = VariableOffsets.front().second;
|
|
const uint64_t JumpTableSizeBytes = GV->getGlobalSize(DL);
|
|
if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
|
|
return std::nullopt;
|
|
++NumEligibleJumpTables;
|
|
const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
|
|
if (N > JumpTableSizeThreshold)
|
|
return std::nullopt;
|
|
|
|
JumpTableTy JumpTable;
|
|
JumpTable.Index = VariableOffsets.front().first;
|
|
JumpTable.Funcs.reserve(N);
|
|
for (uint64_t Index = 0; Index < N; ++Index) {
|
|
// ConstantOffset is zero.
|
|
APInt Offset = Index * StrideBytes;
|
|
Constant *C =
|
|
ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
|
|
auto *Func = dyn_cast_or_null<Function>(C);
|
|
if (!Func || Func->isDeclaration() ||
|
|
Func->getInstructionCount() > FunctionSizeThreshold)
|
|
return std::nullopt;
|
|
JumpTable.Funcs.push_back(Func);
|
|
}
|
|
return JumpTable;
|
|
}
|
|
|
|
static BasicBlock *
|
|
expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
|
|
OptimizationRemarkEmitter &ORE,
|
|
llvm::function_ref<GlobalValue::GUID(const Function &)>
|
|
GetGuidForFunction) {
|
|
++NumJumpTablesConverted;
|
|
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
|
|
|
|
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
|
|
BasicBlock *BB = CB->getParent();
|
|
BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
|
|
BB->getName() + Twine(".tail"));
|
|
DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
|
|
BB->getTerminator()->eraseFromParent();
|
|
|
|
Function &F = *BB->getParent();
|
|
BasicBlock *BBUnreachable = BasicBlock::Create(
|
|
F.getContext(), "default.switch.case.unreachable", &F, Tail);
|
|
IRBuilder<> BuilderUnreachable(BBUnreachable);
|
|
BuilderUnreachable.CreateUnreachable();
|
|
|
|
IRBuilder<> Builder(BB);
|
|
SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
|
|
DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
|
|
|
|
IRBuilder<> BuilderTail(CB);
|
|
PHINode *PHI =
|
|
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
|
|
const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
|
|
|
|
SmallVector<uint64_t> BranchWeights;
|
|
DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
|
|
const bool HadProfile = isValueProfileMD(ProfMD);
|
|
if (HadProfile) {
|
|
// The assumptions, coming in, are that the functions in JT.Funcs are
|
|
// defined in this module (from parseJumpTable).
|
|
assert(llvm::all_of(
|
|
JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
|
|
BranchWeights.reserve(JT.Funcs.size() + 1);
|
|
// The first is the default target, which is the unreachable block created
|
|
// above.
|
|
BranchWeights.push_back(0U);
|
|
uint64_t TotalCount = 0;
|
|
auto Targets = getValueProfDataFromInst(
|
|
*CB, InstrProfValueKind::IPVK_IndirectCallTarget,
|
|
std::numeric_limits<uint32_t>::max(), TotalCount);
|
|
|
|
for (const auto &[G, C] : Targets) {
|
|
[[maybe_unused]] auto It = GuidToCounter.insert({G, C});
|
|
// TODO(boomanaiden154): Currently we do not assert on inserting
|
|
// duplicate GUIDs because we might have multiple zeros when the profile
|
|
// loader fails to map addresses to functions. Readd the assertion that
|
|
// we did insert once this has been fixed.
|
|
}
|
|
}
|
|
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
|
|
BasicBlock *B = BasicBlock::Create(Func->getContext(),
|
|
"call." + Twine(Index), &F, Tail);
|
|
DTUpdates.push_back({DominatorTree::Insert, BB, B});
|
|
DTUpdates.push_back({DominatorTree::Insert, B, Tail});
|
|
|
|
CallBase *Call = cast<CallBase>(CB->clone());
|
|
// The MD_prof metadata (VP kind), if it existed, can be dropped, it doesn't
|
|
// make sense on a direct call. Note that the values are used for the branch
|
|
// weights of the switch.
|
|
Call->setMetadata(LLVMContext::MD_prof, nullptr);
|
|
Call->setCalledFunction(Func);
|
|
Call->insertInto(B, B->end());
|
|
Switch->addCase(
|
|
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
|
|
GlobalValue::GUID FctID = GetGuidForFunction(*Func);
|
|
// It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
|
|
// just some of the jump targets are taken (for the given profile).
|
|
BranchWeights.push_back(FctID == 0U ? 0U
|
|
: GuidToCounter.lookup_or(FctID, 0U));
|
|
UncondBrInst::Create(Tail, B);
|
|
if (PHI)
|
|
PHI->addIncoming(Call, B);
|
|
}
|
|
DTU.applyUpdates(DTUpdates);
|
|
ORE.emit([&]() {
|
|
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
|
|
<< "expanded indirect call into switch";
|
|
});
|
|
// Only set branch weights on the switch if we have non-zero branch weights.
|
|
// We can have no non-zero branch weights while having VP metadata if for
|
|
// example, all of the functions are external and not instrumented.
|
|
if (HadProfile && !ProfcheckDisableMetadataFixes &&
|
|
llvm::any_of(BranchWeights, not_equal_to(0))) {
|
|
setBranchWeights(*Switch, downscaleWeights(BranchWeights),
|
|
/*IsExpected=*/false);
|
|
} else
|
|
setExplicitlyUnknownBranchWeights(*Switch, DEBUG_TYPE);
|
|
if (PHI)
|
|
CB->replaceAllUsesWith(PHI);
|
|
CB->eraseFromParent();
|
|
return Tail;
|
|
}
|
|
|
|
PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
|
|
FunctionAnalysisManager &AM) {
|
|
OptimizationRemarkEmitter &ORE =
|
|
AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
|
|
DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
|
|
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
|
|
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
|
|
bool Changed = false;
|
|
auto FuncToGuid = [InLTO = this->InLTO](const Function &Fct) {
|
|
if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName))
|
|
return AssignGUIDPass::getGUID(Fct);
|
|
|
|
return Function::getGUIDAssumingExternalLinkage(
|
|
getIRPGOFuncName(Fct, InLTO));
|
|
};
|
|
|
|
for (BasicBlock &BB : make_early_inc_range(F)) {
|
|
BasicBlock *CurrentBB = &BB;
|
|
while (CurrentBB) {
|
|
BasicBlock *SplittedOutTail = nullptr;
|
|
for (Instruction &I : make_early_inc_range(*CurrentBB)) {
|
|
auto *Call = dyn_cast<CallInst>(&I);
|
|
if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
|
|
continue;
|
|
auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
|
|
// Skip atomic or volatile loads.
|
|
if (!L || !L->isSimple())
|
|
continue;
|
|
auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
|
|
if (!GEP)
|
|
continue;
|
|
auto *PtrTy = dyn_cast<PointerType>(L->getType());
|
|
assert(PtrTy && "call operand must be a pointer");
|
|
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
|
|
if (!JumpTable)
|
|
continue;
|
|
SplittedOutTail =
|
|
expandToSwitch(Call, *JumpTable, DTU, ORE, FuncToGuid);
|
|
Changed = true;
|
|
break;
|
|
}
|
|
CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
|
|
}
|
|
}
|
|
|
|
if (!Changed)
|
|
return PreservedAnalyses::all();
|
|
|
|
PreservedAnalyses PA;
|
|
if (DT)
|
|
PA.preserve<DominatorTreeAnalysis>();
|
|
if (PDT)
|
|
PA.preserve<PostDominatorTreeAnalysis>();
|
|
return PA;
|
|
}
|