Files
llvm-project/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
Sergio Afonso c94db1af36 [MLIR][OpenMP] Unify device shared memory logic, NFCI (#182856)
This patch creates a utils library for the OpenMP dialect with functions
used by MLIR to LLVM IR translation as well as the stack-to-shared pass
to determine which allocations must use local stack memory or device
shared memory.
2026-04-27 13:15:55 +01:00

125 lines
5.0 KiB
C++

//===- StackToShared.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to swap stack allocations on the target
// device with device shared memory where applicable.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace omp {
#define GEN_PASS_DEF_STACKTOSHAREDPASS
#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
} // namespace omp
} // namespace mlir
using namespace mlir;
/// Tell whether to replace an operation representing a stack allocation with a
/// device shared memory allocation/deallocation pair based on the location of
/// the allocation and its uses.
static bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) {
return omp::opInSharedDeviceContext(op) &&
llvm::any_of(op.getResults(), [&](Value result) {
return omp::allocaUsesRequireSharedMem(result);
});
}
/// Based on the location of the definition of the given value representing the
/// result of a device shared memory allocation, find the corresponding points
/// where its deallocation should be placed and introduce `omp.free_shared_mem`
/// ops at those points.
static void insertDeviceSharedMemDeallocation(OpBuilder &builder,
TypeAttr elemType,
Value arraySize,
IntegerAttr alignment,
Value allocVal) {
Block *allocaBlock = allocVal.getParentBlock();
DominanceInfo domInfo;
for (Block &block : allocVal.getParentRegion()->getBlocks()) {
Operation *terminator = block.getTerminator();
if (!terminator->hasSuccessors() &&
domInfo.dominates(allocaBlock, &block)) {
builder.setInsertionPoint(terminator);
omp::FreeSharedMemOp::create(builder, allocVal.getLoc(), elemType,
arraySize, alignment, allocVal);
}
}
}
namespace {
class StackToSharedPass
: public omp::impl::StackToSharedPassBase<StackToSharedPass> {
public:
StackToSharedPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
OpBuilder builder(context);
LLVM::LLVMFuncOp funcOp = getOperation();
auto offloadIface = funcOp->getParentOfType<omp::OffloadModuleInterface>();
if (!offloadIface || !offloadIface.getIsTargetDevice())
return;
llvm::SmallVector<Operation *> toBeDeleted;
funcOp->walk([&](LLVM::AllocaOp allocaOp) {
if (!shouldReplaceAllocaWithDeviceSharedMem(*allocaOp))
return;
// Replace llvm.alloca with omp.alloc_shared_mem.
Type resultType = allocaOp.getResult().getType();
// TODO: The handling of non-default address spaces might need to be
// improved. This currently only handles the case where an alloca to
// non-default address space is only used by a single addrspacecast to
// default address space.
bool nonDefaultAddrSpace = false;
if (auto llvmPtrType = dyn_cast<LLVM::LLVMPointerType>(resultType))
nonDefaultAddrSpace = llvmPtrType.getAddressSpace() != 0;
builder.setInsertionPoint(allocaOp);
auto sharedAllocOp = omp::AllocSharedMemOp::create(
builder, allocaOp->getLoc(), LLVM::LLVMPointerType::get(context),
allocaOp.getElemTypeAttr(), allocaOp.getArraySize(),
allocaOp.getAlignmentAttr());
if (nonDefaultAddrSpace) {
assert(allocaOp->hasOneUse() && " unsupported non-default address "
"space alloca with multiple uses");
auto asCastOp =
cast<LLVM::AddrSpaceCastOp>(*allocaOp->getUsers().begin());
asCastOp.replaceAllUsesWith(sharedAllocOp.getOperation());
// Delete later because we can't delete the cast op before the top-level
// iteration visits it. Also, the alloca can't be deleted before because
// it's used by it.
toBeDeleted.push_back(asCastOp);
toBeDeleted.push_back(allocaOp);
} else {
allocaOp.replaceAllUsesWith(sharedAllocOp.getOperation());
allocaOp.erase();
}
// Create a new omp.free_shared_mem for the allocated buffer prior to
// exiting the region.
insertDeviceSharedMemDeallocation(
builder, allocaOp.getElemTypeAttr(), allocaOp.getArraySize(),
allocaOp.getAlignmentAttr(), sharedAllocOp.getResult());
});
for (Operation *op : toBeDeleted)
op->erase();
}
};
} // namespace