112 lines
4.5 KiB
C++
112 lines
4.5 KiB
C++
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Common/static-multimap-view.h"
|
|
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
|
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
|
#include "flang/Optimizer/Transforms/Passes.h"
|
|
#include "flang/Runtime/entry-names.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
namespace fir {
|
|
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
|
|
#include "flang/Optimizer/Transforms/Passes.h.inc"
|
|
} // namespace fir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
class ConvertComplexPowPass
|
|
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
|
|
public:
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
|
|
arith::ArithDialect, func::FuncDialect>();
|
|
}
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
// Helper to declare or get a math library function.
|
|
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
|
|
StringRef name, FunctionType type) {
|
|
if (auto func = builder.getNamedFunction(name))
|
|
return func;
|
|
auto func = builder.createFunction(loc, name, type);
|
|
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
|
|
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
|
|
builder.getUnitAttr());
|
|
return func;
|
|
}
|
|
|
|
void ConvertComplexPowPass::runOnOperation() {
|
|
ModuleOp mod = getOperation();
|
|
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
|
|
|
|
mod.walk([&](Operation *op) {
|
|
if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
|
|
builder.setInsertionPoint(powIop);
|
|
Location loc = powIop.getLoc();
|
|
auto complexTy = cast<ComplexType>(powIop.getType());
|
|
auto elemTy = complexTy.getElementType();
|
|
Value base = powIop.getLhs();
|
|
Value intExp = powIop.getRhs();
|
|
func::FuncOp callee;
|
|
unsigned realBits = cast<FloatType>(elemTy).getWidth();
|
|
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
|
|
auto funcTy = builder.getFunctionType(
|
|
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
|
|
if (realBits == 32 && intBits == 32)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
|
|
else if (realBits == 32 && intBits == 64)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
|
|
else if (realBits == 64 && intBits == 32)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
|
|
else if (realBits == 64 && intBits == 64)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
|
|
else if (realBits == 128 && intBits == 32)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
|
|
else if (realBits == 128 && intBits == 64)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
|
|
else
|
|
return;
|
|
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
|
|
if (auto fmf = powIop.getFastmathAttr())
|
|
call.setFastmathAttr(fmf);
|
|
powIop.replaceAllUsesWith(call.getResult(0));
|
|
powIop.erase();
|
|
} else if (auto powOp = dyn_cast<complex::PowOp>(op)) {
|
|
builder.setInsertionPoint(powOp);
|
|
Location loc = powOp.getLoc();
|
|
auto complexTy = cast<ComplexType>(powOp.getType());
|
|
auto elemTy = complexTy.getElementType();
|
|
unsigned realBits = cast<FloatType>(elemTy).getWidth();
|
|
func::FuncOp callee;
|
|
auto funcTy =
|
|
builder.getFunctionType({complexTy, complexTy}, {complexTy});
|
|
if (realBits == 32)
|
|
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
|
|
else if (realBits == 64)
|
|
callee = getOrDeclare(builder, loc, "cpow", funcTy);
|
|
else if (realBits == 128)
|
|
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
|
|
else
|
|
return;
|
|
auto call = fir::CallOp::create(builder, loc, callee,
|
|
{powOp.getLhs(), powOp.getRhs()});
|
|
if (auto fmf = powOp.getFastmathAttr())
|
|
call.setFastmathAttr(fmf);
|
|
powOp.replaceAllUsesWith(call.getResult(0));
|
|
powOp.erase();
|
|
}
|
|
});
|
|
}
|