[mlir][GPU] Refactor GPUOps lowering (#188905)

This change promotes `gpu.func` / `gpu.launch` metadata that was
previously carried as discardable attributes into proper inherent ODS
fields (`kernel`, `workgroup_attributions`), renames the block-argument
helpers to avoid clashing with generated getters, and routes `func.func`
and `gpu.func` lowering through a shared helper that maps discardable
`llvm.*` attributes into `llvm.func` properties.

Downstream producers (Flang CUDA device func transform, kernel
outlining, sparse GPU codegen, XeGPU) set kernels via `setKernel(true)`
instead of manually attaching `gpu.kernel`.

Fixes #185174

Assisted-by: CLion code completion, GPT 5.3 - Codex
This commit is contained in:
Hocky Yudhiono
2026-04-22 18:40:57 +08:00
committed by GitHub
parent b313bb7145
commit c1cff89bdc
12 changed files with 262 additions and 222 deletions

View File

@@ -68,8 +68,7 @@ class CUFDeviceFuncTransform
gpu::GPUFuncOp::create(builder, loc, funcOp.getName(), type,
mlir::TypeRange{}, mlir::TypeRange{});
if (isGlobal)
deviceFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
deviceFuncOp.setKernel(true);
mlir::Region &deviceFuncBody = deviceFuncOp.getBody();
mlir::Block &entryBlock = deviceFuncBody.front();

View File

@@ -0,0 +1,40 @@
//===- LowerFunctionDiscardablesToLLVM.h - Func discardables to llvm - C++ -*-//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Shared helpers for lowering discardable attributes on any FunctionOpInterface
// (e.g. func.func, gpu.func) into llvm.func properties and discardables.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERFUNCTIONDISCARDABLESTOLLVM_H
#define MLIR_CONVERSION_LLVMCOMMON_LOWERFUNCTIONDISCARDABLESTOLLVM_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir {
/// Result of lowering discardable attributes from a `FunctionOpInterface` to
/// what `llvm.func` expects: typed inherent properties plus remaining
/// discardable attributes.
struct LoweredLLVMFuncAttrs {
LLVM::LLVMFuncOp::Properties properties;
NamedAttrList discardableAttrs;
};
/// Partition `funcOp`'s discardables for `llvm.func`: `sym_name`,
/// `function_type`, and typed `properties` from `llvm.*` ODS attrs; other
/// discardables unchanged. Fails if that property set is invalid; drops
/// ODS-named attrs without `llvm.`.
FailureOr<LoweredLLVMFuncAttrs>
lowerDiscardableAttrsForLLVMFunc(FunctionOpInterface funcOp, Type llvmFuncType);
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_LOWERFUNCTIONDISCARDABLESTOLLVM_H

View File

@@ -426,7 +426,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
OptionalAttr<DictArrayAttr>:$private_attrib_attrs,
GPU_OptionalDimSizeHintAttr:$known_block_size,
GPU_OptionalDimSizeHintAttr:$known_grid_size,
GPU_OptionalDimSizeHintAttr:$known_cluster_size);
GPU_OptionalDimSizeHintAttr:$known_cluster_size,
OptionalAttr<ConfinedAttr<I64Attr, [IntNonNegative]>>:$workgroup_attributions,
UnitAttr:$kernel);
let regions = (region AnyRegion:$body);
let skipDefaultBuilders = 1;
@@ -440,22 +442,24 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
let extraClassDeclaration = [{
/// Returns `true` if the GPU function defined by this Op is a kernel, i.e.
/// it is intended to be launched from host.
/// it is intended to be launched from host. Also accepts legacy discardable
/// `gpu.kernel` for compatibility with older generic IR.
bool isKernel() {
if (getKernel())
return true;
return (*this)->getAttrOfType<UnitAttr>(
GPUDialect::getKernelFuncAttrName()) != nullptr;
}
/// Returns the number of buffers located in the workgroup memory.
unsigned getNumWorkgroupAttributions() {
auto attr = (*this)->getAttrOfType<IntegerAttr>(
getNumWorkgroupAttributionsAttrName());
if (!attr)
std::optional<int64_t> v = getWorkgroupAttributions();
if (!v)
return 0;
int64_t value = attr.getInt();
assert(value >= 0 && value < std::numeric_limits<uint32_t>::max() &&
int64_t value = *v;
assert(value < std::numeric_limits<uint32_t>::max() &&
"invalid workgroup attribution count");
return value;
return static_cast<unsigned>(value);
}
/// Return the index of the first workgroup attribution in the block argument
@@ -466,7 +470,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
/// Returns a list of block arguments that correspond to buffers located in
/// the workgroup memory
ArrayRef<BlockArgument> getWorkgroupAttributions() {
ArrayRef<BlockArgument> getWorkgroupAttributionBBArgs() {
auto begin =
std::next(getBody().args_begin(), getFirstWorkgroupAttributionIndex());
auto end = std::next(begin, getNumWorkgroupAttributions());
@@ -548,12 +552,6 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
return setPrivateAttributionAttr(index, StringAttr::get((*this)->getContext(), name), value);
}
/// Returns the name of the attribute containing the number of buffers
/// located in the workgroup memory.
static StringRef getNumWorkgroupAttributionsAttrName() {
return "workgroup_attributions";
}
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
@@ -806,7 +804,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
Optional<Index>:$clusterSizeZ,
Optional<I32>:$dynamicSharedMemorySize,
OptionalAttr<FlatSymbolRefAttr>:$module,
OptionalAttr<FlatSymbolRefAttr>:$function)>,
OptionalAttr<FlatSymbolRefAttr>:$function,
OptionalAttr<ConfinedAttr<I64Attr, [IntNonNegative]>>:$workgroup_attributions)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
let summary = "GPU kernel launch operation";
@@ -1004,19 +1003,17 @@ def GPU_LaunchOp : GPU_Op<"launch", [
/// Returns the number of buffers located in the workgroup memory.
unsigned getNumWorkgroupAttributions() {
auto attr = (*this)->getAttrOfType<IntegerAttr>(
getNumWorkgroupAttributionsAttrName());
if (!attr)
std::optional<int64_t> v = getWorkgroupAttributions();
if (!v)
return 0;
int64_t value = attr.getInt();
assert(value >= 0 && value < std::numeric_limits<uint32_t>::max() &&
int64_t value = *v;
assert(value < std::numeric_limits<uint32_t>::max() &&
"invalid workgroup attribution count");
return value;
return static_cast<unsigned>(value);
}
/// Returns a list of block arguments that correspond to buffers located in
/// the workgroup memory
ArrayRef<BlockArgument> getWorkgroupAttributions() {
/// Block arguments for workgroup memory buffers
ArrayRef<BlockArgument> getWorkgroupAttributionBBArgs() {
auto begin =
std::next(getBody().args_begin(), getNumConfigRegionAttributes());
auto end = std::next(begin, getNumWorkgroupAttributions());
@@ -1047,12 +1044,6 @@ def GPU_LaunchOp : GPU_Op<"launch", [
/// Adds a new block argument that corresponds to buffers located in
/// private memory.
BlockArgument addPrivateAttribution(Type type, Location loc);
/// Returns the name of the attribute containing the number of buffers
/// located in the workgroup memory.
static StringRef getNumWorkgroupAttributionsAttrName() {
return "workgroup_attributions";
}
}];
let hasCanonicalizer = 1;

View File

@@ -19,6 +19,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LowerFunctionDiscardablesToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -77,65 +78,6 @@ static void filterFuncAttributes(FunctionOpInterface func,
}
}
/// Add custom lowered funcOp to llvm.func attributes here.
struct LoweredFuncAttrs {
LLVM::LLVMFuncOp::Properties properties;
NamedAttrList discardableAttrs;
};
/// Lower discardable function attributes on `func.func` to attributes expected
/// by `llvm.func`.
static FailureOr<LoweredFuncAttrs>
lowerFuncAttributes(FunctionOpInterface func) {
MLIRContext *ctx = func->getContext();
LoweredFuncAttrs lowered;
llvm::SmallDenseSet<StringRef> odsAttrNames(
LLVM::LLVMFuncOp::getAttributeNames().begin(),
LLVM::LLVMFuncOp::getAttributeNames().end());
NamedAttrList inherentAttrs;
for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
StringRef attrName = attr.getName().strref();
if (odsAttrNames.contains(attrName)) {
LDBG() << "LLVM specific attributes: " << attrName
<< "should use llvm.* prefix, discarding it";
continue;
}
StringRef inherent = attrName;
if (inherent.consume_front("llvm.") && odsAttrNames.contains(inherent))
inherentAttrs.set(inherent, attr.getValue()); // collect inherent attrs
else
lowered.discardableAttrs.push_back(attr);
}
// Convert collected inherent attrs into typed properties.
if (!inherentAttrs.empty()) {
DictionaryAttr dict = inherentAttrs.getDictionary(ctx);
auto emitError = [&] {
return func.emitOpError("invalid llvm.func property");
};
if (failed(LLVM::LLVMFuncOp::setPropertiesFromAttr(lowered.properties, dict,
emitError))) {
return failure();
}
}
return lowered;
}
static void buildLLVMFuncProperties(PatternRewriter &rewriter,
FunctionOpInterface srcFunc,
Type llvmFuncType,
LLVM::LLVMFuncOp::Properties &props) {
MLIRContext *ctx = rewriter.getContext();
props.sym_name = rewriter.getStringAttr(srcFunc.getName());
props.function_type = TypeAttr::get(llvmFuncType);
props.setCConv(LLVM::CConvAttr::get(ctx, LLVM::CConv::C));
}
/// Propagate argument/results attributes.
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
FunctionOpInterface funcOp,
@@ -369,14 +311,15 @@ static FailureOr<LLVM::LLVMFunctionType> convertFuncSignature(
static LLVM::LLVMFuncOp createLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFunctionType llvmType,
LoweredFuncAttrs &loweredAttrs,
LoweredLLVMFuncAttrs &loweredAttrs,
SymbolTableCollection *symbolTables) {
Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
if (symbolTables && symbolTableOp) {
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
symbolTable.remove(funcOp);
}
buildLLVMFuncProperties(rewriter, funcOp, llvmType, loweredAttrs.properties);
loweredAttrs.properties.setCConv(
LLVM::CConvAttr::get(rewriter.getContext(), LLVM::CConv::C));
auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, funcOp.getLoc(),
loweredAttrs.properties,
loweredAttrs.discardableAttrs);
@@ -491,27 +434,25 @@ static void wrapWithCInterface(FunctionOpInterface funcOp,
newFuncOp);
}
// Conversion steps
// 1. Validate function type
// 2. Convert signature
// 3. Validate C wrapper varargs constraint
// 4. Lower function attrs
// 5. Create llvm.func
// 6. Propagate arg/result attrs
// 7. Inline body + signature conversion
// 8. Restore byval/byref pointee types
// 9. C-wrapper handling
/// Conversion steps
/// - Validate function type
/// - Convert signature
/// - Validate C wrapper varargs constraint
/// - Lower function attrs
/// - Create llvm.func
/// - Propagate arg/result attrs
/// - Inline body + signature conversion
/// - Restore byval/byref pointee types
/// - C-wrapper handling
FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
// 1. Validate function type
// Check the funcOp has `FunctionType`.
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
if (!funcTy)
return rewriter.notifyMatchFailure(
funcOp, "Only support FunctionOpInterface with FunctionType");
// 2. Convert signature
bool useBarePtrCallConv = shouldUseBarePtrCallConv(funcOp, &converter);
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
@@ -524,28 +465,29 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
if (failed(llvmType))
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
// 3. Validate C wrapper varargs constraint
// Validate C wrapper varargs constraint
bool emitCWrapper = funcOp->hasAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName());
if (!useBarePtrCallConv && emitCWrapper && llvmType->isVarArg())
return funcOp.emitError("C interface for variadic functions is not "
"supported yet.");
// 4. Lower function attrs
FailureOr<LoweredFuncAttrs> loweredAttrs = lowerFuncAttributes(funcOp);
// Lower function attrs
FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
lowerDiscardableAttrsForLLVMFunc(funcOp, *llvmType);
if (failed(loweredAttrs))
return rewriter.notifyMatchFailure(funcOp,
"failed to lower func attributes");
// 5. Create llvm.func
// Create llvm.func
auto newFuncOp = createLLVMFuncOp(funcOp, rewriter, *llvmType, *loweredAttrs,
symbolTables);
// 6. Propagate arg/result attrs
// Propagate arg/result attrs
propagateFunctionArgResAttrs(funcOp, rewriter, converter, result, *llvmType,
newFuncOp);
// 7. Inline body + signature conversion
// Inline body + signature conversion
rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
newFuncOp.end());
// Convert just the entry block. The remaining unstructured control flow is
@@ -554,14 +496,14 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
&converter);
// 8. Restore byval/byref pointee types
// Restore byval/byref pointee types
// Fix the type mismatch between the materialized `llvm.ptr` and the expected
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
// function arguments.
restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
newFuncOp);
// 9. C-wrapper handling
// C-wrapper handling
if (!useBarePtrCallConv && emitCWrapper)
wrapWithCInterface(funcOp, rewriter, converter, newFuncOp);

