diff --git a/flang/docs/OpenMP-declare-target.md b/flang/docs/OpenMP-declare-target.md index ac7fec5e8349..e43a623cc9fb 100644 --- a/flang/docs/OpenMP-declare-target.md +++ b/flang/docs/OpenMP-declare-target.md @@ -177,10 +177,6 @@ Host functions with `target` regions are marked with a `declare target host` attribute so they will be removed after outlining the target regions contained inside. -While this infrastructure could be generally applicable to more than just Flang, -it is only utilised in the Flang frontend, so it resides there rather than in -the OpenMP dialect codebase. - ## Declare Target OpenMP Dialect To LLVM-IR Lowering The OpenMP dialect lowering of `declare target` is done through the diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index 1b7da0da3721..9ec159e1ba1e 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -35,12 +35,6 @@ def MapsForPrivatizedSymbolsPass let dependentDialects = ["mlir::omp::OpenMPDialect"]; } -def MarkDeclareTargetPass - : Pass<"omp-mark-declare-target", "mlir::ModuleOp"> { - let summary = "Marks all functions called by an OpenMP declare target function as declare target"; - let dependentDialects = ["mlir::omp::OpenMPDialect"]; -} - def DeleteUnreachableTargetsPass : Pass<"omp-delete-unreachable-targets", "mlir::ModuleOp"> { let summary = "Deletes OpenMP target operations in unreachable code"; diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index eb4930fb2f6a..db29e93b71da 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -7,7 +7,6 @@ add_flang_library(FlangOpenMPTransforms GenericLoopConversion.cpp MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp - MarkDeclareTarget.cpp DeleteUnreachableTargets.cpp LowerWorkdistribute.cpp LowerWorkshare.cpp @@ -36,6 +35,7 @@ add_flang_library(FlangOpenMPTransforms MLIR_LIBS MLIRFuncDialect MLIROpenMPDialect + MLIROpenMPTransforms MLIRIR MLIRPass MLIRTransformUtils diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 9b73d587ee7b..5927fff96027 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -347,12 +347,11 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm, pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass()); pm.addPass(flangomp::createAutomapToTargetDataPass()); pm.addPass(flangomp::createMapInfoFinalizationPass()); - pm.addPass(flangomp::createMarkDeclareTargetPass()); + pm.addPass(mlir::omp::createMarkDeclareTargetPass()); // Delete unreachable target operations before FunctionFilteringPass // extracts them. pm.addPass(flangomp::createDeleteUnreachableTargetsPass()); - pm.addPass(flangomp::createGenericLoopConversionPass()); if (opts.isTargetDevice) pm.addPass(flangomp::createFunctionFilteringPass()); diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td index 1fde7e08ab43..43d84b7fa4bf 100644 --- a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td @@ -11,6 +11,17 @@ include "mlir/Pass/PassBase.td" +def MarkDeclareTargetPass : Pass<"omp-mark-declare-target", "ModuleOp"> { + let summary = "Marks all functions called by an OpenMP declare target " + "function as declare target"; + let description = [{ + Marks functions contained within the module as declare target if they are + called from within an explicitly marked declare target function or a target + region (omp.target). + }]; + let dependentDialects = ["mlir::omp::OpenMPDialect"]; +} + def PrepareForOMPOffloadPrivatizationPass : Pass<"omp-offload-privatization-prepare", "ModuleOp"> { let summary = "Prepare OpenMP maps for privatization for deferred target tasks"; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt index b9b8eda9ed51..a46924cd9878 100644 --- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIROpenMPTransforms + MarkDeclareTarget.cpp OpenMPOffloadPrivatizationPrepare.cpp DEPENDS diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp similarity index 63% rename from flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp rename to mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp index 5fa77fb2080d..18a36f73edaf 100644 --- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp +++ b/mlir/lib/Dialect/OpenMP/Transforms/MarkDeclareTarget.cpp @@ -1,4 +1,4 @@ -//===- MarkDeclareTarget.cpp -------------------------------------------===// +//===- MarkDeclareTarget.cpp ----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,12 +10,8 @@ // //===----------------------------------------------------------------------===// -#include "flang/Optimizer/OpenMP/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" @@ -23,27 +19,32 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" -namespace flangomp { -#define GEN_PASS_DEF_MARKDECLARETARGETPASS -#include "flang/Optimizer/OpenMP/Passes.h.inc" -} // namespace flangomp +namespace mlir { +namespace omp { +#define GEN_PASS_DEF_MARKDECLARETARGETPASS +#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc" + +} // namespace omp +} // namespace mlir + +using namespace mlir; namespace { + class MarkDeclareTargetPass - : public flangomp::impl::MarkDeclareTargetPassBase { + : public omp::impl::MarkDeclareTargetPassBase { struct ParentInfo { - mlir::omp::DeclareTargetDeviceType devTy; - mlir::omp::DeclareTargetCaptureClause capClause; + omp::DeclareTargetDeviceType devTy; + omp::DeclareTargetCaptureClause capClause; bool automap; }; - void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo, - llvm::SmallPtrSet visited) { - if (auto currFOp = - getOperation().lookupSymbol(symRef)) { - auto current = llvm::dyn_cast( - currFOp.getOperation()); + void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo, + llvm::SmallPtrSet visited) { + if (auto currFOp = getOperation().lookupSymbol(symRef)) { + auto current = + llvm::dyn_cast(currFOp.getOperation()); if (current.isDeclareTarget()) { auto currentDt = current.getDeclareTargetDeviceType(); @@ -51,8 +52,8 @@ class MarkDeclareTargetPass // Found the same function twice, with different device_types, // mark as Any as it belongs to both if (currentDt != parentInfo.devTy && - currentDt != mlir::omp::DeclareTargetDeviceType::any) { - current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, + currentDt != omp::DeclareTargetDeviceType::any) { + current.setDeclareTarget(omp::DeclareTargetDeviceType::any, current.getDeclareTargetCaptureClause(), current.getDeclareTargetAutomap()); } @@ -67,65 +68,63 @@ class MarkDeclareTargetPass void processReductionRefs(std::optional symRefs, ParentInfo parentInfo, - llvm::SmallPtrSet visited) { + llvm::SmallPtrSet visited) { if (!symRefs) return; for (auto symRef : symRefs->getAsRange()) { if (auto declareReductionOp = - getOperation().lookupSymbol( - symRef)) { + getOperation().lookupSymbol(symRef)) { markNestedFuncs(parentInfo, declareReductionOp, visited); } } } - void - processReductionClauses(mlir::Operation *op, ParentInfo parentInfo, - llvm::SmallPtrSet visited) { - llvm::TypeSwitch(*op) - .Case([&](mlir::omp::LoopOp op) { + void processReductionClauses(Operation *op, ParentInfo parentInfo, + llvm::SmallPtrSet visited) { + llvm::TypeSwitch(*op) + .Case([&](omp::LoopOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::ParallelOp op) { + .Case([&](omp::ParallelOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::SectionsOp op) { + .Case([&](omp::SectionsOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::SimdOp op) { + .Case([&](omp::SimdOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::TargetOp op) { + .Case([&](omp::TargetOp op) { processReductionRefs(op.getInReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::TaskgroupOp op) { + .Case([&](omp::TaskgroupOp op) { processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::TaskloopOp op) { + .Case([&](omp::TaskloopOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); processReductionRefs(op.getInReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::TaskOp op) { + .Case([&](omp::TaskOp op) { processReductionRefs(op.getInReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::TeamsOp op) { + .Case([&](omp::TeamsOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); }) - .Case([&](mlir::omp::WsloopOp op) { + .Case([&](omp::WsloopOp op) { processReductionRefs(op.getReductionSyms(), parentInfo, visited); }) - .Default([](mlir::Operation &) {}); + .Default([](Operation &) {}); } - void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp, - llvm::SmallPtrSet visited) { + void markNestedFuncs(ParentInfo parentInfo, Operation *currOp, + llvm::SmallPtrSet visited) { if (visited.contains(currOp)) return; visited.insert(currOp); - currOp->walk([&, this](mlir::Operation *op) { - if (auto callOp = llvm::dyn_cast(op)) { + currOp->walk([&, this](Operation *op) { + if (auto callOp = llvm::dyn_cast(op)) { if (auto symRef = llvm::dyn_cast_if_present( callOp.getCallableForCallee())) { processSymbolRef(symRef, parentInfo, visited); @@ -139,11 +138,11 @@ class MarkDeclareTargetPass // as implicitly declare target if they are called from within an explicitly // marked declare target function or a target region (TargetOp) void runOnOperation() override { - for (auto functionOp : getOperation().getOps()) { - auto declareTargetOp = llvm::dyn_cast( + for (auto functionOp : getOperation().getOps()) { + auto declareTargetOp = llvm::dyn_cast( functionOp.getOperation()); if (declareTargetOp.isDeclareTarget()) { - llvm::SmallPtrSet visited; + llvm::SmallPtrSet visited; ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(), declareTargetOp.getDeclareTargetCaptureClause(), declareTargetOp.getDeclareTargetAutomap()}; @@ -156,11 +155,11 @@ class MarkDeclareTargetPass // when it's lowering has been implemented and change the // DeclareTargetDeviceType argument from nohost to host depending on // the contents of the device clause - getOperation()->walk([&](mlir::omp::TargetOp tarOp) { - llvm::SmallPtrSet visited; + getOperation()->walk([&](omp::TargetOp tarOp) { + llvm::SmallPtrSet visited; ParentInfo parentInfo = { - /*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost, - /*capClause=*/mlir::omp::DeclareTargetCaptureClause::to, + /*devTy=*/omp::DeclareTargetDeviceType::nohost, + /*capClause=*/omp::DeclareTargetCaptureClause::to, /*automap=*/false, }; markNestedFuncs(parentInfo, tarOp, visited);