[CIR][MLIR][OpenMP] Enable the MarkDeclareTarget pass for ClangIR (#189420)

This patch enables the MarkDeclareTarget for CIR by adding the pass to
the lowerings and attaching the declare target interface to the
cir::FuncOp. The MarkDeclareTarget is also generalized to work on the
FunctionOpInterface instead of func::Op since it needs to be able to
handle cir::FuncOp as well.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jan Leyonberg
2026-04-01 12:50:09 -04:00
committed by GitHub
parent 44979bedf0
commit 91adaeceb1
15 changed files with 161 additions and 31 deletions

View File

@@ -0,0 +1,22 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
#define CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H
namespace mlir {
class DialectRegistry;
} // namespace mlir
namespace cir::omp {
void registerOpenMPExtensions(mlir::DialectRegistry &registry);
} // namespace cir::omp
#endif // CLANG_CIR_DIALECT_OPENMP_REGISTEROPENMPEXTENSIONS_H

View File

@@ -21,6 +21,7 @@
#include "clang/CIR/CIRGenerator.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/OpenACC/RegisterOpenACCExtensions.h"
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
#include "llvm/IR/DataLayout.h"
using namespace cir;
@@ -56,9 +57,10 @@ void CIRGenerator::Initialize(ASTContext &astContext) {
mlirContext->getOrLoadDialect<mlir::acc::OpenACCDialect>();
mlirContext->getOrLoadDialect<mlir::omp::OpenMPDialect>();
// Register extensions to integrate CIR types with OpenACC.
// Register extensions to integrate CIR types with OpenACC and OpenMP.
mlir::DialectRegistry registry;
cir::acc::registerOpenACCExtensions(registry);
cir::omp::registerOpenMPExtensions(registry);
mlirContext->appendDialectRegistry(registry);
cgm = std::make_unique<clang::CIRGen::CIRGenModule>(

View File

@@ -65,6 +65,7 @@ add_clang_library(clangCIR
clangLex
${dialect_libs}
CIROpenACCSupport
CIROpenMPSupport
MLIRCIR
MLIRCIRInterfaces
MLIRTargetLLVMIRImport

View File

@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,11 @@
add_clang_library(CIROpenMPSupport
RegisterOpenMPExtensions.cpp
DEPENDS
MLIRCIROpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRCIR
MLIROpenMPDialect
)

View File

@@ -0,0 +1,26 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Registration for OpenMP extensions as applied to CIR dialect.
//
//===----------------------------------------------------------------------===//
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
namespace cir::omp {
void registerOpenMPExtensions(mlir::DialectRegistry &registry) {
registry.addExtension(+[](mlir::MLIRContext *ctx, cir::CIRDialect *dialect) {
cir::FuncOp::attachInterface<
mlir::omp::DeclareTargetDefaultModel<cir::FuncOp>>(*ctx);
});
}
} // namespace cir::omp

View File

@@ -22,6 +22,7 @@ add_clang_library(clangCIRLoweringDirectToLLVM
MLIRBuiltinToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
MLIROpenMPToLLVMIRTranslation
MLIROpenMPTransforms
MLIRIR
)

View File

@@ -21,6 +21,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -4936,6 +4937,7 @@ std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
void populateCIRToLLVMPasses(mlir::OpPassManager &pm) {
mlir::populateCIRPreLoweringPasses(pm);
pm.addPass(mlir::omp::createMarkDeclareTargetPass());
pm.addPass(createConvertCIRToLLVMPass());
}

View File

@@ -0,0 +1,53 @@
// RUN: cir-opt --omp-mark-declare-target %s -o - | FileCheck %s
// Test that the MarkDeclareTarget pass propagates the declare_target
// attribute from explicitly marked functions to functions they call,
// and from omp.target regions to functions called within them.
!s32i = !cir.int<s, 32>
module {
// A helper function with no declare_target attribute initially.
// After the pass, it should be marked because @caller calls it.
// CHECK-LABEL: cir.func @helper
// CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)
cir.func @helper() {
cir.return
}
// Explicitly marked as declare_target; calls @helper.
// CHECK-LABEL: cir.func @caller
// CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>
cir.func @caller() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
cir.call @helper() : () -> ()
cir.return
}
// Called from within an omp.target region; should be marked as nohost.
// CHECK-LABEL: cir.func @device_helper
// CHECK-SAME: omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)
cir.func @device_helper() {
cir.return
}
// Contains an omp.target region that calls @device_helper.
// The function itself should NOT be marked as declare_target.
// CHECK-LABEL: cir.func @target_caller
// CHECK-NOT: omp.declare_target
// CHECK-SAME: {
cir.func @target_caller() {
omp.target {
cir.call @device_helper() : () -> ()
omp.terminator
}
cir.return
}
// Not called by any declare_target function or target region.
// CHECK-LABEL: cir.func @unrelated
// CHECK-NOT: omp.declare_target
// CHECK-SAME: {
cir.func @unrelated() {
cir.return
}
}

View File

@@ -23,6 +23,7 @@ clang_target_link_libraries(cir-opt
PRIVATE
clangCIR
clangCIRLoweringDirectToLLVM
CIROpenMPSupport
MLIRCIR
MLIRCIRTransforms
)
@@ -35,6 +36,7 @@ target_link_libraries(cir-opt
MLIRDialect
MLIRIR
MLIRMemRefDialect
MLIROpenMPTransforms
MLIROptLib
MLIRParser
MLIRPass

View File

@@ -18,6 +18,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
@@ -25,6 +26,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/Passes.h"
@@ -37,6 +39,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::BuiltinDialect, cir::CIRDialect,
mlir::memref::MemRefDialect, mlir::LLVM::LLVMDialect,
mlir::DLTIDialect, mlir::omp::OpenMPDialect>();
cir::omp::registerOpenMPExtensions(registry);
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::createCIRCanonicalizePass();
@@ -71,6 +74,7 @@ int main(int argc, char **argv) {
return mlir::createCXXABILoweringPass();
});
mlir::omp::registerOpenMPPasses();
mlir::registerTransformsPasses();
return mlir::asMainReturnCode(MlirOptMain(

View File

@@ -13,6 +13,7 @@ clang_target_link_libraries(cir-translate
PRIVATE
clangCIR
clangCIRLoweringDirectToLLVM
CIROpenMPSupport
MLIRCIR
MLIRCIRTransforms
)

View File

@@ -31,6 +31,7 @@
#include "clang/Basic/DiagnosticOptions.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/OpenMP/RegisterOpenMPExtensions.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/MissingFeatures.h"
@@ -169,6 +170,7 @@ void registerToLLVMTranslation() {
registry.insert<mlir::DLTIDialect, mlir::func::FuncDialect>();
mlir::registerAllToLLVMIRTranslations(registry);
cir::direct::registerCIRDialectTranslation(registry);
cir::omp::registerOpenMPExtensions(registry);
});
}

View File

@@ -6,8 +6,8 @@ add_mlir_dialect_library(MLIROpenMPTransforms
MLIROpenMPPassIncGen
LINK_LIBS PUBLIC
MLIRFunctionInterfaces
MLIRIR
MLIRFuncDialect
MLIRLLVMDialect
MLIROpenMPDialect
MLIRPass

View File

@@ -10,10 +10,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -42,28 +42,30 @@ class MarkDeclareTargetPass
void processSymbolRef(SymbolRefAttr symRef, ParentInfo parentInfo,
llvm::SmallPtrSet<Operation *, 16> visited) {
if (auto currFOp = getOperation().lookupSymbol<func::FuncOp>(symRef)) {
auto current =
llvm::dyn_cast<omp::DeclareTargetInterface>(currFOp.getOperation());
Operation *symOp = getOperation().lookupSymbol(symRef);
if (!symOp)
return;
auto current = llvm::dyn_cast<omp::DeclareTargetInterface>(symOp);
if (!current)
return;
if (current.isDeclareTarget()) {
auto currentDt = current.getDeclareTargetDeviceType();
if (current.isDeclareTarget()) {
auto currentDt = current.getDeclareTargetDeviceType();
// Found the same function twice, with different device_types,
// mark as Any as it belongs to both
if (currentDt != parentInfo.devTy &&
currentDt != omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause(),
current.getDeclareTargetAutomap());
}
} else {
current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
parentInfo.automap);
// Found the same function twice, with different device_types,
// mark as Any as it belongs to both
if (currentDt != parentInfo.devTy &&
currentDt != omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause(),
current.getDeclareTargetAutomap());
}
markNestedFuncs(parentInfo, currFOp, visited);
} else {
current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
parentInfo.automap);
}
markNestedFuncs(parentInfo, symOp, visited);
}
void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
@@ -138,16 +140,16 @@ class MarkDeclareTargetPass
// as implicitly declare target if they are called from within an explicitly
// marked declare target function or a target region (TargetOp)
void runOnOperation() override {
for (auto functionOp : getOperation().getOps<func::FuncOp>()) {
auto declareTargetOp = llvm::dyn_cast<omp::DeclareTargetInterface>(
functionOp.getOperation());
if (declareTargetOp.isDeclareTarget()) {
llvm::SmallPtrSet<Operation *, 16> visited;
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
declareTargetOp.getDeclareTargetAutomap()};
markNestedFuncs(parentInfo, functionOp, visited);
}
for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
auto declareTargetOp =
llvm::dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
if (!declareTargetOp || !declareTargetOp.isDeclareTarget())
continue;
llvm::SmallPtrSet<Operation *, 16> visited;
ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
declareTargetOp.getDeclareTargetAutomap()};
markNestedFuncs(parentInfo, funcOp, visited);
}
// TODO: Extend to work with reverse-offloading, this shouldn't