View File

@@ -14,10 +14,14 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "gpu-lowering"
using namespace mlir;
LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
@@ -74,6 +78,55 @@ LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
name, attr, alignment, addrSpace);
}
FailureOr<LoweredLLVMFuncAttrs> GPUFuncOpLowering::buildLoweredGPULLVMFuncAttrs(
gpu::GPUFuncOp gpuFuncOp, Type llvmFuncType, OpBuilder &rewriter) const {
FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
lowerDiscardableAttrsForLLVMFunc(gpuFuncOp, llvmFuncType);
if (failed(loweredAttrs))
return failure();
MLIRContext *ctx = rewriter.getContext();
LLVM::LLVMFuncOp::Properties &props = loweredAttrs->properties;
props.sym_name = rewriter.getStringAttr(gpuFuncOp.getName());
props.function_type = TypeAttr::get(llvmFuncType);
const bool isKernelFunc = gpuFuncOp.isKernel();
props.setCConv(LLVM::CConvAttr::get(ctx, isKernelFunc
? kernelCallingConvention
: nonKernelCallingConvention));
NamedAttrList &discardable = loweredAttrs->discardableAttrs;
auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
auto appendIfNameAndValue = [&](StringAttr name, Attribute value) {
if (name && value)
discardable.append(name, value);
};
DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
DenseI32ArrayAttr knownClusterSize = gpuFuncOp.getKnownClusterSizeAttr();
appendIfNameAndValue(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
knownBlockSize);
appendIfNameAndValue(gpuDialect->getKnownGridSizeAttrHelper().getName(),
knownGridSize);
appendIfNameAndValue(gpuDialect->getKnownClusterSizeAttrHelper().getName(),
knownClusterSize);
if (isKernelFunc) {
discardable.append(gpuDialect->getKernelFuncAttrName(),
rewriter.getUnitAttr());
// Add a dialect specific kernel attribute in addition to GPU kernel
// attribute. The former is necessary for further translation while the
// latter is expected by gpu.launch_func.
appendIfNameAndValue(kernelAttributeName, rewriter.getUnitAttr());
appendIfNameAndValue(kernelBlockSizeAttributeName, knownBlockSize);
appendIfNameAndValue(kernelClusterSizeAttributeName, knownClusterSize);
}
return loweredAttrs;
}
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -85,7 +138,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// workgroup attributions.
ArrayRef<BlockArgument> workgroupAttributions =
gpuFuncOp.getWorkgroupAttributions();
gpuFuncOp.getWorkgroupAttributionBBArgs();
size_t numAttributions = workgroupAttributions.size();
// Insert all arguments at the end.
@@ -136,7 +189,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
} else {
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
for (auto [idx, attribution] :
llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
llvm::enumerate(gpuFuncOp.getWorkgroupAttributionBBArgs())) {
auto type = dyn_cast<MemRefType>(attribution.getType());
assert(type && type.hasStaticShape() && "unexpected type in attribution");
@@ -174,67 +227,17 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
});
}
// Create the new function operation. Only copy those attributes that are
// not specific to function modeling.
SmallVector<NamedAttribute, 4> attributes;
ArrayAttr argAttrs;
for (const auto &attr : gpuFuncOp->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
attr.getName() ==
gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
attr.getName() == gpuFuncOp.getKnownGridSizeAttrName() ||
attr.getName() == gpuFuncOp.getKnownClusterSizeAttrName())
continue;
if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
argAttrs = gpuFuncOp.getArgAttrsAttr();
continue;
}
attributes.push_back(attr);
}
ArrayAttr argAttrs = gpuFuncOp.getArgAttrsAttr();
DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
DenseI32ArrayAttr knownClusterSize = gpuFuncOp.getKnownClusterSizeAttr();
// Ensure we don't lose information if the function is lowered before its
// surrounding context.
auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
if (knownBlockSize)
attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
knownBlockSize);
if (knownGridSize)
attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
knownGridSize);
if (knownClusterSize)
attributes.emplace_back(
gpuDialect->getKnownClusterSizeAttrHelper().getName(),
knownClusterSize);
FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
buildLoweredGPULLVMFuncAttrs(gpuFuncOp, funcType, rewriter);
if (failed(loweredAttrs))
return rewriter.notifyMatchFailure(gpuFuncOp,
"failed to lower func attributes");
// Add a dialect specific kernel attribute in addition to GPU kernel
// attribute. The former is necessary for further translation while the
// latter is expected by gpu.launch_func.
if (gpuFuncOp.isKernel()) {
if (kernelAttributeName)
attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
// Set the dialect-specific block size attribute if there is one.
if (kernelBlockSizeAttributeName && knownBlockSize) {
attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
}
// Set the dialect-specific cluster size attribute if there is one.
if (kernelClusterSizeAttributeName && knownClusterSize) {
attributes.emplace_back(kernelClusterSizeAttributeName, knownClusterSize);
}
}
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
? kernelCallingConvention
: nonKernelCallingConvention;
auto llvmFuncOp = LLVM::LLVMFuncOp::create(
rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
/*comdat=*/nullptr, attributes);
auto llvmFuncOp = LLVM::LLVMFuncOp::create(rewriter, gpuFuncOp.getLoc(),
loweredAttrs->properties,
loweredAttrs->discardableAttrs);
{
// Insert operations that correspond to converted workgroup and private
@@ -260,8 +263,9 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ArrayRef<BlockArgument> attributionArguments =
gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
numAttributions);
for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
for (auto [idx, vals] : llvm::enumerate(
llvm::zip_equal(gpuFuncOp.getWorkgroupAttributionBBArgs(),
attributionArguments))) {
auto [attribution, arg] = vals;
auto type = cast<MemRefType>(attribution.getType());
@@ -287,7 +291,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// existing memref infrastructure. This may use more registers than
// otherwise necessary given that memref sizes are fixed, but we can try
// and canonicalize that away later.
Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
Value attribution = gpuFuncOp.getWorkgroupAttributionBBArgs()[idx];
auto type = cast<MemRefType>(attribution.getType());
Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, memory);

