It was failing for cases where the location was not a FileLineColLoc and fileLoc a nullptr.
If the following query is run:
`match getUsersByPredicate(hasOpName("memref.alloc"),
hasOpName("memref.dealloc"), true)`
on the IR illustrated below, it caused the program to crash.
``` mlir
func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
%a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
%b = memref.alloc(%arg2, %arg1) : memref<?x?xf32>
%c = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%d = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
outs(%c : memref<?x?xf32>)
linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
outs(%d : memref<?x?xf32>)
memref.dealloc %c : memref<?x?xf32>
memref.dealloc %b : memref<?x?xf32>
memref.dealloc %a : memref<?x?xf32>
memref.dealloc %d : memref<?x?xf32>
return
}
```
71 lines
2.7 KiB
C++
71 lines
2.7 KiB
C++
//===- MatchFinder.cpp - --------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains the method definitions for the `MatchFinder` class
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Query/Matcher/MatchFinder.h"
|
|
namespace mlir::query::matcher {
|
|
|
|
MatchFinder::MatchResult::MatchResult(Operation *rootOp,
|
|
std::vector<Operation *> matchedOps)
|
|
: rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
|
|
|
|
std::vector<MatchFinder::MatchResult>
|
|
MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
|
|
std::vector<MatchResult> results;
|
|
llvm::SetVector<Operation *> tempStorage;
|
|
root->walk([&](Operation *subOp) {
|
|
if (matcher.match(subOp)) {
|
|
MatchResult match;
|
|
match.rootOp = subOp;
|
|
match.matchedOps.push_back(subOp);
|
|
results.push_back(std::move(match));
|
|
} else if (matcher.match(subOp, tempStorage)) {
|
|
results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
|
|
tempStorage.end()));
|
|
}
|
|
tempStorage.clear();
|
|
});
|
|
return results;
|
|
}
|
|
|
|
void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
|
|
Operation *op) const {
|
|
if (auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>()) {
|
|
SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
|
|
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
|
|
llvm::SMDiagnostic diag =
|
|
qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
|
|
diag.print("", os, true, false, true);
|
|
}
|
|
}
|
|
|
|
void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
|
|
Operation *op, const std::string &binding) const {
|
|
if (auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>()) {
|
|
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
|
|
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
|
|
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
|
|
"\"" + binding + "\" binds here");
|
|
}
|
|
}
|
|
|
|
std::vector<Operation *>
|
|
MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
|
|
std::vector<Operation *> newVector;
|
|
for (auto &result : matches) {
|
|
newVector.insert(newVector.end(), result.matchedOps.begin(),
|
|
result.matchedOps.end());
|
|
}
|
|
return newVector;
|
|
}
|
|
|
|
} // namespace mlir::query::matcher
|