Files
llvm-project/mlir/lib/Query/Matcher/MatchFinder.cpp
Denzel-Brian Budii e89bd48c56 [mlir] Avoid crash in mlir-query's MatchFinder class (#145049)
It was failing for cases where the location was not a FileLineColLoc and fileLoc a nullptr.

If the following query is run: 

`match getUsersByPredicate(hasOpName("memref.alloc"),
hasOpName("memref.dealloc"), true)`

on the IR illustrated below, it caused the program to crash.

``` mlir
func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
  %a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
  %b = memref.alloc(%arg2, %arg1) : memref<?x?xf32>
  %c = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
  %d = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
  linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
               outs(%c : memref<?x?xf32>)
  linalg.matmul ins(%a, %b : memref<?x?xf32>, memref<?x?xf32>)
               outs(%d : memref<?x?xf32>)
  memref.dealloc %c : memref<?x?xf32>
  memref.dealloc %b : memref<?x?xf32>
  memref.dealloc %a : memref<?x?xf32>
  memref.dealloc %d : memref<?x?xf32>
  return
}
```
2025-12-30 06:46:11 +02:00

71 lines
2.7 KiB
C++

//===- MatchFinder.cpp - --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the method definitions for the `MatchFinder` class
//
//===----------------------------------------------------------------------===//
#include "mlir/Query/Matcher/MatchFinder.h"
namespace mlir::query::matcher {
MatchFinder::MatchResult::MatchResult(Operation *rootOp,
std::vector<Operation *> matchedOps)
: rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
std::vector<MatchFinder::MatchResult>
MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
std::vector<MatchResult> results;
llvm::SetVector<Operation *> tempStorage;
root->walk([&](Operation *subOp) {
if (matcher.match(subOp)) {
MatchResult match;
match.rootOp = subOp;
match.matchedOps.push_back(subOp);
results.push_back(std::move(match));
} else if (matcher.match(subOp, tempStorage)) {
results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
tempStorage.end()));
}
tempStorage.clear();
});
return results;
}
void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
Operation *op) const {
if (auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>()) {
SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
llvm::SMDiagnostic diag =
qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
diag.print("", os, true, false, true);
}
}
void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
Operation *op, const std::string &binding) const {
if (auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>()) {
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
"\"" + binding + "\" binds here");
}
}
std::vector<Operation *>
MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
std::vector<Operation *> newVector;
for (auto &result : matches) {
newVector.insert(newVector.end(), result.matchedOps.begin(),
result.matchedOps.end());
}
return newVector;
}
} // namespace mlir::query::matcher