View File

@@ -8,6 +8,7 @@
#ifndef MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
#define MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
#include "mlir/Conversion/LLVMCommon/LowerFunctionDiscardablesToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
@@ -106,6 +107,12 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
/// Lower discardable attrs like `func` lowering, then set `llvm.func`
/// properties and append GPU / target-specific discardable metadata.
FailureOr<LoweredLLVMFuncAttrs>
buildLoweredGPULLVMFuncAttrs(gpu::GPUFuncOp gpuFuncOp, Type llvmFuncType,
OpBuilder &rewriter) const;
private:
/// The address space to use for `alloca`s in private memory.
unsigned allocaAddrSpace;

View File

@@ -1,5 +1,6 @@
add_mlir_conversion_library(MLIRLLVMCommonConversion
ConversionTarget.cpp
LowerFunctionDiscardablesToLLVM.cpp
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp

View File

@@ -0,0 +1,60 @@
//===- LowerFunctionDiscardablesToLLVM.cpp - Func discardables to llvm ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/LowerFunctionDiscardablesToLLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/DebugLog.h"
using namespace mlir;
#define DEBUG_TYPE "lower-function-discardables-to-llvm"
FailureOr<LoweredLLVMFuncAttrs>
mlir::lowerDiscardableAttrsForLLVMFunc(FunctionOpInterface funcOp,
Type llvmFuncType) {
MLIRContext *ctx = funcOp->getContext();
LoweredLLVMFuncAttrs result;
result.properties.sym_name = StringAttr::get(ctx, funcOp.getName());
result.properties.function_type = TypeAttr::get(llvmFuncType);
llvm::SmallDenseSet<StringRef> odsAttrNames(
LLVM::LLVMFuncOp::getAttributeNames().begin(),
LLVM::LLVMFuncOp::getAttributeNames().end());
NamedAttrList inherentAttrs;
for (const NamedAttribute &attr : funcOp->getDiscardableAttrs()) {
StringRef attrName = attr.getName().strref();
if (odsAttrNames.contains(attrName)) {
LDBG() << "LLVM specific attributes: " << attrName
<< "should use llvm.* prefix, discarding it";
continue;
}
StringRef inherent = attrName;
if (inherent.consume_front("llvm.") && odsAttrNames.contains(inherent))
inherentAttrs.set(inherent, attr.getValue()); // collect inherent attrs
else
result.discardableAttrs.push_back(attr);
}
// Convert collected inherent attrs into typed properties.
if (!inherentAttrs.empty()) {
DictionaryAttr dict = inherentAttrs.getDictionary(ctx);
auto emitError = [&] {
return funcOp.emitOpError("invalid llvm.func property");
};
if (failed(LLVM::LLVMFuncOp::setPropertiesFromAttr(result.properties, dict,
emitError))) {
return failure();
}
}
return result;
}

View File

@@ -42,6 +42,7 @@
#include "llvm/Support/StringSaver.h"
#include <cassert>
#include <numeric>
#include <optional>
using namespace mlir;
using namespace mlir::gpu;
@@ -262,8 +263,10 @@ bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
}
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
if (auto gpuFunc = dyn_cast<GPUFuncOp>(op))
return gpuFunc.isKernel();
return static_cast<bool>(
op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()));
}
namespace {
@@ -713,10 +716,10 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
OpBuilder::InsertionGuard g(builder);
// Add a WorkGroup attribution attribute. This attribute is required to
// identify private attributions in the list of block argguments.
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(workgroupAttributions.size()));
if (!workgroupAttributions.empty())
result.addAttribute(
getWorkgroupAttributionsAttrName(result.name),
builder.getI64IntegerAttr(workgroupAttributions.size()));
// Add Op operands.
result.addOperands(asyncDependencies);
@@ -842,7 +845,7 @@ LogicalResult LaunchOp::verifyRegions() {
}
// Verify Attributions Address Spaces.
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
GPUDialect::getWorkgroupAddressSpace())) ||
failed(verifyAttributions(getOperation(), getPrivateAttributions(),
GPUDialect::getPrivateAddressSpace())))
@@ -921,7 +924,7 @@ void LaunchOp::print(OpAsmPrinter &p) {
p << ')';
}
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs());
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
p << ' ';
@@ -929,7 +932,7 @@ void LaunchOp::print(OpAsmPrinter &p) {
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
LaunchOp::getOperandSegmentSizeAttr(),
getNumWorkgroupAttributionsAttrName(),
getWorkgroupAttributionsAttrName(),
moduleAttrName, functionAttrName});
}
@@ -1072,12 +1075,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
// Create the region arguments, it has kNumConfigRegionAttributes arguments
// that correspond to block/thread identifiers and grid/block sizes, all
// having `index` type, a variadic number of WorkGroup Attributions and
// a variadic number of Private Attributions. The number of WorkGroup
// Attributions is stored in the attr with name:
// LaunchOp::getNumWorkgroupAttributionsAttrName().
// Create the region arguments: fixed launch-config args (`index`), then
// workgroup / private attribution args. The workgroup count is stored in the
// inherent `workgroup_attributions` attribute when non-zero.
Type index = parser.getBuilder().getIndexType();
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
LaunchOp::kNumConfigRegionAttributes + 6, index);
@@ -1101,8 +1101,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
unsigned numWorkgroupAttrs = regionArguments.size() -
LaunchOp::kNumConfigRegionAttributes -
(hasCluster ? 6 : 0);
result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(numWorkgroupAttrs));
if (numWorkgroupAttrs != 0)
result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(result.name),
builder.getI64IntegerAttr(numWorkgroupAttrs));
// Parse private memory attributions.
if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
@@ -1176,12 +1177,10 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
/// Adds a new block argument that corresponds to buffers located in
/// workgroup memory.
BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
auto attrName = getNumWorkgroupAttributionsAttrName();
auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
(*this)->setAttr(attrName,
IntegerAttr::get(attr.getType(), attr.getValue() + 1));
int64_t cur = getWorkgroupAttributions().value_or(0);
setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
return getBody().insertArgument(
LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
getNumConfigRegionAttributes() + static_cast<unsigned>(cur), type, loc);
}
/// Adds a new block argument that corresponds to buffers located in
@@ -1391,8 +1390,7 @@ LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return diag;
}
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName()))
if (!GPUDialect::isKernel(kernelFunc))
return launchOp.emitOpError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
@@ -1576,12 +1574,10 @@ void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
/// Adds a new block argument that corresponds to buffers located in
/// workgroup memory.
BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
auto attrName = getNumWorkgroupAttributionsAttrName();
auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
(*this)->setAttr(attrName,
IntegerAttr::get(attr.getType(), attr.getValue() + 1));
int64_t cur = getWorkgroupAttributions().value_or(0);
setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
return getBody().insertArgument(
getFunctionType().getNumInputs() + attr.getInt(), type, loc);
getFunctionType().getNumInputs() + static_cast<unsigned>(cur), type, loc);
}
/// Adds a new block argument that corresponds to buffers located in
@@ -1603,7 +1599,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
builder.getStringAttr(name));
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
result.addAttribute(getWorkgroupAttributionsAttrName(result.name),
builder.getI64IntegerAttr(workgroupAttributions.size()));
result.addAttributes(attrs);
Region *body = result.addRegion();
@@ -1712,8 +1708,10 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Store the number of operands we just parsed as the number of workgroup
// memory attributions.
unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(numWorkgroupAttrs));
if (numWorkgroupAttrs != 0)
result.addAttribute(
GPUFuncOp::getWorkgroupAttributionsAttrName(result.name),
builder.getI64IntegerAttr(numWorkgroupAttrs));
if (workgroupAttributionAttrs)
result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
workgroupAttributionAttrs);
@@ -1729,7 +1727,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the kernel attribute if present.
if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
result.addAttribute(GPUDialect::getKernelFuncAttrName(),
result.addAttribute(GPUFuncOp::getKernelAttrName(result.name),
builder.getUnitAttr());
// Parse attributes.
@@ -1751,7 +1749,7 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
/*isVariadic=*/false,
type.getResults());
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs(),
getWorkgroupAttribAttrs().value_or(nullptr));
printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
getPrivateAttribAttrs().value_or(nullptr));
@@ -1760,7 +1758,7 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionAttributes(
p, *this,
{getNumWorkgroupAttributionsAttrName(),
{getWorkgroupAttributionsAttrName(), getKernelAttrName(),
GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName(),
getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
@@ -1914,7 +1912,7 @@ LogicalResult GPUFuncOp::verifyBody() {
<< blockArgType;
}
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
GPUDialect::getWorkgroupAddressSpace())) ||
failed(verifyAttributions(getOperation(), getPrivateAttributions(),
GPUDialect::getPrivateAddressSpace())))

