//===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities for working with Profiling Metadata. // //===----------------------------------------------------------------------===// #include "llvm/IR/ProfDataUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/CommandLine.h" using namespace llvm; namespace llvm { extern cl::opt ProfcheckDisableMetadataFixes; } // MD_prof nodes have the following layout // // In general: // { String name, Array of i32 } // // In terms of Types: // { MDString, [i32, i32, ...]} // // Concretely for Branch Weights // { "branch_weights", [i32 1, i32 10000]} // // We maintain some constants here to ensure that we access the branch weights // correctly, and can change the behavior in the future if the layout changes // the minimum number of operands for MD_prof nodes with branch weights static constexpr unsigned MinBWOps = 3; // the minimum number of operands for MD_prof nodes with value profiles static constexpr unsigned MinVPOps = 5; // We may want to add support for other MD_prof types, so provide an abstraction // for checking the metadata type. static bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { // TODO: This routine may be simplified if MD_prof used an enum instead of a // string to differentiate the types of MD_prof nodes. if (!ProfData || !Name || MinOps < 2) return false; unsigned NOps = ProfData->getNumOperands(); if (NOps < MinOps) return false; auto *ProfDataName = dyn_cast(ProfData->getOperand(0)); if (!ProfDataName) return false; return ProfDataName->getString() == Name; } template >> static void extractFromBranchWeightMD(const MDNode *ProfileData, SmallVectorImpl &Weights) { assert(isBranchWeightMD(ProfileData) && "wrong metadata"); unsigned NOps = ProfileData->getNumOperands(); unsigned WeightsIdx = getBranchWeightOffset(ProfileData); assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); Weights.resize(NOps - WeightsIdx); for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { ConstantInt *Weight = mdconst::dyn_extract(ProfileData->getOperand(Idx)); assert(Weight && "Malformed branch_weight in MD_prof node"); assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && "Too many bits for MD_prof branch_weight"); Weights[Idx - WeightsIdx] = Weight->getZExtValue(); } } /// Push the weights right to fit in uint32_t. SmallVector llvm::fitWeights(ArrayRef Weights) { SmallVector Ret; Ret.reserve(Weights.size()); uint64_t Max = *llvm::max_element(Weights); if (Max > UINT_MAX) { unsigned Offset = 32 - llvm::countl_zero(Max); for (const uint64_t &Value : Weights) Ret.push_back(static_cast(Value >> Offset)); } else { append_range(Ret, Weights); } return Ret; } static cl::opt ElideAllZeroBranchWeights("elide-all-zero-branch-weights", #if defined(LLVM_ENABLE_PROFCHECK) cl::init(false) #else cl::init(true) #endif ); const char *MDProfLabels::BranchWeights = "branch_weights"; const char *MDProfLabels::ExpectedBranchWeights = "expected"; const char *MDProfLabels::ValueProfile = "VP"; const char *MDProfLabels::FunctionEntryCount = "function_entry_count"; const char *MDProfLabels::SyntheticFunctionEntryCount = "synthetic_function_entry_count"; const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown"; const char *llvm::LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count"; bool llvm::hasProfMD(const Instruction &I) { return I.hasMetadata(LLVMContext::MD_prof); } bool llvm::isBranchWeightMD(const MDNode *ProfileData) { return isTargetMD(ProfileData, MDProfLabels::BranchWeights, MinBWOps); } bool llvm::isValueProfileMD(const MDNode *ProfileData) { return isTargetMD(ProfileData, MDProfLabels::ValueProfile, MinVPOps); } bool llvm::hasBranchWeightMD(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); return isBranchWeightMD(ProfileData); } static bool hasCountTypeMD(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); // Value profiles record count-type information. if (isValueProfileMD(ProfileData)) return true; // Conservatively assume non CallBase instruction only get taken/not-taken // branch probability, so not interpret them as count. return isa(I) && !isBranchWeightMD(ProfileData); } bool llvm::hasValidBranchWeightMD(const Instruction &I) { return getValidBranchWeightMDNode(I); } bool llvm::hasBranchWeightOrigin(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); return hasBranchWeightOrigin(ProfileData); } bool llvm::hasBranchWeightOrigin(const MDNode *ProfileData) { if (!isBranchWeightMD(ProfileData)) return false; auto *ProfDataName = dyn_cast(ProfileData->getOperand(1)); // NOTE: if we ever have more types of branch weight provenance, // we need to check the string value is "expected". For now, we // supply a more generic API, and avoid the spurious comparisons. assert(ProfDataName == nullptr || ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights); return ProfDataName != nullptr; } unsigned llvm::getBranchWeightOffset(const MDNode *ProfileData) { return hasBranchWeightOrigin(ProfileData) ? 2 : 1; } unsigned llvm::getNumBranchWeights(const MDNode &ProfileData) { return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); } MDNode *llvm::getBranchWeightMDNode(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); if (!isBranchWeightMD(ProfileData)) return nullptr; return ProfileData; } MDNode *llvm::getValidBranchWeightMDNode(const Instruction &I) { auto *ProfileData = getBranchWeightMDNode(I); if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) return ProfileData; return nullptr; } void llvm::extractFromBranchWeightMD32(const MDNode *ProfileData, SmallVectorImpl &Weights) { extractFromBranchWeightMD(ProfileData, Weights); } void llvm::extractFromBranchWeightMD64(const MDNode *ProfileData, SmallVectorImpl &Weights) { extractFromBranchWeightMD(ProfileData, Weights); } bool llvm::extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData)) return false; extractFromBranchWeightMD(ProfileData, Weights); return true; } bool llvm::extractBranchWeights(const Instruction &I, SmallVectorImpl &Weights) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); return extractBranchWeights(ProfileData, Weights); } bool llvm::extractBranchWeights(const Instruction &I, uint64_t &TrueVal, uint64_t &FalseVal) { assert((isa(I)) && "Looking for branch weights on something besides CondBr or Select"); SmallVector Weights; auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); if (!extractBranchWeights(ProfileData, Weights)) return false; if (Weights.size() > 2) return false; TrueVal = Weights[0]; FalseVal = Weights[1]; return true; } bool llvm::extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { TotalVal = 0; if (!ProfileData) return false; auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); if (!ProfDataName) return false; if (ProfDataName->getString() == MDProfLabels::BranchWeights) { unsigned Offset = getBranchWeightOffset(ProfileData); for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { auto *V = mdconst::extract(ProfileData->getOperand(Idx)); TotalVal += V->getValue().getZExtValue(); } return true; } if (ProfDataName->getString() == MDProfLabels::ValueProfile && ProfileData->getNumOperands() > 3) { TotalVal = mdconst::dyn_extract(ProfileData->getOperand(2)) ->getValue() .getZExtValue(); return true; } return false; } bool llvm::extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); } void llvm::setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName) { MDBuilder MDB(I.getContext()); I.setMetadata( LLVMContext::MD_prof, MDNode::get(I.getContext(), {MDB.createString(MDProfLabels::UnknownBranchWeightsMarker), MDB.createString(PassName)})); } void llvm::setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, StringRef PassName, const Function *F) { F = F ? F : I.getFunction(); assert(F && "Either pass a instruction attached to a Function, or explicitly " "pass the Function that it will be attached to"); if (std::optional EC = F->getEntryCount(); EC && EC->getCount() > 0) setExplicitlyUnknownBranchWeights(I, PassName); } MDNode *llvm::getExplicitlyUnknownBranchWeightsIfProfiled(Function &F, StringRef PassName) { if (std::optional EC = F.getEntryCount(); !EC || EC->getCount() == 0) return nullptr; MDBuilder MDB(F.getContext()); return MDNode::get( F.getContext(), {MDB.createString(MDProfLabels::UnknownBranchWeightsMarker), MDB.createString(PassName)}); } void llvm::setExplicitlyUnknownFunctionEntryCount(Function &F, StringRef PassName) { MDBuilder MDB(F.getContext()); F.setMetadata( LLVMContext::MD_prof, MDNode::get(F.getContext(), {MDB.createString(MDProfLabels::UnknownBranchWeightsMarker), MDB.createString(PassName)})); } bool llvm::isExplicitlyUnknownProfileMetadata(const MDNode &MD) { if (MD.getNumOperands() != 2) return false; return MD.getOperand(0).equalsStr(MDProfLabels::UnknownBranchWeightsMarker); } bool llvm::hasExplicitlyUnknownBranchWeights(const Instruction &I) { auto *MD = I.getMetadata(LLVMContext::MD_prof); if (!MD) return false; return isExplicitlyUnknownProfileMetadata(*MD); } void llvm::setBranchWeights(Instruction &I, ArrayRef Weights, bool IsExpected, bool ElideAllZero) { if ((ElideAllZeroBranchWeights && ElideAllZero) && llvm::all_of(Weights, equal_to(0))) { I.setMetadata(LLVMContext::MD_prof, nullptr); return; } MDBuilder MDB(I.getContext()); MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); I.setMetadata(LLVMContext::MD_prof, BranchWeights); } void llvm::setFittedBranchWeights(Instruction &I, ArrayRef Weights, bool IsExpected, bool ElideAllZero) { setBranchWeights(I, fitWeights(Weights), IsExpected, ElideAllZero); } SmallVector llvm::downscaleWeights(ArrayRef Weights, std::optional KnownMaxCount) { uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value() : *llvm::max_element(Weights); assert(MaxCount > 0 && "Bad max count"); uint64_t Scale = calculateCountScale(MaxCount); SmallVector DownscaledWeights; for (const auto &ECI : Weights) DownscaledWeights.push_back(scaleBranchCount(ECI, Scale)); return DownscaledWeights; } void llvm::scaleProfData(Instruction &I, uint64_t S, uint64_t T) { assert(T != 0 && "Caller should guarantee"); auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); if (ProfileData == nullptr) return; auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); if (!ProfDataName || (ProfDataName->getString() != MDProfLabels::BranchWeights && ProfDataName->getString() != MDProfLabels::ValueProfile)) return; if (!hasCountTypeMD(I)) return; LLVMContext &C = I.getContext(); MDBuilder MDB(C); SmallVector Vals; Vals.push_back(ProfileData->getOperand(0)); APInt APS(128, S), APT(128, T); if (ProfDataName->getString() == MDProfLabels::BranchWeights && ProfileData->getNumOperands() > 0) { // Using APInt::div may be expensive, but most cases should fit 64 bits. APInt Val(128, mdconst::dyn_extract( ProfileData->getOperand(getBranchWeightOffset(ProfileData))) ->getValue() .getZExtValue()); Val *= APS; Vals.push_back(MDB.createConstant(ConstantInt::get( Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); } else if (ProfDataName->getString() == MDProfLabels::ValueProfile) for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) { // The first value is the key of the value profile, which will not change. Vals.push_back(ProfileData->getOperand(Idx)); uint64_t Count = mdconst::dyn_extract(ProfileData->getOperand(Idx + 1)) ->getValue() .getZExtValue(); // Don't scale the magic number. if (Count == NOMORE_ICP_MAGICNUM) { Vals.push_back(ProfileData->getOperand(Idx + 1)); continue; } // Using APInt::div may be expensive, but most cases should fit 64 bits. APInt Val(128, Count); Val *= APS; Vals.push_back(MDB.createConstant(ConstantInt::get( Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); } I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); } void llvm::applyProfMetadataIfEnabled( Value *V, llvm::function_ref setMetadataCallback) { if (!ProfcheckDisableMetadataFixes) { if (Instruction *Inst = dyn_cast(V)) { setMetadataCallback(Inst); } } }