[CIR][MLIR][OpenMP] Enable the MarkDeclareTarget pass for ClangIR (#189420)
This patch enables the MarkDeclareTarget for CIR by adding the pass to the lowerings and attaching the declare target interface to the cir::FuncOp. The MarkDeclareTarget is also generalized to work on the FunctionOpInterface instead of func::Op since it needs to be able to handle cir::FuncOp as well. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
|
||||
#define CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
} // namespace mlir
|
||||
|
||||
namespace cir::omp {
|
||||
|
||||
void registerOpenMPExtensions(mlir::DialectRegistry ®istry);
|
||||
|
||||
} // namespace cir::omp
|
||||
|
||||
#endif // CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "clang/CIR/CIRGenerator.h"
|
||||
#include "clang/CIR/Dialect/IR/CIRDialect.h"
|
||||
#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
|
||||
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
|
||||
#include "llvm/IR/DataLayout.h"
|
||||
|
||||
using namespace cir;
|
||||
@@ -56,9 +57,10 @@ void CIRGenerator::Initialize(ASTContext &astContext) {
|
||||
mlirContext->getOrLoadDialect<mlir::acc::OpenACCDialect>();
|
||||
mlirContext->getOrLoadDialect<mlir::omp::OpenMPDialect>();
|
||||
|
||||
// Register extensions to integrate CIR types with OpenACC.
|
||||
// Register extensions to integrate CIR types with OpenACC and OpenMP.
|
||||
mlir::DialectRegistry registry;
|
||||
cir::acc::registerOpenACCExtensions(registry);
|
||||
cir::omp::registerOpenMPExtensions(registry);
|
||||
mlirContext->appendDialectRegistry(registry);
|
||||
|
||||
cgm = std::make_unique<clang::CIRGen::CIRGenModule>(
|
||||
|
||||
@@ -65,6 +65,7 @@ add_clang_library(clangCIR
|
||||
clangLex
|
||||
${dialect_libs}
|
||||
CIROpenACCSupport
|
||||
CIROpenMPSupport
|
||||
MLIRCIR
|
||||
MLIRCIRInterfaces
|
||||
MLIRTargetLLVMIRImport
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(OpenACC)
|
||||
add_subdirectory(OpenMP)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
11
clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt
Normal file
11
clang/lib/CIR/Dialect/OpenMP/CMakeLists.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
add_clang_library(CIROpenMPSupport
|
||||
RegisterOpenMPExtensions.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRCIROpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRCIR
|
||||
MLIROpenMPDialect
|
||||
)
|
||||
26
clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp
Normal file
26
clang/lib/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Registration for OpenMP extensions as applied to CIR dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
|
||||
#include "clang/CIR/Dialect/IR/CIRDialect.h"
|
||||
|
||||
namespace cir::omp {
|
||||
|
||||
void registerOpenMPExtensions(mlir::DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](mlir::MLIRContext *ctx, cir::CIRDialect *dialect) {
|
||||
cir::FuncOp::attachInterface<
|
||||
mlir::omp::DeclareTargetDefaultModel<cir::FuncOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace cir::omp
|
||||
@@ -22,6 +22,7 @@ add_clang_library(clangCIRLoweringDirectToLLVM
|
||||
MLIRBuiltinToLLVMIRTranslation
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIROpenMPToLLVMIRTranslation
|
||||
MLIROpenMPTransforms
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
@@ -4936,6 +4937,7 @@ std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
|
||||
|
||||
void populateCIRToLLVMPasses(mlir::OpPassManager &pm) {
|
||||
mlir::populateCIRPreLoweringPasses(pm);
|
||||
pm.addPass(mlir::omp::createMarkDeclareTargetPass());
|
||||
pm.addPass(createConvertCIRToLLVMPass());
|
||||
}
|
||||
|
||||
|
||||
53
clang/test/CIR/Transforms/omp-mark-declare-target.cir
Normal file
53
clang/test/CIR/Transforms/omp-mark-declare-target.cir
Normal file
@@ -0,0 +1,53 @@
|
||||
// RUN: cir-opt --omp-mark-declare-target %s -o - | FileCheck %s
|
||||
|
||||
// Test that the MarkDeclareTarget pass propagates the declare_target
|
||||
// attribute from explicitly marked functions to functions they call,
|
||||
// and from omp.target regions to functions called within them.
|
||||
|
||||
!s32i = !cir.int<s, 32>
|
||||
|
||||
module {
|
||||
// A helper function with no declare_target attribute initially.
|
||||
// After the pass, it should be marked because @caller calls it.
|
||||
// CHECK-LABEL: cir.func @helper
|
||||
// CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)
|
||||
cir.func @helper() {
|
||||
cir.return
|
||||
}
|
||||
|
||||
// Explicitly marked as declare_target; calls @helper.
|
||||
// CHECK-LABEL: cir.func @caller
|
||||
// CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>
|
||||
cir.func @caller() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
|
||||
cir.call @helper() : () -> ()
|
||||
cir.return
|
||||
}
|
||||
|
||||
// Called from within an omp.target region; should be marked as nohost.
|
||||
// CHECK-LABEL: cir.func @device_helper
|
||||
// CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)
|
||||
cir.func @device_helper() {
|
||||
cir.return
|
||||
}
|
||||
|
||||
// Contains an omp.target region that calls @device_helper.
|
||||
// The function itself should NOT be marked as declare_target.
|
||||
// CHECK-LABEL: cir.func @target_caller
|
||||
// CHECK-NOT: omp.declare_target
|
||||
// CHECK-SAME: {
|
||||
cir.func @target_caller() {
|
||||
omp.target {
|
||||
cir.call @device_helper() : () -> ()
|
||||
omp.terminator
|
||||
}
|
||||
cir.return
|
||||
}
|
||||
|
||||
// Not called by any declare_target function or target region.
|
||||
// CHECK-LABEL: cir.func @unrelated
|
||||
// CHECK-NOT: omp.declare_target
|
||||
// CHECK-SAME: {
|
||||
cir.func @unrelated() {
|
||||
cir.return
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ clang_target_link_libraries(cir-opt
|
||||
PRIVATE
|
||||
clangCIR
|
||||
clangCIRLoweringDirectToLLVM
|
||||
CIROpenMPSupport
|
||||
MLIRCIR
|
||||
MLIRCIRTransforms
|
||||
)
|
||||
@@ -35,6 +36,7 @@ target_link_libraries(cir-opt
|
||||
MLIRDialect
|
||||
MLIRIR
|
||||
MLIRMemRefDialect
|
||||
MLIROpenMPTransforms
|
||||
MLIROptLib
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Pass/PassOptions.h"
|
||||
@@ -25,6 +26,7 @@
|
||||
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "clang/CIR/Dialect/IR/CIRDialect.h"
|
||||
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
|
||||
#include "clang/CIR/Dialect/Passes.h"
|
||||
#include "clang/CIR/Passes.h"
|
||||
|
||||
@@ -37,6 +39,7 @@ int main(int argc, char **argv) {
|
||||
registry.insert<mlir::BuiltinDialect, cir::CIRDialect,
|
||||
mlir::memref::MemRefDialect, mlir::LLVM::LLVMDialect,
|
||||
mlir::DLTIDialect, mlir::omp::OpenMPDialect>();
|
||||
cir::omp::registerOpenMPExtensions(registry);
|
||||
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||
return mlir::createCIRCanonicalizePass();
|
||||
@@ -71,6 +74,7 @@ int main(int argc, char **argv) {
|
||||
return mlir::createCXXABILoweringPass();
|
||||
});
|
||||
|
||||
mlir::omp::registerOpenMPPasses();
|
||||
mlir::registerTransformsPasses();
|
||||
|
||||
return mlir::asMainReturnCode(MlirOptMain(
|
||||
|
||||
@@ -13,6 +13,7 @@ clang_target_link_libraries(cir-translate
|
||||
PRIVATE
|
||||
clangCIR
|
||||
clangCIRLoweringDirectToLLVM
|
||||
CIROpenMPSupport
|
||||
MLIRCIR
|
||||
MLIRCIRTransforms
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "clang/Basic/DiagnosticOptions.h"
|
||||
#include "clang/Basic/TargetInfo.h"
|
||||
#include "clang/CIR/Dialect/IR/CIRDialect.h"
|
||||
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
|
||||
#include "clang/CIR/Dialect/Passes.h"
|
||||
#include "clang/CIR/LowerToLLVM.h"
|
||||
#include "clang/CIR/MissingFeatures.h"
|
||||
@@ -169,6 +170,7 @@ void registerToLLVMTranslation() {
|
||||
registry.insert<mlir::DLTIDialect, mlir::func::FuncDialect>();
|
||||
mlir::registerAllToLLVMIRTranslations(registry);
|
||||
cir::direct::registerCIRDialectTranslation(registry);
|
||||
cir::omp::registerOpenMPExtensions(registry);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ add_mlir_dialect_library(MLIROpenMPTransforms
|
||||
MLIROpenMPPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRFunctionInterfaces
|
||||
MLIRIR
|
||||
MLIRFuncDialect
|
||||
MLIRLLVMDialect
|
||||
MLIROpenMPDialect
|
||||
MLIRPass
|
||||
|
||||
@@ -10,10 +10,10 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
@@ -42,28 +42,30 @@ class MarkDeclareTargetPass
|
||||
|
||||
void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
|
||||
llvm::SmallPtrSet<Operation *, 16> visited) {
|
||||
if (auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) {
|
||||
auto current =
|
||||
llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation());
|
||||
Operation *symOp = getOperation().lookupSymbol(symRef);
|
||||
if (!symOp)
|
||||
return;
|
||||
auto current = llvm::dyn_cast<omp::DeclareTargetInterface>(symOp);
|
||||
if (!current)
|
||||
return;
|
||||
|
||||
if (current.isDeclareTarget()) {
|
||||
auto currentDt = current.getDeclareTargetDeviceType();
|
||||
if (current.isDeclareTarget()) {
|
||||
auto currentDt = current.getDeclareTargetDeviceType();
|
||||
|
||||
// Found the same function twice, with different device_types,
|
||||
// mark as Any as it belongs to both
|
||||
if (currentDt != parentInfo.devTy &&
|
||||
currentDt != omp::DeclareTargetDeviceType::any) {
|
||||
current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
|
||||
current.getDeclareTargetCaptureClause(),
|
||||
current.getDeclareTargetAutomap());
|
||||
}
|
||||
} else {
|
||||
current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
|
||||
parentInfo.automap);
|
||||
// Found the same function twice, with different device_types,
|
||||
// mark as Any as it belongs to both
|
||||
if (currentDt != parentInfo.devTy &&
|
||||
currentDt != omp::DeclareTargetDeviceType::any) {
|
||||
current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
|
||||
current.getDeclareTargetCaptureClause(),
|
||||
current.getDeclareTargetAutomap());
|
||||
}
|
||||
|
||||
markNestedFuncs(parentInfo, currFOp, visited);
|
||||
} else {
|
||||
current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
|
||||
parentInfo.automap);
|
||||
}
|
||||
|
||||
markNestedFuncs(parentInfo, symOp, visited);
|
||||
}
|
||||
|
||||
void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
|
||||
@@ -138,16 +140,16 @@ class MarkDeclareTargetPass
|
||||
// as implicitly declare target if they are called from within an explicitly
|
||||
// marked declare target function or a target region (TargetOp)
|
||||
void runOnOperation() override {
|
||||
for (auto functionOp : getOperation().getOps<func::FuncOp>()) {
|
||||
auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>(
|
||||
functionOp.getOperation());
|
||||
if (declareTargetOp.isDeclareTarget()) {
|
||||
llvm::SmallPtrSet<Operation *, 16> visited;
|
||||
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
|
||||
declareTargetOp.getDeclareTargetCaptureClause(),
|
||||
declareTargetOp.getDeclareTargetAutomap()};
|
||||
markNestedFuncs(parentInfo, functionOp, visited);
|
||||
}
|
||||
for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
|
||||
auto declareTargetOp =
|
||||
llvm::dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
|
||||
if (!declareTargetOp || !declareTargetOp.isDeclareTarget())
|
||||
continue;
|
||||
llvm::SmallPtrSet<Operation *, 16> visited;
|
||||
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
|
||||
declareTargetOp.getDeclareTargetCaptureClause(),
|
||||
declareTargetOp.getDeclareTargetAutomap()};
|
||||
markNestedFuncs(parentInfo, funcOp, visited);
|
||||
}
|
||||
|
||||
// TODO: Extend to work with reverse-offloading, this shouldn't
|
||||
|
||||
Reference in New Issue
Block a user