[mlir][IR] Add variadic getParentOfType overloads (#184071)
Add `getParentOfType` overloads that work with multiple types.
This commit is contained in:
committed by
GitHub
parent
e68f696fda
commit
3f1d968db9
@@ -242,6 +242,14 @@ public:
|
||||
return parentOp;
|
||||
return OpTy();
|
||||
}
|
||||
template <typename... OpTy>
|
||||
std::enable_if_t<(sizeof...(OpTy) > 1), Operation *> getParentOfType() {
|
||||
auto *op = this;
|
||||
while ((op = op->getParentOp()))
|
||||
if (isa<OpTy...>(op))
|
||||
return op;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns the closest surrounding parent operation with trait `Trait`.
|
||||
template <template <typename T> class Trait>
|
||||
|
||||
@@ -210,6 +210,17 @@ public:
|
||||
} while ((region = region->getParentRegion()));
|
||||
return ParentT();
|
||||
}
|
||||
template <typename... ParentT>
|
||||
std::enable_if_t<(sizeof...(ParentT) > 1), Operation *> getParentOfType() {
|
||||
auto *region = this;
|
||||
do {
|
||||
if (!region->container)
|
||||
return nullptr;
|
||||
if (isa<ParentT...>(region->container))
|
||||
return region->container;
|
||||
} while ((region = region->getParentRegion()));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Return the number of this region in the parent operation.
|
||||
unsigned getRegionNumber();
|
||||
|
||||
@@ -1216,13 +1216,7 @@ bool acc::CacheOp::isCacheReadonly() {
|
||||
// It is quite alike acc::getEnclosingComputeOp() utility,
|
||||
// but we cannot use it here.
|
||||
static bool isEnclosedIntoComputeOp(mlir::Operation *op) {
|
||||
mlir::Operation *parentOp = op->getParentOp();
|
||||
while (parentOp) {
|
||||
if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
|
||||
return true;
|
||||
parentOp = parentOp->getParentOp();
|
||||
}
|
||||
return false;
|
||||
return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
|
||||
}
|
||||
|
||||
/// Helper to add an effect on an operand, referenced by its mutable range.
|
||||
@@ -1476,10 +1470,6 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isComputeOperation(Operation *op) {
|
||||
return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to remove operation without region that have constant false `ifCond`
|
||||
/// and remove the condition from the operation if the `ifCond` is a true
|
||||
@@ -4824,10 +4814,8 @@ void RoutineOp::addBindIDName(MLIRContext *context,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult acc::InitOp::verify() {
|
||||
Operation *currOp = *this;
|
||||
while ((currOp = currOp->getParentOp()))
|
||||
if (isComputeOperation(currOp))
|
||||
return emitOpError("cannot be nested in a compute operation");
|
||||
if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
|
||||
return emitOpError("cannot be nested in a compute operation");
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -4846,10 +4834,8 @@ void acc::InitOp::addDeviceType(MLIRContext *context,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult acc::ShutdownOp::verify() {
|
||||
Operation *currOp = *this;
|
||||
while ((currOp = currOp->getParentOp()))
|
||||
if (isComputeOperation(currOp))
|
||||
return emitOpError("cannot be nested in a compute operation");
|
||||
if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
|
||||
return emitOpError("cannot be nested in a compute operation");
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -4868,10 +4854,8 @@ void acc::ShutdownOp::addDeviceType(MLIRContext *context,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult acc::SetOp::verify() {
|
||||
Operation *currOp = *this;
|
||||
while ((currOp = currOp->getParentOp()))
|
||||
if (isComputeOperation(currOp))
|
||||
return emitOpError("cannot be nested in a compute operation");
|
||||
if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
|
||||
return emitOpError("cannot be nested in a compute operation");
|
||||
if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
|
||||
return emitOpError("at least one default_async, device_num, or device_type "
|
||||
"operand must appear");
|
||||
|
||||
@@ -27,14 +27,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
static bool insideAccComputeRegion(mlir::Operation *op) {
|
||||
mlir::Operation *parent{op->getParentOp()};
|
||||
while (parent) {
|
||||
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
|
||||
return true;
|
||||
}
|
||||
parent = parent->getParentOp();
|
||||
}
|
||||
return false;
|
||||
return op->getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
|
||||
}
|
||||
|
||||
static void collectVars(mlir::ValueRange operands,
|
||||
|
||||
@@ -21,13 +21,7 @@
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) {
|
||||
mlir::Operation *parentOp = region.getParentOp();
|
||||
while (parentOp) {
|
||||
if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
|
||||
return parentOp;
|
||||
parentOp = parentOp->getParentOp();
|
||||
}
|
||||
return nullptr;
|
||||
return region.getParentOfType<ACC_COMPUTE_CONSTRUCT_OPS>();
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
|
||||
@@ -430,10 +430,8 @@ void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
|
||||
}
|
||||
|
||||
Operation *mlir::sparse_tensor::getTop(Operation *op) {
|
||||
for (; isa<scf::ForOp>(op->getParentOp()) ||
|
||||
isa<scf::WhileOp>(op->getParentOp()) ||
|
||||
isa<scf::ParallelOp>(op->getParentOp()) ||
|
||||
isa<scf::IfOp>(op->getParentOp());
|
||||
for (; isa<scf::ForOp, scf::WhileOp, scf::ParallelOp, scf::IfOp>(
|
||||
op->getParentOp());
|
||||
op = op->getParentOp())
|
||||
;
|
||||
return op;
|
||||
|
||||
Reference in New Issue
Block a user