[mlir][IR] Add variadic getParentOfType overloads (#184071)

Add `getParentOfType` overloads that work with multiple types.
This commit is contained in:
Matthias Springer
2026-03-03 18:22:48 +02:00
committed by GitHub
parent e68f696fda
commit 3f1d968db9
6 changed files with 30 additions and 42 deletions

View File

@@ -242,6 +242,14 @@ public:
return parentOp;
return OpTy();
}
template <typename... OpTy>
std::enable_if_t<(sizeof...(OpTy) > 1), Operation *> getParentOfType() {
auto *op = this;
while ((op = op->getParentOp()))
if (isa<OpTy...>(op))
return op;
return nullptr;
}
/// Returns the closest surrounding parent operation with trait `Trait`.
template <template <typename T> class Trait>

View File

@@ -210,6 +210,17 @@ public:
} while ((region = region->getParentRegion()));
return ParentT();
}
template <typename... ParentT>
std::enable_if_t<(sizeof...(ParentT) > 1), Operation *> getParentOfType() {
auto *region = this;
do {
if (!region->container)
return nullptr;
if (isa<ParentT...>(region->container))
return region->container;
} while ((region = region->getParentRegion()));
return nullptr;
}
/// Return the number of this region in the parent operation.
unsigned getRegionNumber();

View File

@@ -1216,13 +1216,7 @@ bool acc::CacheOp::isCacheReadonly() {
// It is quite alike acc::getEnclosingComputeOp() utility,
// but we cannot use it here.
static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
mlir::Operation *parentOp = op->getParentOp();
while (parentOp) {
if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
return true;
parentOp = parentOp->getParentOp();
}
return false;
return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
/// Helper to add an effect on an operand, referenced by its mutable range.
@@ -1476,10 +1470,6 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
return success();
}
static bool isComputeOperation(Operation *op) {
return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
}
namespace {
/// Pattern to remove operation without region that have constant false `ifCond`
/// and remove the condition from the operation if the `ifCond` is a true
@@ -4824,10 +4814,8 @@ void RoutineOp::addBindIDName(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::InitOp::verify() {
Operation *currOp = *this;
while ((currOp = currOp->getParentOp()))
if (isComputeOperation(currOp))
return emitOpError("cannot be nested in a compute operation");
if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
return emitOpError("cannot be nested in a compute operation");
return success();
}
@@ -4846,10 +4834,8 @@ void acc::InitOp::addDeviceType(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::ShutdownOp::verify() {
Operation *currOp = *this;
while ((currOp = currOp->getParentOp()))
if (isComputeOperation(currOp))
return emitOpError("cannot be nested in a compute operation");
if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
return emitOpError("cannot be nested in a compute operation");
return success();
}
@@ -4868,10 +4854,8 @@ void acc::ShutdownOp::addDeviceType(MLIRContext *context,
//===----------------------------------------------------------------------===//
LogicalResult acc::SetOp::verify() {
Operation *currOp = *this;
while ((currOp = currOp->getParentOp()))
if (isComputeOperation(currOp))
return emitOpError("cannot be nested in a compute operation");
if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
return emitOpError("cannot be nested in a compute operation");
if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
return emitOpError("at least one default_async, device_num, or device_type "
"operand must appear");

View File

@@ -27,14 +27,7 @@ using namespace mlir;
namespace {
static bool insideAccComputeRegion(mlir::Operation *op) {
mlir::Operation *parent{op->getParentOp()};
while (parent) {
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
return true;
}
parent = parent->getParentOp();
}
return false;
return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
static void collectVars(mlir::ValueRange operands,

View File

@@ -21,13 +21,7 @@
#include "llvm/Support/Casting.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
mlir::Operation *parentOp = region.getParentOp();
while (parentOp) {
if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
return parentOp;
parentOp = parentOp->getParentOp();
}
return nullptr;
return region.getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
}
template <typename OpTy>

View File

@@ -430,10 +430,8 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
}
Operation *mlir::sparse_tensor::getTop(Operation *op) {
for (; isa<scf::ForOp>(op->getParentOp()) ||
isa<scf::WhileOp>(op->getParentOp()) ||
isa<scf::ParallelOp>(op->getParentOp()) ||
isa<scf::IfOp>(op->getParentOp());
for (; isa<scf::ForOp, scf::WhileOp, scf::ParallelOp, scf::IfOp>(
op->getParentOp());
op = op->getParentOp())
;
return op;