//===- LowerAllowCheckPass.cpp ----------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/RandomNumberGenerator.h" #include #include #include using namespace llvm; #define DEBUG_TYPE "lower-allow-check" static cl::opt HotPercentileCutoff("lower-allow-check-percentile-cutoff-hot", cl::desc("Hot percentile cutoff.")); static cl::opt RandomRate("lower-allow-check-random-rate", cl::desc("Probability value in the range [0.0, 1.0] of " "unconditional pseudo-random checks.")); STATISTIC(NumChecksTotal, "Number of checks"); STATISTIC(NumChecksRemoved, "Number of removed checks"); struct RemarkInfo { ore::NV Kind; ore::NV F; ore::NV BB; explicit RemarkInfo(IntrinsicInst *II) : Kind("Kind", II->getArgOperand(0)), F("Function", II->getParent()->getParent()), BB("Block", II->getParent()->getName()) {} }; static void emitRemark(IntrinsicInst *II, OptimizationRemarkEmitter &ORE, bool Removed) { if (Removed) { ORE.emit([&]() { RemarkInfo Info(II); return OptimizationRemark(DEBUG_TYPE, "Removed", II) << "Removed check: Kind=" << Info.Kind << " F=" << Info.F << " BB=" << Info.BB; }); } else { ORE.emit([&]() { RemarkInfo Info(II); return OptimizationRemarkMissed(DEBUG_TYPE, "Allowed", II) << "Allowed check: Kind=" << Info.Kind << " F=" << Info.F << " BB=" << Info.BB; }); } } static bool lowerAllowChecks(Function &F, FunctionAnalysisManager &AM, const LowerAllowCheckPass::Options &Opts) { // Lazy analysis getters. auto GetBFI = [&AM, &F, BFI = (BlockFrequencyInfo *)nullptr]() mutable -> const BlockFrequencyInfo & { if (!BFI) BFI = &AM.getResult(F); return *BFI; }; auto GetPSI = [&AM, &F, PSI = std::optional()]() mutable -> const ProfileSummaryInfo * { if (!PSI.has_value()) { auto &MAMProxy = AM.getResult(F); PSI = MAMProxy.getCachedResult(*F.getParent()); } return *PSI; }; auto GetORE = [&AM, &F, ORE = (OptimizationRemarkEmitter *)nullptr]() mutable -> OptimizationRemarkEmitter & { if (!ORE) ORE = &AM.getResult(F); return *ORE; }; // List of intrinsics and the constant value they should be lowered to. SmallVector, 16> ReplaceWithValue; std::unique_ptr Rng; auto GetRng = [&]() -> RandomNumberGenerator & { if (!Rng) Rng = F.getParent()->createRNG(F.getName()); return *Rng; }; auto GetCutoff = [&](const IntrinsicInst *II) -> unsigned { if (HotPercentileCutoff.getNumOccurrences()) return HotPercentileCutoff; else if (II->getIntrinsicID() == Intrinsic::allow_ubsan_check) { auto *Kind = cast(II->getArgOperand(0)); if (Kind->getZExtValue() < Opts.cutoffs.size()) return Opts.cutoffs[Kind->getZExtValue()]; } else if (II->getIntrinsicID() == Intrinsic::allow_runtime_check) { return Opts.runtime_check; } return 0; }; auto ShouldRemoveHot = [&](const BasicBlock &BB, unsigned int cutoff) { if (cutoff == 1000000) return true; const ProfileSummaryInfo *PSI = GetPSI(); return PSI && PSI->isHotCountNthPercentile( cutoff, GetBFI().getBlockProfileCount(&BB).value_or(0)); }; auto ShouldRemoveRandom = [&]() { return RandomRate.getNumOccurrences() && !std::bernoulli_distribution(RandomRate)(GetRng()); }; auto ShouldRemove = [&](const IntrinsicInst *II) { unsigned int cutoff = GetCutoff(II); return ShouldRemoveRandom() || ShouldRemoveHot(*(II->getParent()), cutoff); }; for (Instruction &I : instructions(F)) { IntrinsicInst *II = dyn_cast(&I); if (!II) continue; auto ID = II->getIntrinsicID(); switch (ID) { case Intrinsic::allow_ubsan_check: case Intrinsic::allow_runtime_check: { bool ToRemove = ShouldRemove(II); ReplaceWithValue.push_back({ II, !ToRemove, }); emitRemark(II, GetORE(), ToRemove); break; } case Intrinsic::allow_sanitize_address: ReplaceWithValue.push_back( {II, F.hasFnAttribute(Attribute::SanitizeAddress)}); break; case Intrinsic::allow_sanitize_thread: ReplaceWithValue.push_back( {II, F.hasFnAttribute(Attribute::SanitizeThread)}); break; case Intrinsic::allow_sanitize_memory: ReplaceWithValue.push_back( {II, F.hasFnAttribute(Attribute::SanitizeMemory)}); break; case Intrinsic::allow_sanitize_hwaddress: ReplaceWithValue.push_back( {II, F.hasFnAttribute(Attribute::SanitizeHWAddress)}); break; default: break; } } for (auto [I, V] : ReplaceWithValue) { ++NumChecksTotal; if (!V) // If the final value is false, the check is considered removed. ++NumChecksRemoved; I->replaceAllUsesWith(ConstantInt::getBool(I->getType(), V)); I->eraseFromParent(); } return !ReplaceWithValue.empty(); } PreservedAnalyses LowerAllowCheckPass::run(Function &F, FunctionAnalysisManager &AM) { if (F.isDeclaration()) return PreservedAnalyses::all(); return lowerAllowChecks(F, AM, Opts) // We do not change the CFG, we only replace the intrinsics with // true or false. ? PreservedAnalyses::none().preserveSet() : PreservedAnalyses::all(); } bool LowerAllowCheckPass::IsRequested() { return RandomRate.getNumOccurrences() || HotPercentileCutoff.getNumOccurrences(); } void LowerAllowCheckPass::printPipeline( raw_ostream &OS, function_ref MapClassName2PassName) { static_cast *>(this)->printPipeline( OS, MapClassName2PassName); OS << "<"; // Format is // but it's equally valid to specify // cutoffs[0]=70000;cutoffs[1]=70000;cutoffs[2]=70000;cutoffs[5]=90000;... // and that's what we do here. It is verbose but valid and easy to verify // correctness. // TODO: print shorter output by combining adjacent runs, etc. int i = 0; ListSeparator LS(";"); for (unsigned int cutoff : Opts.cutoffs) { if (cutoff > 0) OS << LS << "cutoffs[" << i << "]=" << cutoff; i++; } if (Opts.runtime_check) OS << LS << "runtime_check=" << Opts.runtime_check; OS << '>'; }