The function instructionsWithoutDebug serves two uses: skipping debug intrinsics and skipping pseudo instructions. Nonetheless, these functions are expensive due to out-of-line filtering using std::function. Ideally, the filter should be inlined, but that would require including IntrinsicInst.h in BasicBlock.h. We no longer use debug intrinsics, so the first use (parameter false) is no longer needed. The second use is sometimes needed, but the distinction between PseudoProbe instructions can be made at the call sites more easily in many cases. Therefore, remove instructionsWithoutDebug/sizeWithoutDebug. c-t-t stage2-O3 -0.21%.
718 lines
26 KiB
C++
718 lines
26 KiB
C++
//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
|
// Exceptions. See the LICENSE file for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
///
|
|
/// \file
|
|
/// This file implements the IR2Vec algorithm.
|
|
///
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Analysis/IR2Vec.h"
|
|
|
|
#include "llvm/ADT/DepthFirstIterator.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/Statistic.h"
|
|
#include "llvm/IR/CFG.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/PassManager.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/Errc.h"
|
|
#include "llvm/Support/Error.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include "llvm/Support/Format.h"
|
|
#include "llvm/Support/MemoryBuffer.h"
|
|
|
|
using namespace llvm;
|
|
using namespace ir2vec;
|
|
|
|
#define DEBUG_TYPE "ir2vec"
|
|
|
|
STATISTIC(VocabMissCounter,
|
|
"Number of lookups to entities not present in the vocabulary");
|
|
|
|
namespace llvm {
|
|
namespace ir2vec {
|
|
cl::OptionCategory IR2VecCategory("IR2Vec Options");
|
|
|
|
// FIXME: Use a default vocab when not specified
|
|
cl::opt<std::string>
|
|
VocabFile("ir2vec-vocab-path", cl::Optional,
|
|
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
|
|
cl::cat(IR2VecCategory));
|
|
cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
|
|
cl::desc("Weight for opcode embeddings"),
|
|
cl::cat(IR2VecCategory));
|
|
cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
|
|
cl::desc("Weight for type embeddings"),
|
|
cl::cat(IR2VecCategory));
|
|
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
|
|
cl::desc("Weight for argument embeddings"),
|
|
cl::cat(IR2VecCategory));
|
|
cl::opt<IR2VecKind> IR2VecEmbeddingKind(
|
|
"ir2vec-kind", cl::Optional,
|
|
cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
|
|
"Generate symbolic embeddings"),
|
|
clEnumValN(IR2VecKind::FlowAware, "flow-aware",
|
|
"Generate flow-aware embeddings")),
|
|
cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
|
|
cl::cat(IR2VecCategory));
|
|
|
|
} // namespace ir2vec
|
|
} // namespace llvm
|
|
|
|
AnalysisKey IR2VecVocabAnalysis::Key;
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// Local helper functions
|
|
//===----------------------------------------------------------------------===//
|
|
namespace llvm::json {
|
|
inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
|
|
llvm::json::Path P) {
|
|
std::vector<double> TempOut;
|
|
if (!llvm::json::fromJSON(E, TempOut, P))
|
|
return false;
|
|
Out = Embedding(std::move(TempOut));
|
|
return true;
|
|
}
|
|
} // namespace llvm::json
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// Embedding
|
|
//===----------------------------------------------------------------------===//
|
|
Embedding &Embedding::operator+=(const Embedding &RHS) {
|
|
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
|
|
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
|
|
std::plus<double>());
|
|
return *this;
|
|
}
|
|
|
|
Embedding Embedding::operator+(const Embedding &RHS) const {
|
|
Embedding Result(*this);
|
|
Result += RHS;
|
|
return Result;
|
|
}
|
|
|
|
Embedding &Embedding::operator-=(const Embedding &RHS) {
|
|
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
|
|
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
|
|
std::minus<double>());
|
|
return *this;
|
|
}
|
|
|
|
Embedding Embedding::operator-(const Embedding &RHS) const {
|
|
Embedding Result(*this);
|
|
Result -= RHS;
|
|
return Result;
|
|
}
|
|
|
|
Embedding &Embedding::operator*=(double Factor) {
|
|
std::transform(this->begin(), this->end(), this->begin(),
|
|
[Factor](double Elem) { return Elem * Factor; });
|
|
return *this;
|
|
}
|
|
|
|
Embedding Embedding::operator*(double Factor) const {
|
|
Embedding Result(*this);
|
|
Result *= Factor;
|
|
return Result;
|
|
}
|
|
|
|
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
|
|
assert(this->size() == Src.size() && "Vectors must have the same dimension");
|
|
for (size_t Itr = 0; Itr < this->size(); ++Itr)
|
|
(*this)[Itr] += Src[Itr] * Factor;
|
|
return *this;
|
|
}
|
|
|
|
bool Embedding::approximatelyEquals(const Embedding &RHS,
|
|
double Tolerance) const {
|
|
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
|
|
for (size_t Itr = 0; Itr < this->size(); ++Itr)
|
|
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
|
|
LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
|
|
<< (*this)[Itr] << " vs " << RHS[Itr]
|
|
<< "; Tolerance: " << Tolerance << "\n");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void Embedding::print(raw_ostream &OS) const {
|
|
OS << " [";
|
|
for (const auto &Elem : Data)
|
|
OS << " " << format("%.2f", Elem) << " ";
|
|
OS << "]\n";
|
|
}
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// Embedder and its subclasses
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
|
|
const Vocabulary &Vocab) {
|
|
switch (Mode) {
|
|
case IR2VecKind::Symbolic:
|
|
return std::make_unique<SymbolicEmbedder>(F, Vocab);
|
|
case IR2VecKind::FlowAware:
|
|
return std::make_unique<FlowAwareEmbedder>(F, Vocab);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
Embedding Embedder::computeEmbeddings() const {
|
|
Embedding FuncVector(Dimension, 0.0);
|
|
|
|
if (F.isDeclaration())
|
|
return FuncVector;
|
|
|
|
// Consider only the basic blocks that are reachable from entry
|
|
for (const BasicBlock *BB : depth_first(&F))
|
|
FuncVector += computeEmbeddings(*BB);
|
|
return FuncVector;
|
|
}
|
|
|
|
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
|
|
Embedding BBVector(Dimension, 0);
|
|
|
|
// We consider only the non-debug and non-pseudo instructions
|
|
for (const auto &I : BB)
|
|
if (!I.isDebugOrPseudoInst())
|
|
BBVector += computeEmbeddings(I);
|
|
return BBVector;
|
|
}
|
|
|
|
Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
|
|
// Currently, we always (re)compute the embeddings for symbolic embedder.
|
|
// This is cheaper than caching the vectors.
|
|
Embedding ArgEmb(Dimension, 0);
|
|
for (const auto &Op : I.operands())
|
|
ArgEmb += Vocab[*Op];
|
|
auto InstVector =
|
|
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
|
|
if (const auto *IC = dyn_cast<CmpInst>(&I))
|
|
InstVector += Vocab[IC->getPredicate()];
|
|
return InstVector;
|
|
}
|
|
|
|
Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
|
|
// If we have already computed the embedding for this instruction, return it
|
|
auto It = InstVecMap.find(&I);
|
|
if (It != InstVecMap.end())
|
|
return It->second;
|
|
|
|
// TODO: Handle call instructions differently.
|
|
// For now, we treat them like other instructions
|
|
Embedding ArgEmb(Dimension, 0);
|
|
for (const auto &Op : I.operands()) {
|
|
// If the operand is defined elsewhere, we use its embedding
|
|
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
|
|
auto DefIt = InstVecMap.find(DefInst);
|
|
// Fixme (#159171): Ideally we should never miss an instruction
|
|
// embedding here.
|
|
// But when we have cyclic dependencies (e.g., phi
|
|
// nodes), we might miss the embedding. In such cases, we fall back to
|
|
// using the vocabulary embedding. This can be fixed by iterating to a
|
|
// fixed-point, or by using a simple solver for the set of simultaneous
|
|
// equations.
|
|
// Another case when we might miss an instruction embedding is when
|
|
// the operand instruction is in a different basic block that has not
|
|
// been processed yet. This can be fixed by processing the basic blocks
|
|
// in a topological order.
|
|
if (DefIt != InstVecMap.end())
|
|
ArgEmb += DefIt->second;
|
|
else
|
|
ArgEmb += Vocab[*Op];
|
|
}
|
|
// If the operand is not defined by an instruction, we use the
|
|
// vocabulary
|
|
else {
|
|
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
|
|
<< *Op << "=" << Vocab[*Op][0] << "\n");
|
|
ArgEmb += Vocab[*Op];
|
|
}
|
|
}
|
|
// Create the instruction vector by combining opcode, type, and arguments
|
|
// embeddings
|
|
auto InstVector =
|
|
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
|
|
if (const auto *IC = dyn_cast<CmpInst>(&I))
|
|
InstVector += Vocab[IC->getPredicate()];
|
|
InstVecMap[&I] = InstVector;
|
|
return InstVector;
|
|
}
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// VocabStorage
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
|
|
: Sections(std::move(SectionData)), TotalSize([&] {
|
|
assert(!Sections.empty() && "Vocabulary has no sections");
|
|
// Compute total size across all sections
|
|
size_t Size = 0;
|
|
for (const auto &Section : Sections) {
|
|
assert(!Section.empty() && "Vocabulary section is empty");
|
|
Size += Section.size();
|
|
}
|
|
return Size;
|
|
}()),
|
|
Dimension([&] {
|
|
// Get dimension from the first embedding in the first section - all
|
|
// embeddings must have the same dimension
|
|
assert(!Sections.empty() && "Vocabulary has no sections");
|
|
assert(!Sections[0].empty() && "First section of vocabulary is empty");
|
|
unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());
|
|
|
|
// Verify that all embeddings across all sections have the same
|
|
// dimension
|
|
[[maybe_unused]] auto allSameDim =
|
|
[ExpectedDim](const std::vector<Embedding> &Section) {
|
|
return std::all_of(Section.begin(), Section.end(),
|
|
[ExpectedDim](const Embedding &Emb) {
|
|
return Emb.size() == ExpectedDim;
|
|
});
|
|
};
|
|
assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
|
|
"All embeddings must have the same dimension");
|
|
|
|
return ExpectedDim;
|
|
}()) {}
|
|
|
|
const Embedding &VocabStorage::const_iterator::operator*() const {
|
|
assert(SectionId < Storage->Sections.size() && "Invalid section ID");
|
|
assert(LocalIndex < Storage->Sections[SectionId].size() &&
|
|
"Local index out of range");
|
|
return Storage->Sections[SectionId][LocalIndex];
|
|
}
|
|
|
|
VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() {
|
|
++LocalIndex;
|
|
// Check if we need to move to the next section
|
|
if (SectionId < Storage->getNumSections() &&
|
|
LocalIndex >= Storage->Sections[SectionId].size()) {
|
|
assert(LocalIndex == Storage->Sections[SectionId].size() &&
|
|
"Local index should be at the end of the current section");
|
|
LocalIndex = 0;
|
|
++SectionId;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
bool VocabStorage::const_iterator::operator==(
|
|
const const_iterator &Other) const {
|
|
return Storage == Other.Storage && SectionId == Other.SectionId &&
|
|
LocalIndex == Other.LocalIndex;
|
|
}
|
|
|
|
bool VocabStorage::const_iterator::operator!=(
|
|
const const_iterator &Other) const {
|
|
return !(*this == Other);
|
|
}
|
|
|
|
Error VocabStorage::parseVocabSection(StringRef Key,
|
|
const json::Value &ParsedVocabValue,
|
|
VocabMap &TargetVocab, unsigned &Dim) {
|
|
json::Path::Root Path("");
|
|
const json::Object *RootObj = ParsedVocabValue.getAsObject();
|
|
if (!RootObj)
|
|
return createStringError(errc::invalid_argument,
|
|
"JSON root is not an object");
|
|
|
|
const json::Value *SectionValue = RootObj->get(Key);
|
|
if (!SectionValue)
|
|
return createStringError(errc::invalid_argument,
|
|
"Missing '" + std::string(Key) +
|
|
"' section in vocabulary file");
|
|
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
|
|
return createStringError(errc::illegal_byte_sequence,
|
|
"Unable to parse '" + std::string(Key) +
|
|
"' section from vocabulary");
|
|
|
|
Dim = TargetVocab.begin()->second.size();
|
|
if (Dim == 0)
|
|
return createStringError(errc::illegal_byte_sequence,
|
|
"Dimension of '" + std::string(Key) +
|
|
"' section of the vocabulary is zero");
|
|
|
|
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
|
|
[Dim](const std::pair<StringRef, Embedding> &Entry) {
|
|
return Entry.second.size() == Dim;
|
|
}))
|
|
return createStringError(
|
|
errc::illegal_byte_sequence,
|
|
"All vectors in the '" + std::string(Key) +
|
|
"' section of the vocabulary are not of the same dimension");
|
|
|
|
return Error::success();
|
|
}
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// Vocabulary
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
|
|
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
|
|
#define HANDLE_INST(NUM, OPCODE, CLASS) \
|
|
if (Opcode == NUM) { \
|
|
return #OPCODE; \
|
|
}
|
|
#include "llvm/IR/Instruction.def"
|
|
#undef HANDLE_INST
|
|
return "UnknownOpcode";
|
|
}
|
|
|
|
// Helper function to classify an operand into OperandKind
|
|
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
|
|
if (isa<Function>(Op))
|
|
return OperandKind::FunctionID;
|
|
if (isa<PointerType>(Op->getType()))
|
|
return OperandKind::PointerID;
|
|
if (isa<Constant>(Op))
|
|
return OperandKind::ConstantID;
|
|
return OperandKind::VariableID;
|
|
}
|
|
|
|
unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
|
|
if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE)
|
|
return P - CmpInst::FIRST_FCMP_PREDICATE;
|
|
else
|
|
return P - CmpInst::FIRST_ICMP_PREDICATE +
|
|
(CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1);
|
|
}
|
|
|
|
CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
|
|
unsigned fcmpRange =
|
|
CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1;
|
|
if (LocalIndex < fcmpRange)
|
|
return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE +
|
|
LocalIndex);
|
|
else
|
|
return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE +
|
|
LocalIndex - fcmpRange);
|
|
}
|
|
|
|
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
|
|
static SmallString<16> PredNameBuffer;
|
|
if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
|
|
PredNameBuffer = "FCMP_";
|
|
else
|
|
PredNameBuffer = "ICMP_";
|
|
PredNameBuffer += CmpInst::getPredicateName(Pred);
|
|
return PredNameBuffer;
|
|
}
|
|
|
|
StringRef Vocabulary::getStringKey(unsigned Pos) {
|
|
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
|
|
// Opcode
|
|
if (Pos < MaxOpcodes)
|
|
return getVocabKeyForOpcode(Pos + 1);
|
|
// Type
|
|
if (Pos < OperandBaseOffset)
|
|
return getVocabKeyForCanonicalTypeID(
|
|
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
|
|
// Operand
|
|
if (Pos < PredicateBaseOffset)
|
|
return getVocabKeyForOperandKind(
|
|
static_cast<OperandKind>(Pos - OperandBaseOffset));
|
|
// Predicates
|
|
return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
|
|
}
|
|
|
|
// For now, assume vocabulary is stable unless explicitly invalidated.
|
|
bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
|
|
ModuleAnalysisManager::Invalidator &Inv) const {
|
|
auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
|
|
return !(PAC.preservedWhenStateless());
|
|
}
|
|
|
|
VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
|
|
float DummyVal = 0.1f;
|
|
|
|
// Create sections for opcodes, types, operands, and predicates
|
|
// Order must match Vocabulary::Section enum
|
|
std::vector<std::vector<Embedding>> Sections;
|
|
Sections.reserve(4);
|
|
|
|
// Opcodes section
|
|
std::vector<Embedding> OpcodeSec;
|
|
OpcodeSec.reserve(MaxOpcodes);
|
|
for (unsigned I = 0; I < MaxOpcodes; ++I) {
|
|
OpcodeSec.emplace_back(Dim, DummyVal);
|
|
DummyVal += 0.1f;
|
|
}
|
|
Sections.push_back(std::move(OpcodeSec));
|
|
|
|
// Types section
|
|
std::vector<Embedding> TypeSec;
|
|
TypeSec.reserve(MaxCanonicalTypeIDs);
|
|
for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
|
|
TypeSec.emplace_back(Dim, DummyVal);
|
|
DummyVal += 0.1f;
|
|
}
|
|
Sections.push_back(std::move(TypeSec));
|
|
|
|
// Operands section
|
|
std::vector<Embedding> OperandSec;
|
|
OperandSec.reserve(MaxOperandKinds);
|
|
for (unsigned I = 0; I < MaxOperandKinds; ++I) {
|
|
OperandSec.emplace_back(Dim, DummyVal);
|
|
DummyVal += 0.1f;
|
|
}
|
|
Sections.push_back(std::move(OperandSec));
|
|
|
|
// Predicates section
|
|
std::vector<Embedding> PredicateSec;
|
|
PredicateSec.reserve(MaxPredicateKinds);
|
|
for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
|
|
PredicateSec.emplace_back(Dim, DummyVal);
|
|
DummyVal += 0.1f;
|
|
}
|
|
Sections.push_back(std::move(PredicateSec));
|
|
|
|
return VocabStorage(std::move(Sections));
|
|
}
|
|
|
|
namespace {
|
|
using VocabMap = std::map<std::string, Embedding>;
|
|
|
|
/// Read vocabulary JSON file and populate the section maps.
|
|
Error readVocabularyFromFile(StringRef VocabFilePath, VocabMap &OpcVocab,
|
|
VocabMap &TypeVocab, VocabMap &ArgVocab) {
|
|
auto BufOrError =
|
|
MemoryBuffer::getFileOrSTDIN(VocabFilePath, /*IsText=*/true);
|
|
if (!BufOrError)
|
|
return createFileError(VocabFilePath, BufOrError.getError());
|
|
|
|
auto Content = BufOrError.get()->getBuffer();
|
|
|
|
Expected<json::Value> ParsedVocabValue = json::parse(Content);
|
|
if (!ParsedVocabValue)
|
|
return ParsedVocabValue.takeError();
|
|
|
|
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
|
|
if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
|
|
OpcVocab, OpcodeDim))
|
|
return Err;
|
|
|
|
if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
|
|
TypeVocab, TypeDim))
|
|
return Err;
|
|
|
|
if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
|
|
ArgVocab, ArgDim))
|
|
return Err;
|
|
|
|
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
|
|
return createStringError(errc::illegal_byte_sequence,
|
|
"Vocabulary sections have different dimensions");
|
|
|
|
return Error::success();
|
|
}
|
|
} // anonymous namespace
|
|
|
|
/// Generate VocabStorage from vocabulary maps.
|
|
VocabStorage Vocabulary::buildVocabStorage(const VocabMap &OpcVocab,
|
|
const VocabMap &TypeVocab,
|
|
const VocabMap &ArgVocab) {
|
|
|
|
// Helper for handling missing entities in the vocabulary.
|
|
// Currently, we use a zero vector. In the future, we will throw an error to
|
|
// ensure that *all* known entities are present in the vocabulary.
|
|
auto handleMissingEntity = [](const std::string &Val) {
|
|
LLVM_DEBUG(errs() << Val
|
|
<< " is not in vocabulary, using zero vector; This "
|
|
"would result in an error in future.\n");
|
|
++VocabMissCounter;
|
|
};
|
|
|
|
unsigned Dim = OpcVocab.begin()->second.size();
|
|
assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
|
|
|
|
// Handle Opcodes
|
|
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
|
|
Embedding(Dim));
|
|
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
|
|
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
|
|
auto It = OpcVocab.find(VocabKey.str());
|
|
if (It != OpcVocab.end())
|
|
NumericOpcodeEmbeddings[Opcode] = It->second;
|
|
else
|
|
handleMissingEntity(VocabKey.str());
|
|
}
|
|
|
|
// Handle Types - only canonical types are present in vocabulary
|
|
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
|
|
Embedding(Dim));
|
|
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
|
|
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
|
|
static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
|
|
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
|
|
NumericTypeEmbeddings[CTypeID] = It->second;
|
|
continue;
|
|
}
|
|
handleMissingEntity(VocabKey.str());
|
|
}
|
|
|
|
// Handle Arguments/Operands
|
|
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
|
|
Embedding(Dim));
|
|
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
|
|
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
|
|
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
|
|
auto It = ArgVocab.find(VocabKey.str());
|
|
if (It != ArgVocab.end()) {
|
|
NumericArgEmbeddings[OpKind] = It->second;
|
|
continue;
|
|
}
|
|
handleMissingEntity(VocabKey.str());
|
|
}
|
|
|
|
// Handle Predicates: part of Operands section. We look up predicate keys
|
|
// in ArgVocab.
|
|
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
|
|
Embedding(Dim, 0));
|
|
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
|
|
StringRef VocabKey =
|
|
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
|
|
auto It = ArgVocab.find(VocabKey.str());
|
|
if (It != ArgVocab.end()) {
|
|
NumericPredEmbeddings[PK] = It->second;
|
|
continue;
|
|
}
|
|
handleMissingEntity(VocabKey.str());
|
|
}
|
|
|
|
// Create section-based storage instead of flat vocabulary
|
|
// Order must match Vocabulary::Section enum
|
|
std::vector<std::vector<Embedding>> Sections(4);
|
|
Sections[static_cast<unsigned>(Section::Opcodes)] =
|
|
std::move(NumericOpcodeEmbeddings); // Section::Opcodes
|
|
Sections[static_cast<unsigned>(Section::CanonicalTypes)] =
|
|
std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
|
|
Sections[static_cast<unsigned>(Section::Operands)] =
|
|
std::move(NumericArgEmbeddings); // Section::Operands
|
|
Sections[static_cast<unsigned>(Section::Predicates)] =
|
|
std::move(NumericPredEmbeddings); // Section::Predicates
|
|
|
|
// Create VocabStorage from organized sections
|
|
return VocabStorage(std::move(Sections));
|
|
}
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// Vocabulary
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Expected<Vocabulary> Vocabulary::fromFile(StringRef VocabFilePath,
|
|
float OpcWeight, float TypeWeight,
|
|
float ArgWeight) {
|
|
VocabMap OpcVocab, TypeVocab, ArgVocab;
|
|
if (auto Err =
|
|
readVocabularyFromFile(VocabFilePath, OpcVocab, TypeVocab, ArgVocab))
|
|
return std::move(Err);
|
|
|
|
// Scale the vocabulary sections based on the provided weights
|
|
auto scaleVocabSection = [](VocabMap &Vocab, float Weight) {
|
|
for (auto &Entry : Vocab)
|
|
Entry.second *= Weight;
|
|
};
|
|
scaleVocabSection(OpcVocab, OpcWeight);
|
|
scaleVocabSection(TypeVocab, TypeWeight);
|
|
scaleVocabSection(ArgVocab, ArgWeight);
|
|
|
|
// Generate the numeric lookup vocabulary
|
|
return Vocabulary(buildVocabStorage(OpcVocab, TypeVocab, ArgVocab));
|
|
}
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// IR2VecVocabAnalysis
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
|
|
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
|
|
Ctx.emitError("Error reading vocabulary: " + EI.message());
|
|
});
|
|
}
|
|
|
|
IR2VecVocabAnalysis::Result
|
|
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
|
|
auto Ctx = &M.getContext();
|
|
// If vocabulary is already populated by the constructor, use it.
|
|
if (Vocab.has_value())
|
|
return Vocabulary(std::move(Vocab.value()));
|
|
|
|
// Otherwise, try to read from the vocabulary file specified via CLI.
|
|
if (VocabFile.empty()) {
|
|
// FIXME: Use default vocabulary
|
|
Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
|
|
"set it using --ir2vec-vocab-path");
|
|
return Vocabulary(); // Return invalid result
|
|
}
|
|
|
|
// Use the static factory method to load the vocabulary.
|
|
auto VocabOrErr =
|
|
Vocabulary::fromFile(VocabFile, OpcWeight, TypeWeight, ArgWeight);
|
|
if (!VocabOrErr) {
|
|
emitError(VocabOrErr.takeError(), *Ctx);
|
|
return Vocabulary();
|
|
}
|
|
|
|
return std::move(*VocabOrErr);
|
|
}
|
|
|
|
// ==----------------------------------------------------------------------===//
|
|
// Printer Passes
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
|
|
ModuleAnalysisManager &MAM) {
|
|
auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
|
|
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
|
|
|
|
for (Function &F : M) {
|
|
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, Vocabulary);
|
|
if (!Emb) {
|
|
OS << "Error creating IR2Vec embeddings \n";
|
|
continue;
|
|
}
|
|
|
|
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
|
|
OS << "Function vector: ";
|
|
Emb->getFunctionVector().print(OS);
|
|
|
|
OS << "Basic block vectors:\n";
|
|
for (const BasicBlock &BB : F) {
|
|
OS << "Basic block: " << BB.getName() << ":\n";
|
|
Emb->getBBVector(BB).print(OS);
|
|
}
|
|
|
|
OS << "Instruction vectors:\n";
|
|
for (const BasicBlock &BB : F) {
|
|
for (const Instruction &I : BB) {
|
|
OS << "Instruction: ";
|
|
I.print(OS);
|
|
Emb->getInstVector(I).print(OS);
|
|
}
|
|
}
|
|
}
|
|
return PreservedAnalyses::all();
|
|
}
|
|
|
|
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
|
|
ModuleAnalysisManager &MAM) {
|
|
auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
|
|
assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
|
|
|
|
// Print each entry
|
|
unsigned Pos = 0;
|
|
for (const auto &Entry : IR2VecVocabulary) {
|
|
OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
|
|
Entry.print(OS);
|
|
}
|
|
return PreservedAnalyses::all();
|
|
}
|