[mlir][GPU] Refactor GPUOps lowering (#188905)
This change promotes `gpu.func` / `gpu.launch` metadata that was previously carried as discardable attributes into proper inherent ODS fields (`kernel`, `workgroup_attributions`), renames the block-argument helpers to avoid clashing with generated getters, and routes `func.func` and `gpu.func` lowering through a shared helper that maps discardable `llvm.*` attributes into `llvm.func` properties. Downstream producers (Flang CUDA device func transform, kernel outlining, sparse GPU codegen, XeGPU) set kernels via `setKernel(true)` instead of manually attaching `gpu.kernel`. Fixes #185174 Assisted-by: CLion code completion, GPT 5.3 - Codex
This commit is contained in:
@@ -68,8 +68,7 @@ class CUFDeviceFuncTransform
|
||||
gpu::GPUFuncOp::create(builder, loc, funcOp.getName(), type,
|
||||
mlir::TypeRange{}, mlir::TypeRange{});
|
||||
if (isGlobal)
|
||||
deviceFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
deviceFuncOp.setKernel(true);
|
||||
|
||||
mlir::Region &deviceFuncBody = deviceFuncOp.getBody();
|
||||
mlir::Block &entryBlock = deviceFuncBody.front();
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
//===- LowerFunctionDiscardablesToLLVM.h - Func discardables to llvm - C++ -*-//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Shared helpers for lowering discardable attributes on any FunctionOpInterface
|
||||
// (e.g. func.func, gpu.func) into llvm.func properties and discardables.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERFUNCTIONDISCARDABLESTOLLVM_H
|
||||
#define MLIR_CONVERSION_LLVMCOMMON_LOWERFUNCTIONDISCARDABLESTOLLVM_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Result of lowering discardable attributes from a `FunctionOpInterface` to
|
||||
/// what `llvm.func` expects: typed inherent properties plus remaining
|
||||
/// discardable attributes.
|
||||
struct LoweredLLVMFuncAttrs {
|
||||
LLVM::LLVMFuncOp::Properties properties;
|
||||
NamedAttrList discardableAttrs;
|
||||
};
|
||||
|
||||
/// Partition `funcOp`'s discardables for `llvm.func`: `sym_name`,
|
||||
/// `function_type`, and typed `properties` from `llvm.*` ODS attrs; other
|
||||
/// discardables unchanged. Fails if that property set is invalid; drops
|
||||
/// ODS-named attrs without `llvm.`.
|
||||
FailureOr<LoweredLLVMFuncAttrs>
|
||||
lowerDiscardableAttrsForLLVMFunc(FunctionOpInterface funcOp, Type llvmFuncType);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_LLVMCOMMON_LOWERFUNCTIONDISCARDABLESTOLLVM_H
|
||||
@@ -426,7 +426,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
|
||||
OptionalAttr<DictArrayAttr>:$private_attrib_attrs,
|
||||
GPU_OptionalDimSizeHintAttr:$known_block_size,
|
||||
GPU_OptionalDimSizeHintAttr:$known_grid_size,
|
||||
GPU_OptionalDimSizeHintAttr:$known_cluster_size);
|
||||
GPU_OptionalDimSizeHintAttr:$known_cluster_size,
|
||||
OptionalAttr<ConfinedAttr<I64Attr, [IntNonNegative]>>:$workgroup_attributions,
|
||||
UnitAttr:$kernel);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
@@ -440,22 +442,24 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns `true` if the GPU function defined by this Op is a kernel, i.e.
|
||||
/// it is intended to be launched from host.
|
||||
/// it is intended to be launched from host. Also accepts legacy discardable
|
||||
/// `gpu.kernel` for compatibility with older generic IR.
|
||||
bool isKernel() {
|
||||
if (getKernel())
|
||||
return true;
|
||||
return (*this)->getAttrOfType<UnitAttr>(
|
||||
GPUDialect::getKernelFuncAttrName()) != nullptr;
|
||||
}
|
||||
|
||||
/// Returns the number of buffers located in the workgroup memory.
|
||||
unsigned getNumWorkgroupAttributions() {
|
||||
auto attr = (*this)->getAttrOfType<IntegerAttr>(
|
||||
getNumWorkgroupAttributionsAttrName());
|
||||
if (!attr)
|
||||
std::optional<int64_t> v = getWorkgroupAttributions();
|
||||
if (!v)
|
||||
return 0;
|
||||
int64_t value = attr.getInt();
|
||||
assert(value >= 0 && value < std::numeric_limits<uint32_t>::max() &&
|
||||
int64_t value = *v;
|
||||
assert(value < std::numeric_limits<uint32_t>::max() &&
|
||||
"invalid workgroup attribution count");
|
||||
return value;
|
||||
return static_cast<unsigned>(value);
|
||||
}
|
||||
|
||||
/// Return the index of the first workgroup attribution in the block argument
|
||||
@@ -466,7 +470,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
|
||||
|
||||
/// Returns a list of block arguments that correspond to buffers located in
|
||||
/// the workgroup memory
|
||||
ArrayRef<BlockArgument> getWorkgroupAttributions() {
|
||||
ArrayRef<BlockArgument> getWorkgroupAttributionBBArgs() {
|
||||
auto begin =
|
||||
std::next(getBody().args_begin(), getFirstWorkgroupAttributionIndex());
|
||||
auto end = std::next(begin, getNumWorkgroupAttributions());
|
||||
@@ -548,12 +552,6 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
|
||||
return setPrivateAttributionAttr(index, StringAttr::get((*this)->getContext(), name), value);
|
||||
}
|
||||
|
||||
/// Returns the name of the attribute containing the number of buffers
|
||||
/// located in the workgroup memory.
|
||||
static StringRef getNumWorkgroupAttributionsAttrName() {
|
||||
return "workgroup_attributions";
|
||||
}
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
|
||||
|
||||
@@ -806,7 +804,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
|
||||
Optional<Index>:$clusterSizeZ,
|
||||
Optional<I32>:$dynamicSharedMemorySize,
|
||||
OptionalAttr<FlatSymbolRefAttr>:$module,
|
||||
OptionalAttr<FlatSymbolRefAttr>:$function)>,
|
||||
OptionalAttr<FlatSymbolRefAttr>:$function,
|
||||
OptionalAttr<ConfinedAttr<I64Attr, [IntNonNegative]>>:$workgroup_attributions)>,
|
||||
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
|
||||
let summary = "GPU kernel launch operation";
|
||||
|
||||
@@ -1004,19 +1003,17 @@ def GPU_LaunchOp : GPU_Op<"launch", [
|
||||
|
||||
/// Returns the number of buffers located in the workgroup memory.
|
||||
unsigned getNumWorkgroupAttributions() {
|
||||
auto attr = (*this)->getAttrOfType<IntegerAttr>(
|
||||
getNumWorkgroupAttributionsAttrName());
|
||||
if (!attr)
|
||||
std::optional<int64_t> v = getWorkgroupAttributions();
|
||||
if (!v)
|
||||
return 0;
|
||||
int64_t value = attr.getInt();
|
||||
assert(value >= 0 && value < std::numeric_limits<uint32_t>::max() &&
|
||||
int64_t value = *v;
|
||||
assert(value < std::numeric_limits<uint32_t>::max() &&
|
||||
"invalid workgroup attribution count");
|
||||
return value;
|
||||
return static_cast<unsigned>(value);
|
||||
}
|
||||
|
||||
/// Returns a list of block arguments that correspond to buffers located in
|
||||
/// the workgroup memory
|
||||
ArrayRef<BlockArgument> getWorkgroupAttributions() {
|
||||
/// Block arguments for workgroup memory buffers
|
||||
ArrayRef<BlockArgument> getWorkgroupAttributionBBArgs() {
|
||||
auto begin =
|
||||
std::next(getBody().args_begin(), getNumConfigRegionAttributes());
|
||||
auto end = std::next(begin, getNumWorkgroupAttributions());
|
||||
@@ -1047,12 +1044,6 @@ def GPU_LaunchOp : GPU_Op<"launch", [
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
/// private memory.
|
||||
BlockArgument addPrivateAttribution(Type type, Location loc);
|
||||
|
||||
/// Returns the name of the attribute containing the number of buffers
|
||||
/// located in the workgroup memory.
|
||||
static StringRef getNumWorkgroupAttributionsAttrName() {
|
||||
return "workgroup_attributions";
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LowerFunctionDiscardablesToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
@@ -77,65 +78,6 @@ static void filterFuncAttributes(FunctionOpInterface func,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add custom lowered funcOp to llvm.func attributes here.
|
||||
struct LoweredFuncAttrs {
|
||||
LLVM::LLVMFuncOp::Properties properties;
|
||||
NamedAttrList discardableAttrs;
|
||||
};
|
||||
|
||||
/// Lower discardable function attributes on `func.func` to attributes expected
|
||||
/// by `llvm.func`.
|
||||
static FailureOr<LoweredFuncAttrs>
|
||||
lowerFuncAttributes(FunctionOpInterface func) {
|
||||
MLIRContext *ctx = func->getContext();
|
||||
LoweredFuncAttrs lowered;
|
||||
|
||||
llvm::SmallDenseSet<StringRef> odsAttrNames(
|
||||
LLVM::LLVMFuncOp::getAttributeNames().begin(),
|
||||
LLVM::LLVMFuncOp::getAttributeNames().end());
|
||||
|
||||
NamedAttrList inherentAttrs;
|
||||
|
||||
for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
|
||||
StringRef attrName = attr.getName().strref();
|
||||
|
||||
if (odsAttrNames.contains(attrName)) {
|
||||
LDBG() << "LLVM specific attributes: " << attrName
|
||||
<< "should use llvm.* prefix, discarding it";
|
||||
continue;
|
||||
}
|
||||
|
||||
StringRef inherent = attrName;
|
||||
if (inherent.consume_front("llvm.") && odsAttrNames.contains(inherent))
|
||||
inherentAttrs.set(inherent, attr.getValue()); // collect inherent attrs
|
||||
else
|
||||
lowered.discardableAttrs.push_back(attr);
|
||||
}
|
||||
|
||||
// Convert collected inherent attrs into typed properties.
|
||||
if (!inherentAttrs.empty()) {
|
||||
DictionaryAttr dict = inherentAttrs.getDictionary(ctx);
|
||||
auto emitError = [&] {
|
||||
return func.emitOpError("invalid llvm.func property");
|
||||
};
|
||||
if (failed(LLVM::LLVMFuncOp::setPropertiesFromAttr(lowered.properties, dict,
|
||||
emitError))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return lowered;
|
||||
}
|
||||
|
||||
static void buildLLVMFuncProperties(PatternRewriter &rewriter,
|
||||
FunctionOpInterface srcFunc,
|
||||
Type llvmFuncType,
|
||||
LLVM::LLVMFuncOp::Properties &props) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
props.sym_name = rewriter.getStringAttr(srcFunc.getName());
|
||||
props.function_type = TypeAttr::get(llvmFuncType);
|
||||
props.setCConv(LLVM::CConvAttr::get(ctx, LLVM::CConv::C));
|
||||
}
|
||||
|
||||
/// Propagate argument/results attributes.
|
||||
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
|
||||
FunctionOpInterface funcOp,
|
||||
@@ -369,14 +311,15 @@ static FailureOr<LLVM::LLVMFunctionType> convertFuncSignature(
|
||||
static LLVM::LLVMFuncOp createLLVMFuncOp(FunctionOpInterface funcOp,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
LLVM::LLVMFunctionType llvmType,
|
||||
LoweredFuncAttrs &loweredAttrs,
|
||||
LoweredLLVMFuncAttrs &loweredAttrs,
|
||||
SymbolTableCollection *symbolTables) {
|
||||
Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
|
||||
if (symbolTables && symbolTableOp) {
|
||||
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
|
||||
symbolTable.remove(funcOp);
|
||||
}
|
||||
buildLLVMFuncProperties(rewriter, funcOp, llvmType, loweredAttrs.properties);
|
||||
loweredAttrs.properties.setCConv(
|
||||
LLVM::CConvAttr::get(rewriter.getContext(), LLVM::CConv::C));
|
||||
auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, funcOp.getLoc(),
|
||||
loweredAttrs.properties,
|
||||
loweredAttrs.discardableAttrs);
|
||||
@@ -491,27 +434,25 @@ static void wrapWithCInterface(FunctionOpInterface funcOp,
|
||||
newFuncOp);
|
||||
}
|
||||
|
||||
// Conversion steps
|
||||
// 1. Validate function type
|
||||
// 2. Convert signature
|
||||
// 3. Validate C wrapper varargs constraint
|
||||
// 4. Lower function attrs
|
||||
// 5. Create llvm.func
|
||||
// 6. Propagate arg/result attrs
|
||||
// 7. Inline body + signature conversion
|
||||
// 8. Restore byval/byref pointee types
|
||||
// 9. C-wrapper handling
|
||||
/// Conversion steps
|
||||
/// - Validate function type
|
||||
/// - Convert signature
|
||||
/// - Validate C wrapper varargs constraint
|
||||
/// - Lower function attrs
|
||||
/// - Create llvm.func
|
||||
/// - Propagate arg/result attrs
|
||||
/// - Inline body + signature conversion
|
||||
/// - Restore byval/byref pointee types
|
||||
/// - C-wrapper handling
|
||||
FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
|
||||
FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
|
||||
const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
|
||||
// 1. Validate function type
|
||||
// Check the funcOp has `FunctionType`.
|
||||
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
|
||||
if (!funcTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
funcOp, "Only support FunctionOpInterface with FunctionType");
|
||||
|
||||
// 2. Convert signature
|
||||
bool useBarePtrCallConv = shouldUseBarePtrCallConv(funcOp, &converter);
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
@@ -524,28 +465,29 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
|
||||
if (failed(llvmType))
|
||||
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
|
||||
|
||||
// 3. Validate C wrapper varargs constraint
|
||||
// Validate C wrapper varargs constraint
|
||||
bool emitCWrapper = funcOp->hasAttrOfType<UnitAttr>(
|
||||
LLVM::LLVMDialect::getEmitCWrapperAttrName());
|
||||
if (!useBarePtrCallConv && emitCWrapper && llvmType->isVarArg())
|
||||
return funcOp.emitError("C interface for variadic functions is not "
|
||||
"supported yet.");
|
||||
|
||||
// 4. Lower function attrs
|
||||
FailureOr<LoweredFuncAttrs> loweredAttrs = lowerFuncAttributes(funcOp);
|
||||
// Lower function attrs
|
||||
FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
|
||||
lowerDiscardableAttrsForLLVMFunc(funcOp, *llvmType);
|
||||
if (failed(loweredAttrs))
|
||||
return rewriter.notifyMatchFailure(funcOp,
|
||||
"failed to lower func attributes");
|
||||
|
||||
// 5. Create llvm.func
|
||||
// Create llvm.func
|
||||
auto newFuncOp = createLLVMFuncOp(funcOp, rewriter, *llvmType, *loweredAttrs,
|
||||
symbolTables);
|
||||
|
||||
// 6. Propagate arg/result attrs
|
||||
// Propagate arg/result attrs
|
||||
propagateFunctionArgResAttrs(funcOp, rewriter, converter, result, *llvmType,
|
||||
newFuncOp);
|
||||
|
||||
// 7. Inline body + signature conversion
|
||||
// Inline body + signature conversion
|
||||
rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
// Convert just the entry block. The remaining unstructured control flow is
|
||||
@@ -554,14 +496,14 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
|
||||
&converter);
|
||||
|
||||
// 8. Restore byval/byref pointee types
|
||||
// Restore byval/byref pointee types
|
||||
// Fix the type mismatch between the materialized `llvm.ptr` and the expected
|
||||
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
|
||||
// function arguments.
|
||||
restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
|
||||
newFuncOp);
|
||||
|
||||
// 9. C-wrapper handling
|
||||
// C-wrapper handling
|
||||
if (!useBarePtrCallConv && emitCWrapper)
|
||||
wrapWithCInterface(funcOp, rewriter, converter, newFuncOp);
|
||||
|
||||
|
||||
@@ -14,10 +14,14 @@
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/DebugLog.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#define DEBUG_TYPE "gpu-lowering"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
|
||||
@@ -74,6 +78,55 @@ LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
|
||||
name, attr, alignment, addrSpace);
|
||||
}
|
||||
|
||||
FailureOr<LoweredLLVMFuncAttrs> GPUFuncOpLowering::buildLoweredGPULLVMFuncAttrs(
|
||||
gpu::GPUFuncOp gpuFuncOp, Type llvmFuncType, OpBuilder &rewriter) const {
|
||||
FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
|
||||
lowerDiscardableAttrsForLLVMFunc(gpuFuncOp, llvmFuncType);
|
||||
if (failed(loweredAttrs))
|
||||
return failure();
|
||||
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
LLVM::LLVMFuncOp::Properties &props = loweredAttrs->properties;
|
||||
props.sym_name = rewriter.getStringAttr(gpuFuncOp.getName());
|
||||
props.function_type = TypeAttr::get(llvmFuncType);
|
||||
const bool isKernelFunc = gpuFuncOp.isKernel();
|
||||
props.setCConv(LLVM::CConvAttr::get(ctx, isKernelFunc
|
||||
? kernelCallingConvention
|
||||
: nonKernelCallingConvention));
|
||||
|
||||
NamedAttrList &discardable = loweredAttrs->discardableAttrs;
|
||||
auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
|
||||
|
||||
auto appendIfNameAndValue = [&](StringAttr name, Attribute value) {
|
||||
if (name && value)
|
||||
discardable.append(name, value);
|
||||
};
|
||||
|
||||
DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
|
||||
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
|
||||
DenseI32ArrayAttr knownClusterSize = gpuFuncOp.getKnownClusterSizeAttr();
|
||||
|
||||
appendIfNameAndValue(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
|
||||
knownBlockSize);
|
||||
appendIfNameAndValue(gpuDialect->getKnownGridSizeAttrHelper().getName(),
|
||||
knownGridSize);
|
||||
appendIfNameAndValue(gpuDialect->getKnownClusterSizeAttrHelper().getName(),
|
||||
knownClusterSize);
|
||||
|
||||
if (isKernelFunc) {
|
||||
discardable.append(gpuDialect->getKernelFuncAttrName(),
|
||||
rewriter.getUnitAttr());
|
||||
// Add a dialect specific kernel attribute in addition to GPU kernel
|
||||
// attribute. The former is necessary for further translation while the
|
||||
// latter is expected by gpu.launch_func.
|
||||
appendIfNameAndValue(kernelAttributeName, rewriter.getUnitAttr());
|
||||
appendIfNameAndValue(kernelBlockSizeAttributeName, knownBlockSize);
|
||||
appendIfNameAndValue(kernelClusterSizeAttributeName, knownClusterSize);
|
||||
}
|
||||
|
||||
return loweredAttrs;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@@ -85,7 +138,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
// workgroup attributions.
|
||||
|
||||
ArrayRef<BlockArgument> workgroupAttributions =
|
||||
gpuFuncOp.getWorkgroupAttributions();
|
||||
gpuFuncOp.getWorkgroupAttributionBBArgs();
|
||||
size_t numAttributions = workgroupAttributions.size();
|
||||
|
||||
// Insert all arguments at the end.
|
||||
@@ -136,7 +189,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
} else {
|
||||
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
|
||||
for (auto [idx, attribution] :
|
||||
llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
|
||||
llvm::enumerate(gpuFuncOp.getWorkgroupAttributionBBArgs())) {
|
||||
auto type = dyn_cast<MemRefType>(attribution.getType());
|
||||
assert(type && type.hasStaticShape() && "unexpected type in attribution");
|
||||
|
||||
@@ -174,67 +227,17 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
});
|
||||
}
|
||||
|
||||
// Create the new function operation. Only copy those attributes that are
|
||||
// not specific to function modeling.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
ArrayAttr argAttrs;
|
||||
for (const auto &attr : gpuFuncOp->getAttrs()) {
|
||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||
attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
|
||||
attr.getName() ==
|
||||
gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
|
||||
attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
|
||||
attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
|
||||
attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
|
||||
attr.getName() == gpuFuncOp.getKnownGridSizeAttrName() ||
|
||||
attr.getName() == gpuFuncOp.getKnownClusterSizeAttrName())
|
||||
continue;
|
||||
if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
|
||||
argAttrs = gpuFuncOp.getArgAttrsAttr();
|
||||
continue;
|
||||
}
|
||||
attributes.push_back(attr);
|
||||
}
|
||||
ArrayAttr argAttrs = gpuFuncOp.getArgAttrsAttr();
|
||||
|
||||
DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
|
||||
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
|
||||
DenseI32ArrayAttr knownClusterSize = gpuFuncOp.getKnownClusterSizeAttr();
|
||||
// Ensure we don't lose information if the function is lowered before its
|
||||
// surrounding context.
|
||||
auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
|
||||
if (knownBlockSize)
|
||||
attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
|
||||
knownBlockSize);
|
||||
if (knownGridSize)
|
||||
attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
|
||||
knownGridSize);
|
||||
if (knownClusterSize)
|
||||
attributes.emplace_back(
|
||||
gpuDialect->getKnownClusterSizeAttrHelper().getName(),
|
||||
knownClusterSize);
|
||||
FailureOr<LoweredLLVMFuncAttrs> loweredAttrs =
|
||||
buildLoweredGPULLVMFuncAttrs(gpuFuncOp, funcType, rewriter);
|
||||
if (failed(loweredAttrs))
|
||||
return rewriter.notifyMatchFailure(gpuFuncOp,
|
||||
"failed to lower func attributes");
|
||||
|
||||
// Add a dialect specific kernel attribute in addition to GPU kernel
|
||||
// attribute. The former is necessary for further translation while the
|
||||
// latter is expected by gpu.launch_func.
|
||||
if (gpuFuncOp.isKernel()) {
|
||||
if (kernelAttributeName)
|
||||
attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
|
||||
// Set the dialect-specific block size attribute if there is one.
|
||||
if (kernelBlockSizeAttributeName && knownBlockSize) {
|
||||
attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
|
||||
}
|
||||
// Set the dialect-specific cluster size attribute if there is one.
|
||||
if (kernelClusterSizeAttributeName && knownClusterSize) {
|
||||
attributes.emplace_back(kernelClusterSizeAttributeName, knownClusterSize);
|
||||
}
|
||||
}
|
||||
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
|
||||
? kernelCallingConvention
|
||||
: nonKernelCallingConvention;
|
||||
auto llvmFuncOp = LLVM::LLVMFuncOp::create(
|
||||
rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
|
||||
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
|
||||
/*comdat=*/nullptr, attributes);
|
||||
auto llvmFuncOp = LLVM::LLVMFuncOp::create(rewriter, gpuFuncOp.getLoc(),
|
||||
loweredAttrs->properties,
|
||||
loweredAttrs->discardableAttrs);
|
||||
|
||||
{
|
||||
// Insert operations that correspond to converted workgroup and private
|
||||
@@ -260,8 +263,9 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
ArrayRef<BlockArgument> attributionArguments =
|
||||
gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
|
||||
numAttributions);
|
||||
for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
|
||||
gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
|
||||
for (auto [idx, vals] : llvm::enumerate(
|
||||
llvm::zip_equal(gpuFuncOp.getWorkgroupAttributionBBArgs(),
|
||||
attributionArguments))) {
|
||||
auto [attribution, arg] = vals;
|
||||
auto type = cast<MemRefType>(attribution.getType());
|
||||
|
||||
@@ -287,7 +291,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
// existing memref infrastructure. This may use more registers than
|
||||
// otherwise necessary given that memref sizes are fixed, but we can try
|
||||
// and canonicalize that away later.
|
||||
Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
|
||||
Value attribution = gpuFuncOp.getWorkgroupAttributionBBArgs()[idx];
|
||||
auto type = cast<MemRefType>(attribution.getType());
|
||||
Value descr = MemRefDescriptor::fromStaticShape(
|
||||
rewriter, loc, *getTypeConverter(), type, memory);
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#ifndef MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
|
||||
#define MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/LowerFunctionDiscardablesToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
|
||||
@@ -106,6 +107,12 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
|
||||
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
|
||||
/// Lower discardable attrs like `func` lowering, then set `llvm.func`
|
||||
/// properties and append GPU / target-specific discardable metadata.
|
||||
FailureOr<LoweredLLVMFuncAttrs>
|
||||
buildLoweredGPULLVMFuncAttrs(gpu::GPUFuncOp gpuFuncOp, Type llvmFuncType,
|
||||
OpBuilder &rewriter) const;
|
||||
|
||||
private:
|
||||
/// The address space to use for `alloca`s in private memory.
|
||||
unsigned allocaAddrSpace;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
add_mlir_conversion_library(MLIRLLVMCommonConversion
|
||||
ConversionTarget.cpp
|
||||
LowerFunctionDiscardablesToLLVM.cpp
|
||||
LoweringOptions.cpp
|
||||
MemRefBuilder.cpp
|
||||
Pattern.cpp
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
//===- LowerFunctionDiscardablesToLLVM.cpp - Func discardables to llvm ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/LowerFunctionDiscardablesToLLVM.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/DebugLog.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define DEBUG_TYPE "lower-function-discardables-to-llvm"
|
||||
|
||||
FailureOr<LoweredLLVMFuncAttrs>
|
||||
mlir::lowerDiscardableAttrsForLLVMFunc(FunctionOpInterface funcOp,
|
||||
Type llvmFuncType) {
|
||||
MLIRContext *ctx = funcOp->getContext();
|
||||
LoweredLLVMFuncAttrs result;
|
||||
|
||||
result.properties.sym_name = StringAttr::get(ctx, funcOp.getName());
|
||||
result.properties.function_type = TypeAttr::get(llvmFuncType);
|
||||
|
||||
llvm::SmallDenseSet<StringRef> odsAttrNames(
|
||||
LLVM::LLVMFuncOp::getAttributeNames().begin(),
|
||||
LLVM::LLVMFuncOp::getAttributeNames().end());
|
||||
|
||||
NamedAttrList inherentAttrs;
|
||||
|
||||
for (const NamedAttribute &attr : funcOp->getDiscardableAttrs()) {
|
||||
StringRef attrName = attr.getName().strref();
|
||||
|
||||
if (odsAttrNames.contains(attrName)) {
|
||||
LDBG() << "LLVM specific attributes: " << attrName
|
||||
<< "should use llvm.* prefix, discarding it";
|
||||
continue;
|
||||
}
|
||||
|
||||
StringRef inherent = attrName;
|
||||
if (inherent.consume_front("llvm.") && odsAttrNames.contains(inherent))
|
||||
inherentAttrs.set(inherent, attr.getValue()); // collect inherent attrs
|
||||
else
|
||||
result.discardableAttrs.push_back(attr);
|
||||
}
|
||||
|
||||
// Convert collected inherent attrs into typed properties.
|
||||
if (!inherentAttrs.empty()) {
|
||||
DictionaryAttr dict = inherentAttrs.getDictionary(ctx);
|
||||
auto emitError = [&] {
|
||||
return funcOp.emitOpError("invalid llvm.func property");
|
||||
};
|
||||
if (failed(LLVM::LLVMFuncOp::setPropertiesFromAttr(result.properties, dict,
|
||||
emitError))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -42,6 +42,7 @@
|
||||
#include "llvm/Support/StringSaver.h"
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::gpu;
|
||||
@@ -262,8 +263,10 @@ bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
|
||||
}
|
||||
|
||||
bool GPUDialect::isKernel(Operation *op) {
|
||||
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
return static_cast<bool>(isKernelAttr);
|
||||
if (auto gpuFunc = dyn_cast<GPUFuncOp>(op))
|
||||
return gpuFunc.isKernel();
|
||||
return static_cast<bool>(
|
||||
op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -713,10 +716,10 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
|
||||
FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
|
||||
// Add a WorkGroup attribution attribute. This attribute is required to
|
||||
// identify private attributions in the list of block argguments.
|
||||
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
|
||||
builder.getI64IntegerAttr(workgroupAttributions.size()));
|
||||
if (!workgroupAttributions.empty())
|
||||
result.addAttribute(
|
||||
getWorkgroupAttributionsAttrName(result.name),
|
||||
builder.getI64IntegerAttr(workgroupAttributions.size()));
|
||||
|
||||
// Add Op operands.
|
||||
result.addOperands(asyncDependencies);
|
||||
@@ -842,7 +845,7 @@ LogicalResult LaunchOp::verifyRegions() {
|
||||
}
|
||||
|
||||
// Verify Attributions Address Spaces.
|
||||
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
|
||||
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
|
||||
GPUDialect::getWorkgroupAddressSpace())) ||
|
||||
failed(verifyAttributions(getOperation(), getPrivateAttributions(),
|
||||
GPUDialect::getPrivateAddressSpace())))
|
||||
@@ -921,7 +924,7 @@ void LaunchOp::print(OpAsmPrinter &p) {
|
||||
p << ')';
|
||||
}
|
||||
|
||||
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
|
||||
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs());
|
||||
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
|
||||
|
||||
p << ' ';
|
||||
@@ -929,7 +932,7 @@ void LaunchOp::print(OpAsmPrinter &p) {
|
||||
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
|
||||
LaunchOp::getOperandSegmentSizeAttr(),
|
||||
getNumWorkgroupAttributionsAttrName(),
|
||||
getWorkgroupAttributionsAttrName(),
|
||||
moduleAttrName, functionAttrName});
|
||||
}
|
||||
|
||||
@@ -1072,12 +1075,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create the region arguments, it has kNumConfigRegionAttributes arguments
|
||||
// that correspond to block/thread identifiers and grid/block sizes, all
|
||||
// having `index` type, a variadic number of WorkGroup Attributions and
|
||||
// a variadic number of Private Attributions. The number of WorkGroup
|
||||
// Attributions is stored in the attr with name:
|
||||
// LaunchOp::getNumWorkgroupAttributionsAttrName().
|
||||
// Create the region arguments: fixed launch-config args (`index`), then
|
||||
// workgroup / private attribution args. The workgroup count is stored in the
|
||||
// inherent `workgroup_attributions` attribute when non-zero.
|
||||
Type index = parser.getBuilder().getIndexType();
|
||||
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
|
||||
LaunchOp::kNumConfigRegionAttributes + 6, index);
|
||||
@@ -1101,8 +1101,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
unsigned numWorkgroupAttrs = regionArguments.size() -
|
||||
LaunchOp::kNumConfigRegionAttributes -
|
||||
(hasCluster ? 6 : 0);
|
||||
result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
|
||||
builder.getI64IntegerAttr(numWorkgroupAttrs));
|
||||
if (numWorkgroupAttrs != 0)
|
||||
result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(result.name),
|
||||
builder.getI64IntegerAttr(numWorkgroupAttrs));
|
||||
|
||||
// Parse private memory attributions.
|
||||
if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
|
||||
@@ -1176,12 +1177,10 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
/// workgroup memory.
|
||||
BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
|
||||
auto attrName = getNumWorkgroupAttributionsAttrName();
|
||||
auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
|
||||
(*this)->setAttr(attrName,
|
||||
IntegerAttr::get(attr.getType(), attr.getValue() + 1));
|
||||
int64_t cur = getWorkgroupAttributions().value_or(0);
|
||||
setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
|
||||
return getBody().insertArgument(
|
||||
LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
|
||||
getNumConfigRegionAttributes() + static_cast<unsigned>(cur), type, loc);
|
||||
}
|
||||
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
@@ -1391,8 +1390,7 @@ LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
|
||||
GPUDialect::getKernelFuncAttrName()))
|
||||
if (!GPUDialect::isKernel(kernelFunc))
|
||||
return launchOp.emitOpError("kernel function is missing the '")
|
||||
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
|
||||
|
||||
@@ -1576,12 +1574,10 @@ void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
/// workgroup memory.
|
||||
BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
|
||||
auto attrName = getNumWorkgroupAttributionsAttrName();
|
||||
auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
|
||||
(*this)->setAttr(attrName,
|
||||
IntegerAttr::get(attr.getType(), attr.getValue() + 1));
|
||||
int64_t cur = getWorkgroupAttributions().value_or(0);
|
||||
setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
|
||||
return getBody().insertArgument(
|
||||
getFunctionType().getNumInputs() + attr.getInt(), type, loc);
|
||||
getFunctionType().getNumInputs() + static_cast<unsigned>(cur), type, loc);
|
||||
}
|
||||
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
@@ -1603,7 +1599,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
|
||||
builder.getStringAttr(name));
|
||||
result.addAttribute(getFunctionTypeAttrName(result.name),
|
||||
TypeAttr::get(type));
|
||||
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
|
||||
result.addAttribute(getWorkgroupAttributionsAttrName(result.name),
|
||||
builder.getI64IntegerAttr(workgroupAttributions.size()));
|
||||
result.addAttributes(attrs);
|
||||
Region *body = result.addRegion();
|
||||
@@ -1712,8 +1708,10 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Store the number of operands we just parsed as the number of workgroup
|
||||
// memory attributions.
|
||||
unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
|
||||
result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
|
||||
builder.getI64IntegerAttr(numWorkgroupAttrs));
|
||||
if (numWorkgroupAttrs != 0)
|
||||
result.addAttribute(
|
||||
GPUFuncOp::getWorkgroupAttributionsAttrName(result.name),
|
||||
builder.getI64IntegerAttr(numWorkgroupAttrs));
|
||||
if (workgroupAttributionAttrs)
|
||||
result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
|
||||
workgroupAttributionAttrs);
|
||||
@@ -1729,7 +1727,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
|
||||
// Parse the kernel attribute if present.
|
||||
if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
|
||||
result.addAttribute(GPUDialect::getKernelFuncAttrName(),
|
||||
result.addAttribute(GPUFuncOp::getKernelAttrName(result.name),
|
||||
builder.getUnitAttr());
|
||||
|
||||
// Parse attributes.
|
||||
@@ -1751,7 +1749,7 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
|
||||
/*isVariadic=*/false,
|
||||
type.getResults());
|
||||
|
||||
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
|
||||
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributionBBArgs(),
|
||||
getWorkgroupAttribAttrs().value_or(nullptr));
|
||||
printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
|
||||
getPrivateAttribAttrs().value_or(nullptr));
|
||||
@@ -1760,7 +1758,7 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
|
||||
|
||||
function_interface_impl::printFunctionAttributes(
|
||||
p, *this,
|
||||
{getNumWorkgroupAttributionsAttrName(),
|
||||
{getWorkgroupAttributionsAttrName(), getKernelAttrName(),
|
||||
GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
|
||||
getArgAttrsAttrName(), getResAttrsAttrName(),
|
||||
getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
|
||||
@@ -1914,7 +1912,7 @@ LogicalResult GPUFuncOp::verifyBody() {
|
||||
<< blockArgType;
|
||||
}
|
||||
|
||||
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
|
||||
if (failed(verifyAttributions(getOperation(), getWorkgroupAttributionBBArgs(),
|
||||
GPUDialect::getWorkgroupAddressSpace())) ||
|
||||
failed(verifyAttributions(getOperation(), getPrivateAttributions(),
|
||||
GPUDialect::getPrivateAddressSpace())))
|
||||
|
||||
@@ -197,10 +197,9 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
|
||||
FunctionType::get(launchOp.getContext(), kernelOperandTypes, {});
|
||||
auto outlinedFunc = gpu::GPUFuncOp::create(
|
||||
builder, loc, kernelFnName, type,
|
||||
TypeRange(ValueRange(launchOp.getWorkgroupAttributions())),
|
||||
TypeRange(ValueRange(launchOp.getWorkgroupAttributionBBArgs())),
|
||||
TypeRange(ValueRange(launchOp.getPrivateAttributions())));
|
||||
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
outlinedFunc.setKernel(true);
|
||||
|
||||
// If we can infer bounds on the grid and/or block sizes from the arguments
|
||||
// to the launch op, propagate them to the generated kernel. This is safe
|
||||
@@ -227,8 +226,8 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
|
||||
|
||||
// Map memory attributions from the LaunOp op to the GPUFuncOp attributions.
|
||||
for (const auto &[launchArg, funcArg] :
|
||||
llvm::zip(launchOp.getWorkgroupAttributions(),
|
||||
outlinedFunc.getWorkgroupAttributions()))
|
||||
llvm::zip(launchOp.getWorkgroupAttributionBBArgs(),
|
||||
outlinedFunc.getWorkgroupAttributionBBArgs()))
|
||||
map.map(launchArg, funcArg);
|
||||
for (const auto &[launchArg, funcArg] :
|
||||
llvm::zip(launchOp.getPrivateAttributions(),
|
||||
|
||||
@@ -84,8 +84,7 @@ static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
|
||||
FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
|
||||
auto gpuFunc =
|
||||
gpu::GPUFuncOp::create(builder, gpuModule->getLoc(), kernelName, type);
|
||||
gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
gpuFunc.setKernel(true);
|
||||
return gpuFunc;
|
||||
}
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
|
||||
gpuFuncOp, "expected gpu.func terminator to be gpu.return");
|
||||
// Create a new function with the same signature and same attributes.
|
||||
SmallVector<Type> workgroupAttributionsTypes =
|
||||
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
|
||||
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributionBBArgs(),
|
||||
[](BlockArgument arg) { return arg.getType(); });
|
||||
SmallVector<Type> privateAttributionsTypes =
|
||||
llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
|
||||
|
||||
Reference in New Issue
Block a user