Files
llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
Mehdi Amini ee8259dcca [mlir][sparse] Fix use-after-free crash in SparseSpaceCollapsePass (#184001)
The `SparseSpaceCollapsePass::runOnOperation()` was calling
`collapseSparseSpace()` inside the `func->walk(...)` callback. The
collapse erases the outer `IterateOp` (and all its nested ops including
inner `ExtractIterSpaceOp`s), while the walk iterator still holds a
reference to one of those nested ops as its current position. The walk
then tries to advance to the next op, which has already been freed,
causing a segfault.

Fix this by collecting all collapsable groups during the walk without
mutating the IR, then processing them after the walk completes.

Fixes #130021
2026-03-03 12:43:04 +01:00

208 lines
6.6 KiB
C++

//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
namespace mlir {
#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
#define DEBUG_TYPE "sparse-space-collapse"
using namespace mlir;
using namespace sparse_tensor;
namespace {
struct CollapseSpaceInfo {
ExtractIterSpaceOp space;
IterateOp loop;
};
bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
auto pIterArgs = parent.getRegionIterArgs();
auto nInitArgs = node.getInits();
if (pIterArgs.size() != nInitArgs.size())
return false;
// Two loops are collapsable if they are perfectly nested.
auto pYields = parent.getYieldedValues();
auto nResult = node.getLoopResults().value();
bool yieldEq =
llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
return std::get<0>(zipped) == std::get<1>(zipped);
});
// Parent iter_args should be passed directly to the node's init_args.
bool iterArgEq =
llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
return std::get<0>(zipped) == std::get<1>(zipped);
});
return yieldEq && iterArgEq;
}
bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
ExtractIterSpaceOp curSpace) {
auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
Value spaceVal = space.getExtractedSpace();
if (spaceVal.hasOneUse())
return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
return nullptr;
};
if (toCollapse.empty()) {
// Collapse root.
if (auto itOp = getIterateOpOverSpace(curSpace)) {
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
info.loop = itOp;
return true;
}
return false;
}
auto parent = toCollapse.back().space;
auto pItOp = toCollapse.back().loop;
auto nItOp = getIterateOpOverSpace(curSpace);
// Can only collapse spaces extracted from the same tensor.
if (parent.getTensor() != curSpace.getTensor()) {
LLVM_DEBUG({
llvm::dbgs()
<< "failed to collpase spaces extracted from different tensors.";
});
return false;
}
// Can only collapse consecutive simple iteration on one tensor (i.e., no
// coiteration).
if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
pItOp.getIterator() != curSpace.getParentIter() ||
curSpace->getParentOp() != pItOp.getOperation()) {
LLVM_DEBUG(
{ llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
return false;
}
if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
LLVM_DEBUG({
llvm::dbgs()
<< "failed to collapse IterateOps that are not perfectly nested.";
});
return false;
}
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
info.loop = nItOp;
return true;
}
void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
if (toCollapse.size() < 2)
return;
ExtractIterSpaceOp root = toCollapse.front().space;
ExtractIterSpaceOp leaf = toCollapse.back().space;
Location loc = root.getLoc();
assert(root->hasOneUse() && leaf->hasOneUse());
// Insert collapsed operation at the same scope as root operation.
OpBuilder builder(root);
// Construct the collapsed iteration space.
auto collapsedSpace = ExtractIterSpaceOp::create(
builder, loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
leaf.getHiLvl());
auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
auto innermost = toCollapse.back().loop;
IRMapping mapper;
mapper.map(leaf, collapsedSpace.getExtractedSpace());
for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
mapper.map(std::get<0>(z), std::get<1>(z));
auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
builder.setInsertionPointToStart(cloned.getBody());
I64BitSet crdUsedLvls;
unsigned shift = 0, argIdx = 1;
for (auto info : toCollapse.drop_back()) {
I64BitSet set = info.loop.getCrdUsedLvls();
crdUsedLvls |= set.lshift(shift);
shift += info.loop.getSpaceDim();
for (BlockArgument crd : info.loop.getCrds()) {
BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
argIdx++, builder.getIndexType(), crd.getLoc());
crd.replaceAllUsesWith(collapsedCrd);
}
}
crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
cloned.setCrdUsedLvls(crdUsedLvls);
rItOp.replaceAllUsesWith(cloned.getResults());
// Erase collapsed loops.
rItOp.erase();
root.erase();
}
struct SparseSpaceCollapsePass
: public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
SparseSpaceCollapsePass() = default;
void runOnOperation() override {
func::FuncOp func = getOperation();
// A naive (experimental) implementation to collapse consecutive sparse
// spaces. It does NOT handle complex cases where multiple spaces are
// extracted in the same basic block. E.g.,
//
// %space1 = extract_space %t1 ...
// %space2 = extract_space %t2 ...
// sparse_tensor.iterate(%sp1) ...
//
// Collect all groups to collapse before performing any IR mutations.
// Mutating (erasing) ops during the walk would invalidate the walk's
// internal iterator and cause use-after-free crashes.
SmallVector<SmallVector<CollapseSpaceInfo>> groups;
SmallVector<CollapseSpaceInfo> toCollapse;
func->walk([&](ExtractIterSpaceOp op) {
if (!legalToCollapse(toCollapse, op)) {
// Save the current group and start a new one.
groups.push_back(std::move(toCollapse));
toCollapse.clear();
// Try to start a new group with the current op.
legalToCollapse(toCollapse, op);
}
});
groups.push_back(std::move(toCollapse));
// Apply all collapse transformations after the walk is complete.
for (auto &group : groups)
collapseSparseSpace(group);
}
};
} // namespace
std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
return std::make_unique<SparseSpaceCollapsePass>();
}