View File

@@ -197,10 +197,9 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
FunctionType::get(launchOp.getContext(), kernelOperandTypes, {});
auto outlinedFunc = gpu::GPUFuncOp::create(
builder, loc, kernelFnName, type,
TypeRange(ValueRange(launchOp.getWorkgroupAttributions())),
TypeRange(ValueRange(launchOp.getWorkgroupAttributionBBArgs())),
TypeRange(ValueRange(launchOp.getPrivateAttributions())));
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
outlinedFunc.setKernel(true);
// If we can infer bounds on the grid and/or block sizes from the arguments
// to the launch op, propagate them to the generated kernel. This is safe
@@ -227,8 +226,8 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
// Map memory attributions from the LaunOp op to the GPUFuncOp attributions.
for (const auto &[launchArg, funcArg] :
llvm::zip(launchOp.getWorkgroupAttributions(),
outlinedFunc.getWorkgroupAttributions()))
llvm::zip(launchOp.getWorkgroupAttributionBBArgs(),
outlinedFunc.getWorkgroupAttributionBBArgs()))
map.map(launchArg, funcArg);
for (const auto &[launchArg, funcArg] :
llvm::zip(launchOp.getPrivateAttributions(),

View File

@@ -84,8 +84,7 @@ static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
auto gpuFunc =
gpu::GPUFuncOp::create(builder, gpuModule->getLoc(), kernelName, type);
gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
gpuFunc.setKernel(true);
return gpuFunc;
}

View File

@@ -173,7 +173,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
gpuFuncOp, "expected gpu.func terminator to be gpu.return");
// Create a new function with the same signature and same attributes.
SmallVector<Type> workgroupAttributionsTypes =
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributionBBArgs(),
[](BlockArgument arg) { return arg.getType(); });
SmallVector<Type> privateAttributionsTypes =
llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),