Files
llvm-project/llvm/lib/Analysis/IR2Vec.cpp
Alexis Engelke 080bc25728 [IR][NFCI] Remove *WithoutDebug (#187240)
The function instructionsWithoutDebug serves two uses: skipping debug
intrinsics and skipping pseudo instructions. Nonetheless, these
functions are expensive due to out-of-line filtering using
std::function. Ideally, the filter should be inlined, but that would
require including IntrinsicInst.h in BasicBlock.h.

We no longer use debug intrinsics, so the first use (parameter false) is
no longer needed. The second use is sometimes needed, but the
distinction between PseudoProbe instructions can be made at the call
sites more easily in many cases.

Therefore, remove instructionsWithoutDebug/sizeWithoutDebug.

c-t-t stage2-O3 -0.21%.
2026-03-18 15:08:41 +00:00

718 lines
26 KiB
C++

//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions. See the LICENSE file for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the IR2Vec algorithm.
///
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace llvm;
using namespace ir2vec;
#define DEBUG_TYPE "ir2vec"
STATISTIC(VocabMissCounter,
"Number of lookups to entities not present in the vocabulary");
namespace llvm {
namespace ir2vec {
cl::OptionCategory IR2VecCategory("IR2Vec Options");
// FIXME: Use a default vocab when not specified
cl::opt<std::string>
VocabFile("ir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
cl::cat(IR2VecCategory));
cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
cl::desc("Weight for opcode embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::desc("Weight for type embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
cl::opt<IR2VecKind> IR2VecEmbeddingKind(
"ir2vec-kind", cl::Optional,
cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
"Generate symbolic embeddings"),
clEnumValN(IR2VecKind::FlowAware, "flow-aware",
"Generate flow-aware embeddings")),
cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
cl::cat(IR2VecCategory));
} // namespace ir2vec
} // namespace llvm
AnalysisKey IR2VecVocabAnalysis::Key;
// ==----------------------------------------------------------------------===//
// Local helper functions
//===----------------------------------------------------------------------===//
namespace llvm::json {
inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
llvm::json::Path P) {
std::vector<double> TempOut;
if (!llvm::json::fromJSON(E, TempOut, P))
return false;
Out = Embedding(std::move(TempOut));
return true;
}
} // namespace llvm::json
// ==----------------------------------------------------------------------===//
// Embedding
//===----------------------------------------------------------------------===//
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
std::plus<double>());
return *this;
}
Embedding Embedding::operator+(const Embedding &RHS) const {
Embedding Result(*this);
Result += RHS;
return Result;
}
Embedding &Embedding::operator-=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
std::minus<double>());
return *this;
}
Embedding Embedding::operator-(const Embedding &RHS) const {
Embedding Result(*this);
Result -= RHS;
return Result;
}
Embedding &Embedding::operator*=(double Factor) {
std::transform(this->begin(), this->end(), this->begin(),
[Factor](double Elem) { return Elem * Factor; });
return *this;
}
Embedding Embedding::operator*(double Factor) const {
Embedding Result(*this);
Result *= Factor;
return Result;
}
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
assert(this->size() == Src.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
(*this)[Itr] += Src[Itr] * Factor;
return *this;
}
bool Embedding::approximatelyEquals(const Embedding &RHS,
double Tolerance) const {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
<< (*this)[Itr] << " vs " << RHS[Itr]
<< "; Tolerance: " << Tolerance << "\n");
return false;
}
return true;
}
void Embedding::print(raw_ostream &OS) const {
OS << " [";
for (const auto &Elem : Data)
OS << " " << format("%.2f", Elem) << " ";
OS << "]\n";
}
// ==----------------------------------------------------------------------===//
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocab);
case IR2VecKind::FlowAware:
return std::make_unique<FlowAwareEmbedder>(F, Vocab);
}
return nullptr;
}
Embedding Embedder::computeEmbeddings() const {
Embedding FuncVector(Dimension, 0.0);
if (F.isDeclaration())
return FuncVector;
// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F))
FuncVector += computeEmbeddings(*BB);
return FuncVector;
}
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
// We consider only the non-debug and non-pseudo instructions
for (const auto &I : BB)
if (!I.isDebugOrPseudoInst())
BBVector += computeEmbeddings(I);
return BBVector;
}
Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
// Currently, we always (re)compute the embeddings for symbolic embedder.
// This is cheaper than caching the vectors.
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands())
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
return InstVector;
}
Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
// If we have already computed the embedding for this instruction, return it
auto It = InstVecMap.find(&I);
if (It != InstVecMap.end())
return It->second;
// TODO: Handle call instructions differently.
// For now, we treat them like other instructions
Embedding ArgEmb(Dimension, 0);
for (const auto &Op : I.operands()) {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
auto DefIt = InstVecMap.find(DefInst);
// Fixme (#159171): Ideally we should never miss an instruction
// embedding here.
// But when we have cyclic dependencies (e.g., phi
// nodes), we might miss the embedding. In such cases, we fall back to
// using the vocabulary embedding. This can be fixed by iterating to a
// fixed-point, or by using a simple solver for the set of simultaneous
// equations.
// Another case when we might miss an instruction embedding is when
// the operand instruction is in a different basic block that has not
// been processed yet. This can be fixed by processing the basic blocks
// in a topological order.
if (DefIt != InstVecMap.end())
ArgEmb += DefIt->second;
else
ArgEmb += Vocab[*Op];
}
// If the operand is not defined by an instruction, we use the
// vocabulary
else {
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
<< *Op << "=" << Vocab[*Op][0] << "\n");
ArgEmb += Vocab[*Op];
}
}
// Create the instruction vector by combining opcode, type, and arguments
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
return InstVector;
}
// ==----------------------------------------------------------------------===//
// VocabStorage
//===----------------------------------------------------------------------===//
VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
: Sections(std::move(SectionData)), TotalSize([&] {
assert(!Sections.empty() && "Vocabulary has no sections");
// Compute total size across all sections
size_t Size = 0;
for (const auto &Section : Sections) {
assert(!Section.empty() && "Vocabulary section is empty");
Size += Section.size();
}
return Size;
}()),
Dimension([&] {
// Get dimension from the first embedding in the first section - all
// embeddings must have the same dimension
assert(!Sections.empty() && "Vocabulary has no sections");
assert(!Sections[0].empty() && "First section of vocabulary is empty");
unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());
// Verify that all embeddings across all sections have the same
// dimension
[[maybe_unused]] auto allSameDim =
[ExpectedDim](const std::vector<Embedding> &Section) {
return std::all_of(Section.begin(), Section.end(),
[ExpectedDim](const Embedding &Emb) {
return Emb.size() == ExpectedDim;
});
};
assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
"All embeddings must have the same dimension");
return ExpectedDim;
}()) {}
const Embedding &VocabStorage::const_iterator::operator*() const {
assert(SectionId < Storage->Sections.size() && "Invalid section ID");
assert(LocalIndex < Storage->Sections[SectionId].size() &&
"Local index out of range");
return Storage->Sections[SectionId][LocalIndex];
}
VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() {
++LocalIndex;
// Check if we need to move to the next section
if (SectionId < Storage->getNumSections() &&
LocalIndex >= Storage->Sections[SectionId].size()) {
assert(LocalIndex == Storage->Sections[SectionId].size() &&
"Local index should be at the end of the current section");
LocalIndex = 0;
++SectionId;
}
return *this;
}
bool VocabStorage::const_iterator::operator==(
const const_iterator &Other) const {
return Storage == Other.Storage && SectionId == Other.SectionId &&
LocalIndex == Other.LocalIndex;
}
bool VocabStorage::const_iterator::operator!=(
const const_iterator &Other) const {
return !(*this == Other);
}
Error VocabStorage::parseVocabSection(StringRef Key,
const json::Value &ParsedVocabValue,
VocabMap &TargetVocab, unsigned &Dim) {
json::Path::Root Path("");
const json::Object *RootObj = ParsedVocabValue.getAsObject();
if (!RootObj)
return createStringError(errc::invalid_argument,
"JSON root is not an object");
const json::Value *SectionValue = RootObj->get(Key);
if (!SectionValue)
return createStringError(errc::invalid_argument,
"Missing '" + std::string(Key) +
"' section in vocabulary file");
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
return createStringError(errc::illegal_byte_sequence,
"Unable to parse '" + std::string(Key) +
"' section from vocabulary");
Dim = TargetVocab.begin()->second.size();
if (Dim == 0)
return createStringError(errc::illegal_byte_sequence,
"Dimension of '" + std::string(Key) +
"' section of the vocabulary is zero");
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the '" + std::string(Key) +
"' section of the vocabulary are not of the same dimension");
return Error::success();
}
// ==----------------------------------------------------------------------===//
// Vocabulary
//===----------------------------------------------------------------------===//
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
if (Opcode == NUM) { \
return #OPCODE; \
}
#include "llvm/IR/Instruction.def"
#undef HANDLE_INST
return "UnknownOpcode";
}
// Helper function to classify an operand into OperandKind
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
if (isa<Function>(Op))
return OperandKind::FunctionID;
if (isa<PointerType>(Op->getType()))
return OperandKind::PointerID;
if (isa<Constant>(Op))
return OperandKind::ConstantID;
return OperandKind::VariableID;
}
unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE)
return P - CmpInst::FIRST_FCMP_PREDICATE;
else
return P - CmpInst::FIRST_ICMP_PREDICATE +
(CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1);
}
CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
unsigned fcmpRange =
CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1;
if (LocalIndex < fcmpRange)
return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE +
LocalIndex);
else
return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE +
LocalIndex - fcmpRange);
}
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
static SmallString<16> PredNameBuffer;
if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
PredNameBuffer = "FCMP_";
else
PredNameBuffer = "ICMP_";
PredNameBuffer += CmpInst::getPredicateName(Pred);
return PredNameBuffer;
}
StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
if (Pos < PredicateBaseOffset)
return getVocabKeyForOperandKind(
static_cast<OperandKind>(Pos - OperandBaseOffset));
// Predicates
return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const {
auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
return !(PAC.preservedWhenStateless());
}
VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
float DummyVal = 0.1f;
// Create sections for opcodes, types, operands, and predicates
// Order must match Vocabulary::Section enum
std::vector<std::vector<Embedding>> Sections;
Sections.reserve(4);
// Opcodes section
std::vector<Embedding> OpcodeSec;
OpcodeSec.reserve(MaxOpcodes);
for (unsigned I = 0; I < MaxOpcodes; ++I) {
OpcodeSec.emplace_back(Dim, DummyVal);
DummyVal += 0.1f;
}
Sections.push_back(std::move(OpcodeSec));
// Types section
std::vector<Embedding> TypeSec;
TypeSec.reserve(MaxCanonicalTypeIDs);
for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
TypeSec.emplace_back(Dim, DummyVal);
DummyVal += 0.1f;
}
Sections.push_back(std::move(TypeSec));
// Operands section
std::vector<Embedding> OperandSec;
OperandSec.reserve(MaxOperandKinds);
for (unsigned I = 0; I < MaxOperandKinds; ++I) {
OperandSec.emplace_back(Dim, DummyVal);
DummyVal += 0.1f;
}
Sections.push_back(std::move(OperandSec));
// Predicates section
std::vector<Embedding> PredicateSec;
PredicateSec.reserve(MaxPredicateKinds);
for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
PredicateSec.emplace_back(Dim, DummyVal);
DummyVal += 0.1f;
}
Sections.push_back(std::move(PredicateSec));
return VocabStorage(std::move(Sections));
}
namespace {
using VocabMap = std::map<std::string, Embedding>;
/// Read vocabulary JSON file and populate the section maps.
Error readVocabularyFromFile(StringRef VocabFilePath, VocabMap &OpcVocab,
VocabMap &TypeVocab, VocabMap &ArgVocab) {
auto BufOrError =
MemoryBuffer::getFileOrSTDIN(VocabFilePath, /*IsText=*/true);
if (!BufOrError)
return createFileError(VocabFilePath, BufOrError.getError());
auto Content = BufOrError.get()->getBuffer();
Expected<json::Value> ParsedVocabValue = json::parse(Content);
if (!ParsedVocabValue)
return ParsedVocabValue.takeError();
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
OpcVocab, OpcodeDim))
return Err;
if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
TypeVocab, TypeDim))
return Err;
if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
ArgVocab, ArgDim))
return Err;
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
return createStringError(errc::illegal_byte_sequence,
"Vocabulary sections have different dimensions");
return Error::success();
}
} // anonymous namespace
/// Generate VocabStorage from vocabulary maps.
VocabStorage Vocabulary::buildVocabStorage(const VocabMap &OpcVocab,
const VocabMap &TypeVocab,
const VocabMap &ArgVocab) {
// Helper for handling missing entities in the vocabulary.
// Currently, we use a zero vector. In the future, we will throw an error to
// ensure that *all* known entities are present in the vocabulary.
auto handleMissingEntity = [](const std::string &Val) {
LLVM_DEBUG(errs() << Val
<< " is not in vocabulary, using zero vector; This "
"would result in an error in future.\n");
++VocabMissCounter;
};
unsigned Dim = OpcVocab.begin()->second.size();
assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim));
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
if (It != OpcVocab.end())
NumericOpcodeEmbeddings[Opcode] = It->second;
else
handleMissingEntity(VocabKey.str());
}
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
Embedding(Dim));
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
NumericTypeEmbeddings[CTypeID] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
}
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
Embedding(Dim));
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
auto It = ArgVocab.find(VocabKey.str());
if (It != ArgVocab.end()) {
NumericArgEmbeddings[OpKind] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
}
// Handle Predicates: part of Operands section. We look up predicate keys
// in ArgVocab.
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
Embedding(Dim, 0));
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
StringRef VocabKey =
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
auto It = ArgVocab.find(VocabKey.str());
if (It != ArgVocab.end()) {
NumericPredEmbeddings[PK] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
}
// Create section-based storage instead of flat vocabulary
// Order must match Vocabulary::Section enum
std::vector<std::vector<Embedding>> Sections(4);
Sections[static_cast<unsigned>(Section::Opcodes)] =
std::move(NumericOpcodeEmbeddings); // Section::Opcodes
Sections[static_cast<unsigned>(Section::CanonicalTypes)] =
std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
Sections[static_cast<unsigned>(Section::Operands)] =
std::move(NumericArgEmbeddings); // Section::Operands
Sections[static_cast<unsigned>(Section::Predicates)] =
std::move(NumericPredEmbeddings); // Section::Predicates
// Create VocabStorage from organized sections
return VocabStorage(std::move(Sections));
}
// ==----------------------------------------------------------------------===//
// Vocabulary
//===----------------------------------------------------------------------===//
Expected<Vocabulary> Vocabulary::fromFile(StringRef VocabFilePath,
float OpcWeight, float TypeWeight,
float ArgWeight) {
VocabMap OpcVocab, TypeVocab, ArgVocab;
if (auto Err =
readVocabularyFromFile(VocabFilePath, OpcVocab, TypeVocab, ArgVocab))
return std::move(Err);
// Scale the vocabulary sections based on the provided weights
auto scaleVocabSection = [](VocabMap &Vocab, float Weight) {
for (auto &Entry : Vocab)
Entry.second *= Weight;
};
scaleVocabSection(OpcVocab, OpcWeight);
scaleVocabSection(TypeVocab, TypeWeight);
scaleVocabSection(ArgVocab, ArgWeight);
// Generate the numeric lookup vocabulary
return Vocabulary(buildVocabStorage(OpcVocab, TypeVocab, ArgVocab));
}
// ==----------------------------------------------------------------------===//
// IR2VecVocabAnalysis
//===----------------------------------------------------------------------===//
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
Ctx.emitError("Error reading vocabulary: " + EI.message());
});
}
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// If vocabulary is already populated by the constructor, use it.
if (Vocab.has_value())
return Vocabulary(std::move(Vocab.value()));
// Otherwise, try to read from the vocabulary file specified via CLI.
if (VocabFile.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
"set it using --ir2vec-vocab-path");
return Vocabulary(); // Return invalid result
}
// Use the static factory method to load the vocabulary.
auto VocabOrErr =
Vocabulary::fromFile(VocabFile, OpcWeight, TypeWeight, ArgWeight);
if (!VocabOrErr) {
emitError(VocabOrErr.takeError(), *Ctx);
return Vocabulary();
}
return std::move(*VocabOrErr);
}
// ==----------------------------------------------------------------------===//
// Printer Passes
//===----------------------------------------------------------------------===//
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, Vocabulary);
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
}
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
Emb->getFunctionVector().print(OS);
OS << "Basic block vectors:\n";
for (const BasicBlock &BB : F) {
OS << "Basic block: " << BB.getName() << ":\n";
Emb->getBBVector(BB).print(OS);
}
OS << "Instruction vectors:\n";
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
OS << "Instruction: ";
I.print(OS);
Emb->getInstVector(I).print(OS);
}
}
}
return PreservedAnalyses::all();
}
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
// Print each entry
unsigned Pos = 0;
for (const auto &Entry : IR2VecVocabulary) {
OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
Entry.print(OS);
}
return PreservedAnalyses::all();
}