[MLIR][OpenMP] Unify device shared memory logic, NFCI (#182856)

This patch creates a utils library for the OpenMP dialect with functions
used by MLIR to LLVM IR translation as well as the stack-to-shared pass
to determine which allocations must use local stack memory or device
shared memory.
This commit is contained in:
Sergio Afonso
2026-04-27 13:15:55 +01:00
committed by GitHub
parent fad06a418b
commit c94db1af36
8 changed files with 193 additions and 178 deletions

View File

@@ -0,0 +1,53 @@
//===- Utils.h - OpenMP dialect utilities -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various OpenMP utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_OPENMP_UTILS_UTILS_H_
#define MLIR_DIALECT_OPENMP_UTILS_UTILS_H_
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace mlir {
namespace omp {
/// Check whether the value representing an allocation, assumed to have been
/// defined in a shared device context, is used in a manner that would require
/// device shared memory for correctness.
///
/// When a use takes place inside an omp.parallel region and it's not as a
/// private clause argument, or when it is a reduction argument passed to
/// omp.parallel or a function call argument, then the defining allocation is
/// eligible for replacement with shared memory.
///
/// \see mlir::omp::opInSharedDeviceContext().
bool allocaUsesRequireSharedMem(Value alloc);
/// Check whether the given operation is located in a context where an
/// allocation to be used by multiple threads in a parallel region would have to
/// be placed in device shared memory to be accessible.
///
/// That means that it is inside of a target device module, it is a non-SPMD
/// target region, is inside of one or it's located in a device function, and it
/// is not not inside of a parallel region.
///
/// This represents a necessary but not sufficient set of conditions to use
/// device shared memory in place of regular allocas. For some variables, the
/// associated OpenMP construct or their uses might also need to be taken into
/// account.
///
/// \see mlir::omp::allocaUsesRequireSharedMem().
bool opInSharedDeviceContext(Operation &op);
} // namespace omp
} // namespace mlir
#endif // MLIR_DIALECT_OPENMP_UTILS_UTILS_H_

View File

@@ -1,2 +1,3 @@
add_subdirectory(IR) add_subdirectory(IR)
add_subdirectory(Transforms) add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIROpenMPTransforms
MLIRLLVMDialect MLIRLLVMDialect
MLIROpenACCMPCommon MLIROpenACCMPCommon
MLIROpenMPDialect MLIROpenMPDialect
MLIROpenMPUtils
MLIRPass MLIRPass
MLIRSupport MLIRSupport
MLIRTransforms MLIRTransforms

View File

@@ -15,7 +15,9 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/Utils/Utils.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir { namespace mlir {
namespace omp { namespace omp {
@@ -26,94 +28,20 @@ namespace omp {
using namespace mlir; using namespace mlir;
/// When a use takes place inside an omp.parallel region and it's not as a /// Tell whether to replace an operation representing a stack allocation with a
/// private clause argument, or when it is a reduction argument passed to /// device shared memory allocation/deallocation pair based on the location of
/// omp.parallel or a function call argument, then the defining allocation is /// the allocation and its uses.
/// eligible for replacement with shared memory.
static bool allocaUseRequiresDeviceSharedMem(const OpOperand &use) {
Operation *owner = use.getOwner();
if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) {
if (llvm::is_contained(parallelOp.getReductionVars(), use.get()))
return true;
} else if (auto callOp = dyn_cast<CallOpInterface>(owner)) {
if (llvm::is_contained(callOp.getArgOperands(), use.get()))
return true;
}
// If it is used directly inside of a parallel region, it has to be replaced
// unless the use is a private clause.
if (owner->getParentOfType<omp::ParallelOp>()) {
if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) {
if (auto privateSyms =
cast_or_null<ArrayAttr>(owner->getAttr("private_syms"))) {
for (auto [var, sym] :
llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
if (var != use.get())
continue;
auto moduleOp = owner->getParentOfType<ModuleOp>();
auto privateOp = cast<omp::PrivateClauseOp>(
moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
return privateOp.getDataSharingType() !=
omp::DataSharingClauseType::Private;
}
}
}
return true;
}
return false;
}
static bool shouldReplaceAllocaWithUses(const Operation::use_range &uses) {
// Check direct uses and also follow hlfir.declare/fir.convert uses.
for (const OpOperand &use : uses) {
Operation *owner = use.getOwner();
if (llvm::isa<LLVM::AddrSpaceCastOp, LLVM::GEPOp>(owner)) {
if (shouldReplaceAllocaWithUses(owner->getUses()))
return true;
} else if (allocaUseRequiresDeviceSharedMem(use)) {
return true;
}
}
return false;
}
// TODO: Refactor the logic in `shouldReplaceAllocaWithDeviceSharedMem`,
// `shouldReplaceAllocaWithUses` and `allocaUseRequiresDeviceSharedMem` to
// be reusable by the MLIR to LLVM IR translation stage, as something very
// similar is also implemented there to choose between allocas and device
// shared memory allocations when processing OpenMP reductions, mapping and
// privatization.
static bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) { static bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) {
auto offloadIface = op.getParentOfType<omp::OffloadModuleInterface>(); return omp::opInSharedDeviceContext(op) &&
if (!offloadIface || !offloadIface.getIsTargetDevice()) llvm::any_of(op.getResults(), [&](Value result) {
return false; return omp::allocaUsesRequireSharedMem(result);
});
auto targetOp = op.getParentOfType<omp::TargetOp>();
// It must be inside of a generic omp.target or in a target device function,
// and not inside of omp.parallel.
if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) {
if (!targetOp || targetOp->isProperAncestor(parallelOp))
return false;
}
if (targetOp) {
if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) !=
omp::TargetExecMode::generic)
return false;
} else {
auto declTargetIface = op.getParentOfType<omp::DeclareTargetInterface>();
if (!declTargetIface || !declTargetIface.isDeclareTarget() ||
declTargetIface.getDeclareTargetDeviceType() ==
omp::DeclareTargetDeviceType::host)
return false;
}
return shouldReplaceAllocaWithUses(op.getUses());
} }
/// Based on the location of the definition of the given value representing the
/// result of a device shared memory allocation, find the corresponding points
/// where its deallocation should be placed and introduce `omp.free_shared_mem`
/// ops at those points.
static void insertDeviceSharedMemDeallocation(OpBuilder &builder, static void insertDeviceSharedMemDeallocation(OpBuilder &builder,
TypeAttr elemType, TypeAttr elemType,
Value arraySize, Value arraySize,

View File

@@ -0,0 +1,13 @@
add_mlir_dialect_library(MLIROpenMPUtils
Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIROpenACCMPCommon
MLIROpenMPDialect
MLIRSupport
)

View File

@@ -0,0 +1,101 @@
//===- StackToShared.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements various OpenMP dialect utilities.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/OpenMP/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
using namespace mlir;
static bool allocaUseRequiresSharedMem(const OpOperand &use) {
Operation *owner = use.getOwner();
if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) {
if (llvm::is_contained(parallelOp.getReductionVars(), use.get()))
return true;
} else if (auto callOp = dyn_cast<CallOpInterface>(owner)) {
if (llvm::is_contained(callOp.getArgOperands(), use.get()))
return true;
}
// If it is used directly inside of a parallel region, it has to be replaced
// unless the use is a private clause.
if (owner->getParentOfType<omp::ParallelOp>()) {
if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) {
OperandRange privateVars = argIface.getPrivateVars();
auto it = llvm::find(privateVars, use.get());
if (it != privateVars.end()) {
auto privateSyms = owner->getAttrOfType<ArrayAttr>("private_syms");
size_t idx = std::distance(privateVars.begin(), it);
auto privateOp =
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
owner, cast<SymbolRefAttr>(privateSyms[idx]));
return privateOp.getDataSharingType() !=
omp::DataSharingClauseType::Private;
}
}
return true;
}
return false;
}
bool mlir::omp::allocaUsesRequireSharedMem(Value alloc) {
for (const OpOperand &use : alloc.getUses()) {
Operation *owner = use.getOwner();
if (isa<LLVM::AddrSpaceCastOp, LLVM::GEPOp>(owner)) {
if (llvm::any_of(owner->getResults(), [&](Value result) {
return allocaUsesRequireSharedMem(result);
}))
return true;
} else if (allocaUseRequiresSharedMem(use)) {
return true;
}
}
return false;
}
bool mlir::omp::opInSharedDeviceContext(Operation &op) {
if (isa<omp::ParallelOp>(op))
return false;
auto offloadIface = op.getParentOfType<omp::OffloadModuleInterface>();
if (!offloadIface || !offloadIface.getIsTargetDevice())
return false;
auto targetOp = op.getParentOfType<omp::TargetOp>();
// It must be inside of a generic omp.target or in a target device function,
// and not inside of omp.parallel.
if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) {
if (!targetOp || targetOp->isProperAncestor(parallelOp))
return false;
}
// The omp.target operation itself is considered in a shared device context in
// order to properly process its own allocation-defining entry block
// arguments.
if (!targetOp)
targetOp = dyn_cast<omp::TargetOp>(op);
if (targetOp) {
if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) !=
omp::TargetExecMode::generic)
return false;
} else {
auto declTargetIface = op.getParentOfType<omp::DeclareTargetInterface>();
if (!declTargetIface || !declTargetIface.isDeclareTarget() ||
declTargetIface.getDeclareTargetDeviceType() ==
omp::DeclareTargetDeviceType::host)
return false;
}
return true;
}

