[flang][MLIR][NFC] Move MarkDeclareTarget pass form flang to MLIR (#181205)

This patch moves the MarkDeclareTarget pass from flang to MLIR since it
will be used in ClangIR as well.
This commit is contained in:
Jan Leyonberg
2026-02-23 10:36:37 -05:00
committed by GitHub
parent 3215645b8d
commit f16162c03c
7 changed files with 62 additions and 62 deletions

View File

@@ -177,10 +177,6 @@ Host functions with `target` regions are marked with a `declare target host`
attribute so they will be removed after outlining the target regions contained
inside.
While this infrastructure could be generally applicable to more than just Flang,
it is only utilised in the Flang frontend, so it resides there rather than in
the OpenMP dialect codebase.
## Declare Target OpenMP Dialect To LLVM-IR Lowering
The OpenMP dialect lowering of `declare target` is done through the

View File

@@ -35,12 +35,6 @@ def MapsForPrivatizedSymbolsPass
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
def MarkDeclareTargetPass
: Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
let summary = "Marks all functions called by an OpenMP declare target function as declare target";
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
def DeleteUnreachableTargetsPass
: Pass<"omp-delete-unreachable-targets", "mlir::ModuleOp"> {
let summary = "Deletes OpenMP target operations in unreachable code";

View File

@@ -7,7 +7,6 @@ add_flang_library(FlangOpenMPTransforms
GenericLoopConversion.cpp
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
DeleteUnreachableTargets.cpp
LowerWorkdistribute.cpp
LowerWorkshare.cpp
@@ -36,6 +35,7 @@ add_flang_library(FlangOpenMPTransforms
MLIR_LIBS
MLIRFuncDialect
MLIROpenMPDialect
MLIROpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils

View File

@@ -347,12 +347,11 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
pm.addPass(flangomp::createAutomapToTargetDataPass());
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
pm.addPass(mlir::omp::createMarkDeclareTargetPass());
// Delete unreachable target operations before FunctionFilteringPass
// extracts them.
pm.addPass(flangomp::createDeleteUnreachableTargetsPass());
pm.addPass(flangomp::createGenericLoopConversionPass());
if (opts.isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());

View File

@@ -11,6 +11,17 @@
include "mlir/Pass/PassBase.td"
def MarkDeclareTargetPass : Pass<"omp-mark-declare-target", "ModuleOp"> {
let summary = "Marks all functions called by an OpenMP declare target "
"function as declare target";
let description = [{
Marks functions contained within the module as declare target if they are
called from within an explicitly marked declare target function or a target
region (omp.target).
}];
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
def PrepareForOMPOffloadPrivatizationPass : Pass<"omp-offload-privatization-prepare", "ModuleOp"> {
let summary = "Prepare OpenMP maps for privatization for deferred target tasks";
let description = [{

View File

@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIROpenMPTransforms
MarkDeclareTarget.cpp
OpenMPOffloadPrivatizationPrepare.cpp
DEPENDS

View File

@@ -1,4 +1,4 @@
//===- MarkDeclareTarget.cpp -------------------------------------------===//
//===- MarkDeclareTarget.cpp ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,12 +10,8 @@
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -23,27 +19,32 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
namespace flangomp {
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp
namespace mlir {
namespace omp {
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
} // namespace omp
} // namespace mlir
using namespace mlir;
namespace {
class MarkDeclareTargetPass
: public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
: public omp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
struct ParentInfo {
mlir::omp::DeclareTargetDeviceType devTy;
mlir::omp::DeclareTargetCaptureClause capClause;
omp::DeclareTargetDeviceType devTy;
omp::DeclareTargetCaptureClause capClause;
bool automap;
};
void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (auto currFOp =
getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
currFOp.getOperation());
void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
llvm::SmallPtrSet<Operation *, 16> visited) {
if (auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) {
auto current =
llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation());
if (current.isDeclareTarget()) {
auto currentDt = current.getDeclareTargetDeviceType();
@@ -51,8 +52,8 @@ class MarkDeclareTargetPass
// Found the same function twice, with different device_types,
// mark as Any as it belongs to both
if (currentDt != parentInfo.devTy &&
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
currentDt != omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause(),
current.getDeclareTargetAutomap());
}
@@ -67,65 +68,63 @@ class MarkDeclareTargetPass
void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
llvm::SmallPtrSet<Operation *, 16> visited) {
if (!symRefs)
return;
for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
if (auto declareReductionOp =
getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>(
symRef)) {
getOperation().lookupSymbol<omp::DeclareReductionOp>(symRef)) {
markNestedFuncs(parentInfo, declareReductionOp, visited);
}
}
}
void
processReductionClauses(mlir::Operation *op, ParentInfo parentInfo,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
llvm::TypeSwitch<mlir::Operation &>(*op)
.Case([&](mlir::omp::LoopOp op) {
void processReductionClauses(Operation *op, ParentInfo parentInfo,
llvm::SmallPtrSet<Operation *, 16> visited) {
llvm::TypeSwitch<Operation &>(*op)
.Case([&](omp::LoopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::ParallelOp op) {
.Case([&](omp::ParallelOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::SectionsOp op) {
.Case([&](omp::SectionsOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::SimdOp op) {
.Case([&](omp::SimdOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TargetOp op) {
.Case([&](omp::TargetOp op) {
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TaskgroupOp op) {
.Case([&](omp::TaskgroupOp op) {
processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TaskloopOp op) {
.Case([&](omp::TaskloopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TaskOp op) {
.Case([&](omp::TaskOp op) {
processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::TeamsOp op) {
.Case([&](omp::TeamsOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Case([&](mlir::omp::WsloopOp op) {
.Case([&](omp::WsloopOp op) {
processReductionRefs(op.getReductionSyms(), parentInfo, visited);
})
.Default([](mlir::Operation &) {});
.Default([](Operation &) {});
}
void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
void markNestedFuncs(ParentInfo parentInfo, Operation *currOp,
llvm::SmallPtrSet<Operation *, 16> visited) {
if (visited.contains(currOp))
return;
visited.insert(currOp);
currOp->walk([&, this](mlir::Operation *op) {
if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
currOp->walk([&, this](Operation *op) {
if (auto callOp = llvm::dyn_cast<CallOpInterface>(op)) {
if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
callOp.getCallableForCallee())) {
processSymbolRef(symRef, parentInfo, visited);
@@ -139,11 +138,11 @@ class MarkDeclareTargetPass
// as implicitly declare target if they are called from within an explicitly
// marked declare target function or a target region (TargetOp)
void runOnOperation() override {
for (auto functionOp : getOperation().getOps<mlir::func::FuncOp>()) {
auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
for (auto functionOp : getOperation().getOps<func::FuncOp>()) {
auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>(
functionOp.getOperation());
if (declareTargetOp.isDeclareTarget()) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
llvm::SmallPtrSet<Operation *, 16> visited;
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
declareTargetOp.getDeclareTargetAutomap()};
@@ -156,11 +155,11 @@ class MarkDeclareTargetPass
// when it's lowering has been implemented and change the
// DeclareTargetDeviceType argument from nohost to host depending on
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
getOperation()->walk([&](omp::TargetOp tarOp) {
llvm::SmallPtrSet<Operation *, 16> visited;
ParentInfo parentInfo = {
/*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
/*capClause=*/mlir::omp::DeclareTargetCaptureClause::to,
/*devTy=*/omp::DeclareTargetDeviceType::nohost,
/*capClause=*/omp::DeclareTargetCaptureClause::to,
/*automap=*/false,
};
markNestedFuncs(parentInfo, tarOp, visited);