104 lines
3.3 KiB
C++
104 lines
3.3 KiB
C++
//===-- CUFFunctionRewrite.cpp --------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Optimizer/CodeGen/TypeConverter.h"
|
|
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
|
#include "flang/Optimizer/Dialect/FIRType.h"
|
|
#include "flang/Optimizer/Support/DataLayout.h"
|
|
#include "flang/Optimizer/Transforms/Passes.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/ValueRange.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <string_view>
|
|
|
|
#define DEBUG_TYPE "flang-cuf-function-rewrite"
|
|
|
|
namespace fir {
|
|
#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
|
|
#include "flang/Optimizer/Transforms/Passes.h.inc"
|
|
} // namespace fir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
using genFunctionType =
|
|
std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;
|
|
|
|
class CallConversion : public OpRewritePattern<fir::CallOp> {
|
|
public:
|
|
CallConversion(MLIRContext *context)
|
|
: OpRewritePattern<fir::CallOp>(context) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(fir::CallOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto callee = op.getCallee();
|
|
if (!callee)
|
|
return failure();
|
|
auto name = callee->getRootReference().getValue();
|
|
|
|
if (genMappings_.contains(name)) {
|
|
auto fct = genMappings_.find(name);
|
|
mlir::Value result = fct->second(rewriter, op);
|
|
if (result)
|
|
rewriter.replaceOp(op, result);
|
|
else
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
|
|
fir::CallOp op) {
|
|
assert(op.getArgs().size() == 0 && "expect 0 arguments");
|
|
mlir::Location loc = op.getLoc();
|
|
unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
|
|
mlir::Type i1Ty = rewriter.getIntegerType(1);
|
|
mlir::Value t = mlir::arith::ConstantOp::create(
|
|
rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
|
|
return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
|
|
}
|
|
|
|
const llvm::StringMap<genFunctionType> genMappings_ = {
|
|
{"on_device", &genOnDevice}};
|
|
};
|
|
|
|
class CUFFunctionRewrite
|
|
: public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
|
|
public:
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
mlir::RewritePatternSet patterns(ctx);
|
|
|
|
patterns.insert<CallConversion>(patterns.getContext());
|
|
|
|
if (mlir::failed(
|
|
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
|
|
mlir::emitError(mlir::UnknownLoc::get(ctx),
|
|
"error in CUFFunctionRewrite op conversion\n");
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|