View File

@@ -8,6 +8,7 @@ add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation
MLIRIR MLIRIR
MLIRLLVMDialect MLIRLLVMDialect
MLIROpenMPDialect MLIROpenMPDialect
MLIROpenMPUtils
MLIRSupport MLIRSupport
MLIRTargetLLVMIRExport MLIRTargetLLVMIRExport
MLIRTransformUtils MLIRTransformUtils

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/Dialect/OpenMP/Utils/Utils.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
@@ -1137,81 +1138,6 @@ struct DeferredStore {
}; };
} // namespace } // namespace
/// Check whether allocations for the given operation might potentially have to
/// be done in device shared memory. That means we're compiling for an
/// offloading target, the operation is neither an `omp::TargetOp` nor nested
/// inside of one, or it is and that target region represents a Generic
/// (non-SPMD) kernel.
///
/// This represents a necessary but not sufficient set of conditions to use
/// device shared memory in place of regular allocas. For some variables, the
/// associated OpenMP construct or their uses might also need to be taken into
/// account.
static bool
mightAllocInDeviceSharedMemory(Operation &op,
const llvm::OpenMPIRBuilder &ompBuilder) {
if (!ompBuilder.Config.isTargetDevice())
return false;
auto targetOp = dyn_cast<omp::TargetOp>(op);
if (!targetOp)
targetOp = op.getParentOfType<omp::TargetOp>();
return !targetOp ||
targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) ==
omp::TargetExecMode::generic;
}
/// Check whether the entry block argument representing the private copy of a
/// variable in an OpenMP construct must be allocated in device shared memory,
/// based on what the uses of that copy are.
///
/// This must only be called if a previous call to
/// \c mightAllocInDeviceSharedMemory has already returned \c true for the
/// operation that owns the specified block argument.
static bool mustAllocPrivateVarInDeviceSharedMemory(BlockArgument value) {
Operation *parentOp = value.getOwner()->getParentOp();
auto moduleOp = parentOp->getParentOfType<ModuleOp>();
for (auto *user : value.getUsers()) {
if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
if (llvm::is_contained(parallelOp.getReductionVars(), value))
return true;
} else if (auto callOp = dyn_cast<CallOpInterface>(user)) {
if (llvm::is_contained(callOp.getArgOperands(), value))
return true;
}
if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) {
if (parentOp->isProperAncestor(parallelOp)) {
// If it is used directly inside of a parallel region, skip private
// clause uses.
bool isPrivateClauseUse = false;
if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(user)) {
if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
user->getAttr("private_syms"))) {
for (auto [var, sym] :
llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
if (var != value)
continue;
auto privateOp = cast<omp::PrivateClauseOp>(
moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
if (privateOp.getCopyRegion().empty()) {
isPrivateClauseUse = true;
break;
}
}
}
}
if (!isPrivateClauseUse)
return true;
}
}
}
return false;
}
/// Allocate space for privatized reduction variables. /// Allocate space for privatized reduction variables.
/// `deferredStores` contains information to create store operations which needs /// `deferredStores` contains information to create store operations which needs
/// to be inserted after all allocas /// to be inserted after all allocas
@@ -1230,8 +1156,7 @@ allocReductionVars(T op, ArrayRef<BlockArgument> reductionArgs,
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) && bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
mightAllocInDeviceSharedMemory(*op, *ompBuilder);
// delay creating stores until after all allocas // delay creating stores until after all allocas
deferredStores.reserve(op.getNumReductionVars()); deferredStores.reserve(op.getNumReductionVars());
@@ -1362,8 +1287,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
return success(); return success();
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) && bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
mightAllocInDeviceSharedMemory(*op, *ompBuilder);
llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init"); llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
auto allocaIP = llvm::IRBuilderBase::InsertPoint( auto allocaIP = llvm::IRBuilderBase::InsertPoint(
@@ -1610,8 +1534,7 @@ static LogicalResult createReductionsAndCleanup(
reductionRegions, privateReductionVariables, moduleTranslation, builder, reductionRegions, privateReductionVariables, moduleTranslation, builder,
"omp.reduction.cleanup"); "omp.reduction.cleanup");
bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) && bool useDeviceSharedMem = omp::opInSharedDeviceContext(*op);
mightAllocInDeviceSharedMemory(*op, *ompBuilder);
if (useDeviceSharedMem) { if (useDeviceSharedMem) {
for (auto [var, reductionDecl] : for (auto [var, reductionDecl] :
llvm::zip_equal(privateReductionVariables, reductionDecls)) llvm::zip_equal(privateReductionVariables, reductionDecls))
@@ -1802,9 +1725,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0); llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool mightUseDeviceSharedMem = bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) &&
mightAllocInDeviceSharedMemory(*op, *ompBuilder);
unsigned int allocaAS = unsigned int allocaAS =
moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace(); moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
unsigned int defaultAS = moduleTranslation.getLLVMModule() unsigned int defaultAS = moduleTranslation.getLLVMModule()
@@ -1818,8 +1739,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
moduleTranslation.convertType(privDecl.getType()); moduleTranslation.convertType(privDecl.getType());
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
llvm::Value *llvmPrivateVar = nullptr; llvm::Value *llvmPrivateVar = nullptr;
if (mightUseDeviceSharedMem && if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
mustAllocPrivateVarInDeviceSharedMemory(blockArg)) {
llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType); llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
} else { } else {
llvmPrivateVar = builder.CreateAlloca( llvmPrivateVar = builder.CreateAlloca(
@@ -1956,14 +1876,11 @@ cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
"`omp.private` op in"); "`omp.private` op in");
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool mightUseDeviceSharedMem = bool mightUseDeviceSharedMem = omp::opInSharedDeviceContext(*op);
isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) &&
mightAllocInDeviceSharedMemory(*op, *ompBuilder);
for (auto [privDecl, llvmPrivVar, blockArg] : for (auto [privDecl, llvmPrivVar, blockArg] :
llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars, llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
privateVarsInfo.blockArgs)) { privateVarsInfo.blockArgs)) {
if (mightUseDeviceSharedMem && if (mightUseDeviceSharedMem && omp::allocaUsesRequireSharedMem(blockArg)) {
mustAllocPrivateVarInDeviceSharedMemory(blockArg)) {
ompBuilder->createOMPFreeShared( ompBuilder->createOMPFreeShared(
builder, llvmPrivVar, builder, llvmPrivVar,
moduleTranslation.convertType(privDecl.getType())); moduleTranslation.convertType(privDecl.getType()));
@@ -6826,8 +6743,8 @@ static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(
// Create the allocation for the argument. // Create the allocation for the argument.
llvm::Value *v = nullptr; llvm::Value *v = nullptr;
if (mightAllocInDeviceSharedMemory(*targetOp, ompBuilder) && if (omp::opInSharedDeviceContext(*targetOp) &&
mustAllocPrivateVarInDeviceSharedMemory(mlirArg)) { omp::allocaUsesRequireSharedMem(mlirArg)) {
// Use the beginning of the codeGenIP rather than the usual allocation point // Use the beginning of the codeGenIP rather than the usual allocation point
// for shared memory allocations because otherwise these would be done prior // for shared memory allocations because otherwise these would be done prior
// to the target initialization call. Also, the exit block (where the // to the target initialization call. Also, the exit block (where the