Fixes #179847 math.ctlz expansion called getIntOrFloatBitWidth() on the operand type without checking. Index type has no fixed bitwidth and is not int/float, so the assertion in Type::getIntOrFloatBitWidth() could fire. - In convertCtlzOp, bail out with notifyMatchFailure when the element type is not integer or float, so expansion is only applied to types with a defined bitwidth. - Add a test in expand-math.mlir that math.ctlz on index is left unchanged by the pass (no crash, op preserved).
794 lines
33 KiB
C++
794 lines
33 KiB
C++
//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements expansion of various math operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir::math {
|
|
#define GEN_PASS_DEF_MATHEXPANDOPSPASS
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
|
|
} // namespace mlir::math
|
|
|
|
/// Create a float constant.
|
|
static Value createFloatConst(Location loc, Type type, APFloat value,
|
|
OpBuilder &b) {
|
|
bool losesInfo = false;
|
|
auto eltType = getElementTypeOrSelf(type);
|
|
// Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
|
|
value.convert(cast<FloatType>(eltType).getFloatSemantics(),
|
|
APFloat::rmNearestTiesToEven, &losesInfo);
|
|
auto attr = b.getFloatAttr(eltType, value);
|
|
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
|
|
return arith::ConstantOp::create(b, loc,
|
|
DenseElementsAttr::get(shapedTy, attr));
|
|
}
|
|
|
|
return arith::ConstantOp::create(b, loc, attr);
|
|
}
|
|
|
|
static Value createFloatConst(Location loc, Type type, double value,
|
|
OpBuilder &b) {
|
|
return createFloatConst(loc, type, APFloat(value), b);
|
|
}
|
|
|
|
/// Create an integer constant.
|
|
static Value createIntConst(Location loc, Type type, int64_t value,
|
|
OpBuilder &b) {
|
|
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
|
|
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
|
|
return arith::ConstantOp::create(b, loc,
|
|
DenseElementsAttr::get(shapedTy, attr));
|
|
}
|
|
|
|
return arith::ConstantOp::create(b, loc, attr);
|
|
}
|
|
|
|
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
|
|
Type opType = operand.getType();
|
|
Type i64Ty = b.getI64Type();
|
|
if (auto shapedTy = dyn_cast<ShapedType>(opType))
|
|
i64Ty = shapedTy.clone(i64Ty);
|
|
Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
|
|
Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
|
|
// The truncation does not preserve the sign when the truncated
|
|
// value is -0. So here the sign is copied again.
|
|
return math::CopySignOp::create(b, fpFixedConvert, operand);
|
|
}
|
|
|
|
// sinhf(float x) -> (exp(x) - exp(-x)) / 2
|
|
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
|
|
Value exp = math::ExpOp::create(b, operand);
|
|
Value neg = arith::NegFOp::create(b, operand);
|
|
Value nexp = math::ExpOp::create(b, neg);
|
|
Value sub = arith::SubFOp::create(b, exp, nexp);
|
|
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
|
|
Value res = arith::MulFOp::create(b, sub, half);
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
// coshf(float x) -> (exp(x) + exp(-x)) / 2
|
|
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
|
|
Value exp = math::ExpOp::create(b, operand);
|
|
Value neg = arith::NegFOp::create(b, operand);
|
|
Value nexp = math::ExpOp::create(b, neg);
|
|
Value add = arith::AddFOp::create(b, exp, nexp);
|
|
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
|
|
Value res = arith::MulFOp::create(b, add, half);
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
/// Expands tanh op into
|
|
/// 1-exp^{-2x} / 1+exp^{-2x}
|
|
/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
|
|
/// We compute a "signs" value which is -1 if input is negative and +1 if input
|
|
/// is positive. Then multiply the input by this value, guaranteeing that the
|
|
/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
|
|
/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
|
|
/// result by `sign(x)` to retain sign of the real result.
|
|
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
|
|
auto floatType = op.getOperand().getType();
|
|
Location loc = op.getLoc();
|
|
Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
|
|
Value one = createFloatConst(loc, floatType, 1.0, rewriter);
|
|
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
|
|
|
|
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
|
|
Value isNegative = arith::CmpFOp::create(
|
|
rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
|
|
Value isNegativeFloat =
|
|
arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
|
|
Value isNegativeTimesNegTwo =
|
|
arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
|
|
Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
|
|
|
|
// Normalize input to positive value: y = sign(x) * x
|
|
Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
|
|
|
|
// Decompose on normalized input
|
|
Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
|
|
Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
|
|
Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
|
|
Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
|
|
Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
|
|
|
|
// Multiply result by sign(x) to retain signs from negative inputs
|
|
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
|
|
|
|
return success();
|
|
}
|
|
|
|
// Converts math.tan to math.sin, math.cos, and arith.divf.
|
|
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type type = operand.getType();
|
|
Value sin = math::SinOp::create(b, type, operand);
|
|
Value cos = math::CosOp::create(b, type, operand);
|
|
Value div = arith::DivFOp::create(b, type, sin, cos);
|
|
rewriter.replaceOp(op, div);
|
|
return success();
|
|
}
|
|
|
|
// asinh(float x) -> log(x + sqrt(x**2 + 1))
|
|
static LogicalResult convertAsinhOp(math::AsinhOp op,
|
|
PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
|
|
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
|
|
Value fma = math::FmaOp::create(b, operand, operand, one);
|
|
Value sqrt = math::SqrtOp::create(b, fma);
|
|
Value add = arith::AddFOp::create(b, operand, sqrt);
|
|
Value res = math::LogOp::create(b, add);
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
// acosh(float x) -> log(x + sqrt(x**2 - 1))
|
|
static LogicalResult convertAcoshOp(math::AcoshOp op,
|
|
PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
|
|
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
|
|
Value fma = math::FmaOp::create(b, operand, operand, negOne);
|
|
Value sqrt = math::SqrtOp::create(b, fma);
|
|
Value add = arith::AddFOp::create(b, operand, sqrt);
|
|
Value res = math::LogOp::create(b, add);
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
// atanh(float x) -> log((1 + x) / (1 - x)) / 2
|
|
static LogicalResult convertAtanhOp(math::AtanhOp op,
|
|
PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
|
|
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
|
|
Value add = arith::AddFOp::create(b, operand, one);
|
|
Value neg = arith::NegFOp::create(b, operand);
|
|
Value sub = arith::AddFOp::create(b, neg, one);
|
|
Value div = arith::DivFOp::create(b, add, sub);
|
|
Value log = math::LogOp::create(b, div);
|
|
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
|
|
Value res = arith::MulFOp::create(b, log, half);
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operandA = op.getOperand(0);
|
|
Value operandB = op.getOperand(1);
|
|
Value operandC = op.getOperand(2);
|
|
Type type = op.getType();
|
|
Value mult = arith::MulFOp::create(b, type, operandA, operandB);
|
|
Value add = arith::AddFOp::create(b, type, mult, operandC);
|
|
rewriter.replaceOp(op, add);
|
|
return success();
|
|
}
|
|
|
|
// Converts a ceilf() function to the following:
|
|
// ceilf(float x) ->
|
|
// y = (float)(int) x
|
|
// if (x > y) then incr = 1 else incr = 0
|
|
// y = y + incr <= replace this op with the ceilf op.
|
|
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
|
|
// Creating constants assumes the static shaped type.
|
|
auto shapedType = dyn_cast<ShapedType>(op.getType());
|
|
if (shapedType && !shapedType.hasStaticShape())
|
|
return failure();
|
|
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
Type operandETy = getElementTypeOrSelf(opType);
|
|
FloatType floatTy = llvm::dyn_cast<FloatType>(operandETy);
|
|
const llvm::fltSemantics &semantics = floatTy.getFloatSemantics();
|
|
|
|
unsigned bitWidth = floatTy.getWidth();
|
|
unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
|
|
const int bias = (&semantics == &APFloat::Float8E8M0FNU())
|
|
? -semantics.minExponent
|
|
: -(semantics.minExponent - 1);
|
|
bool hasNegativeZeroNaNEncoding =
|
|
(semantics.nanEncoding == llvm::fltNanEncoding::NegativeZero);
|
|
|
|
Type iTy = rewriter.getIntegerType(bitWidth);
|
|
if (auto shapedTy = dyn_cast<ShapedType>(opType))
|
|
iTy = shapedTy.clone(iTy);
|
|
|
|
// For IEEE-like floating-point formats with an unbiased exponent ≥
|
|
// `mantissaWidth` falls into one of these categories:
|
|
// - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
|
|
// numbers are already integral, or
|
|
// - a special value (NaN or ±Inf), which also satisfies this exponent
|
|
// condition.
|
|
// For all such cases, `ceilf(x)` is defined to return `x` directly.
|
|
Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
|
|
Value cMask = createIntConst(
|
|
op->getLoc(), iTy, static_cast<int64_t>((1ull << (bitWidth - 1)) - 1), b);
|
|
Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
|
|
Value cThreshold = createIntConst(
|
|
op->getLoc(), iTy,
|
|
static_cast<int64_t>((uint64_t(bias + mantissaWidth)) << mantissaWidth),
|
|
b);
|
|
Value isLargeExp = arith::CmpIOp::create(b, arith::CmpIPredicate::uge,
|
|
unsignedBits, cThreshold);
|
|
Value isSpecialValOrLargeVal = isLargeExp;
|
|
|
|
// In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and
|
|
// all 0s in the exponent and mantissa, therefore requires an explicit check.
|
|
if (hasNegativeZeroNaNEncoding) {
|
|
Value cNegZeroBits = createIntConst(
|
|
op->getLoc(), iTy, static_cast<int64_t>(1ull << (bitWidth - 1)), b);
|
|
Value isNegZeroEncoding = arith::CmpIOp::create(
|
|
b, arith::CmpIPredicate::eq, operandBitcast, cNegZeroBits);
|
|
isSpecialValOrLargeVal =
|
|
arith::OrIOp::create(b, isLargeExp, isNegZeroEncoding);
|
|
}
|
|
|
|
Value fpFixedConvert = createTruncatedFPValue(operand, b);
|
|
|
|
// Creating constants for later use.
|
|
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
|
|
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
|
|
|
|
Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
|
|
fpFixedConvert);
|
|
Value incrValue =
|
|
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
|
|
|
|
Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
|
|
Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
|
|
rewriter.replaceOp(op, ret);
|
|
return success();
|
|
}
|
|
|
|
// Convert `math.fpowi` to a series of `arith.mulf` operations.
|
|
// If the power is negative, we divide one by the result.
|
|
// If both the base and power are zero, the result is 1.
|
|
// In the case of non constant power, we convert the operation to `math.powf`.
|
|
static LogicalResult convertFPowIOp(math::FPowIOp op,
|
|
PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value base = op.getOperand(0);
|
|
Value power = op.getOperand(1);
|
|
Type baseType = base.getType();
|
|
|
|
auto convertFPowItoPowf = [&]() -> LogicalResult {
|
|
Value castPowerToFp =
|
|
arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
|
|
Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
|
|
castPowerToFp);
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
};
|
|
|
|
Attribute cstAttr;
|
|
if (!matchPattern(power, m_Constant(&cstAttr)))
|
|
return convertFPowItoPowf();
|
|
|
|
APInt value;
|
|
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
|
|
return convertFPowItoPowf();
|
|
|
|
int64_t powerInt = value.getSExtValue();
|
|
bool isNegative = powerInt < 0;
|
|
int64_t absPower = std::abs(powerInt);
|
|
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
|
|
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
|
|
|
|
while (absPower > 0) {
|
|
if (absPower & 1)
|
|
res = arith::MulFOp::create(b, baseType, base, res);
|
|
absPower >>= 1;
|
|
base = arith::MulFOp::create(b, baseType, base, base);
|
|
}
|
|
|
|
// Make sure not to introduce UB in case of negative power.
|
|
if (isNegative) {
|
|
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
|
|
.getFloatSemantics();
|
|
Value zero =
|
|
createFloatConst(op->getLoc(), baseType,
|
|
APFloat::getZero(sem, /*Negative=*/false), rewriter);
|
|
Value negZero =
|
|
createFloatConst(op->getLoc(), baseType,
|
|
APFloat::getZero(sem, /*Negative=*/true), rewriter);
|
|
Value posInfinity =
|
|
createFloatConst(op->getLoc(), baseType,
|
|
APFloat::getInf(sem, /*Negative=*/false), rewriter);
|
|
Value negInfinity =
|
|
createFloatConst(op->getLoc(), baseType,
|
|
APFloat::getInf(sem, /*Negative=*/true), rewriter);
|
|
Value zeroEqCheck =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
|
|
Value negZeroEqCheck =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
|
|
res = arith::DivFOp::create(b, baseType, one, res);
|
|
res =
|
|
arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
|
|
res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
|
|
res);
|
|
}
|
|
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
|
|
// Some special cases where b is constant are handled separately:
|
|
// when b == 0, or |b| == 0.5, 1.0, or 2.0.
|
|
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operandA = op.getOperand(0);
|
|
Value operandB = op.getOperand(1);
|
|
auto typeA = operandA.getType();
|
|
auto typeB = operandB.getType();
|
|
|
|
auto &sem =
|
|
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
|
|
APFloat valueB(sem);
|
|
auto mulf = [&](Value x, Value y) -> Value {
|
|
return arith::MulFOp::create(b, x, y);
|
|
};
|
|
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
|
|
if (valueB.isZero()) {
|
|
// a^0 -> 1
|
|
Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
|
|
rewriter.replaceOp(op, one);
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(1.0)) {
|
|
// a^1 -> a
|
|
rewriter.replaceOp(op, operandA);
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(-1.0)) {
|
|
// a^(-1) -> 1 / a
|
|
Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
|
|
Value div = arith::DivFOp::create(b, one, operandA);
|
|
rewriter.replaceOp(op, div);
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(0.5)) {
|
|
// a^(1/2) -> sqrt(a)
|
|
Value sqrt = math::SqrtOp::create(b, operandA);
|
|
rewriter.replaceOp(op, sqrt);
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(-0.5)) {
|
|
// a^(-1/2) -> 1 / sqrt(a)
|
|
Value rsqrt = math::RsqrtOp::create(b, operandA);
|
|
rewriter.replaceOp(op, rsqrt);
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(2.0)) {
|
|
// a^2 -> a * a
|
|
rewriter.replaceOp(op, mulf(operandA, operandA));
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(-2.0)) {
|
|
// a^(-2) -> 1 / (a * a)
|
|
Value one =
|
|
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
|
|
Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
|
|
rewriter.replaceOp(op, div);
|
|
return success();
|
|
}
|
|
if (valueB.isExactlyValue(3.0)) {
|
|
rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
|
|
return success();
|
|
}
|
|
}
|
|
|
|
Value logA = math::LogOp::create(b, operandA);
|
|
Value mult = arith::MulFOp::create(b, operandB, logA);
|
|
Value expResult = math::ExpOp::create(b, mult);
|
|
rewriter.replaceOp(op, expResult);
|
|
return success();
|
|
}
|
|
|
|
// exp2f(float x) -> exp(x * ln(2))
|
|
// Proof: Let's say 2^x = y
|
|
// ln(2^x) = ln(y)
|
|
// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
|
|
static LogicalResult convertExp2fOp(math::Exp2Op op,
|
|
PatternRewriter &rewriter) {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
|
|
Value mult = arith::MulFOp::create(b, opType, operand, ln2);
|
|
Value exp = math::ExpOp::create(b, op->getLoc(), mult);
|
|
rewriter.replaceOp(op, exp);
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult convertRoundOp(math::RoundOp op,
|
|
PatternRewriter &rewriter) {
|
|
Location loc = op.getLoc();
|
|
ImplicitLocOpBuilder b(loc, rewriter);
|
|
Value operand = op.getOperand();
|
|
Type opType = operand.getType();
|
|
Type opEType = getElementTypeOrSelf(opType);
|
|
|
|
if (!opEType.isF32()) {
|
|
return rewriter.notifyMatchFailure(op, "not a round of f32.");
|
|
}
|
|
|
|
Type i32Ty = b.getI32Type();
|
|
if (auto shapedTy = dyn_cast<ShapedType>(opType))
|
|
i32Ty = shapedTy.clone(i32Ty);
|
|
|
|
Value half = createFloatConst(loc, opType, 0.5, b);
|
|
Value c23 = createIntConst(loc, i32Ty, 23, b);
|
|
Value c127 = createIntConst(loc, i32Ty, 127, b);
|
|
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
|
|
|
|
Value incrValue = math::CopySignOp::create(b, half, operand);
|
|
Value add = arith::AddFOp::create(b, opType, operand, incrValue);
|
|
Value fpFixedConvert = createTruncatedFPValue(add, b);
|
|
|
|
// There are three cases where adding 0.5 to the value and truncating by
|
|
// converting to an i64 does not result in the correct behavior:
|
|
//
|
|
// 1. Special values: +-inf and +-nan
|
|
// Casting these special values to i64 has undefined behavior. To identify
|
|
// these values, we use the fact that these values are the only float
|
|
// values with the maximum possible biased exponent.
|
|
//
|
|
// 2. Large values: 2^23 <= |x| <= INT_64_MAX
|
|
// Adding 0.5 to a float larger than or equal to 2^23 results in precision
|
|
// errors that sometimes round the value up and sometimes round the value
|
|
// down. For example:
|
|
// 8388608.0 + 0.5 = 8388608.0
|
|
// 8388609.0 + 0.5 = 8388610.0
|
|
//
|
|
// 3. Very large values: |x| > INT_64_MAX
|
|
// Casting to i64 a value greater than the max i64 value will overflow the
|
|
// i64 leading to wrong outputs.
|
|
//
|
|
// All three cases satisfy the property `biasedExp >= 23`.
|
|
Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
|
|
Value operandExp = arith::AndIOp::create(
|
|
b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
|
|
Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
|
|
Value isSpecialValOrLargeVal = arith::CmpIOp::create(
|
|
b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
|
|
|
|
Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
|
|
fpFixedConvert);
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
// Converts math.ctlz to scf and arith operations. This is done
|
|
// by performing a binary search on the bits.
|
|
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
|
|
PatternRewriter &rewriter) {
|
|
auto operand = op.getOperand();
|
|
auto operandTy = operand.getType();
|
|
auto eTy = getElementTypeOrSelf(operandTy);
|
|
Location loc = op.getLoc();
|
|
|
|
// Only expand for integer or float element types (index has no fixed bitwidth).
|
|
if (!eTy.isIntOrFloat()) {
|
|
return rewriter.notifyMatchFailure(op, "ctlz expansion only supports int or float types");
|
|
}
|
|
|
|
int32_t bitwidth = eTy.getIntOrFloatBitWidth();
|
|
if (bitwidth > 64)
|
|
return failure();
|
|
|
|
uint64_t allbits = -1;
|
|
if (bitwidth < 64) {
|
|
allbits = allbits >> (64 - bitwidth);
|
|
}
|
|
|
|
Value x = operand;
|
|
Value count = createIntConst(loc, operandTy, 0, rewriter);
|
|
for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
|
|
auto half = bw / 2;
|
|
auto bits = createIntConst(loc, operandTy, half, rewriter);
|
|
auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
|
|
|
|
Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
|
|
x, mask);
|
|
Value add = arith::AddIOp::create(rewriter, loc, count, bits);
|
|
Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
|
|
|
|
x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
|
|
count = arith::SelectOp::create(rewriter, loc, pred, add, count);
|
|
}
|
|
|
|
Value zero = createIntConst(loc, operandTy, 0, rewriter);
|
|
Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
|
|
operand, zero);
|
|
|
|
Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
|
|
Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
|
|
rewriter.replaceOp(op, sel);
|
|
return success();
|
|
}
|
|
|
|
// Convert `math.roundeven` into `math.round` + arith ops
|
|
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
|
|
PatternRewriter &rewriter) {
|
|
Location loc = op.getLoc();
|
|
ImplicitLocOpBuilder b(loc, rewriter);
|
|
auto operand = op.getOperand();
|
|
Type operandTy = operand.getType();
|
|
Type resultTy = op.getType();
|
|
Type operandETy = getElementTypeOrSelf(operandTy);
|
|
Type resultETy = getElementTypeOrSelf(resultTy);
|
|
|
|
if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
|
|
return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
|
|
}
|
|
|
|
Type fTy = operandTy;
|
|
Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
|
|
if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
|
|
iTy = shapedTy.clone(iTy);
|
|
}
|
|
|
|
unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
|
|
// The width returned by getFPMantissaWidth includes the integer bit.
|
|
unsigned mantissaWidth =
|
|
llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
|
|
unsigned exponentWidth = bitWidth - mantissaWidth - 1;
|
|
|
|
// The names of the variables correspond to f32.
|
|
// f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
|
|
// f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
|
|
// f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
|
|
Value c1Float = createFloatConst(loc, fTy, 1.0, b);
|
|
Value c0 = createIntConst(loc, iTy, 0, b);
|
|
Value c1 = createIntConst(loc, iTy, 1, b);
|
|
Value cNeg1 = createIntConst(loc, iTy, -1, b);
|
|
Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
|
|
Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
|
|
Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
|
|
Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
|
|
Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
|
|
Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
|
|
|
|
Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
|
|
Value round = math::RoundOp::create(b, operand);
|
|
Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
|
|
|
|
// Get biased exponents for operand and round(operand)
|
|
Value operandExp = arith::AndIOp::create(
|
|
b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
|
|
Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
|
|
Value roundExp = arith::AndIOp::create(
|
|
b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
|
|
Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
|
|
|
|
auto safeShiftRight = [&](Value x, Value shift) -> Value {
|
|
// Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
|
|
Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
|
|
clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
|
|
return arith::ShRUIOp::create(b, x, clampedShift);
|
|
};
|
|
|
|
auto maskMantissa = [&](Value mantissa,
|
|
Value mantissaMaskRightShift) -> Value {
|
|
Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
|
|
return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
|
|
};
|
|
|
|
// A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
|
|
// the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
|
|
// with `biasedExp > 23` (numbers where there is not enough precision to store
|
|
// decimals) are always even, and they satisfy the even condition trivially
|
|
// since the mantissa without all its bits is zero. The even condition
|
|
// is also true for +-0, since they have `biasedExp = -127` and the entire
|
|
// mantissa is zero. The case of +-1 has to be handled separately. Here
|
|
// we identify these values by noting that +-1 are the only whole numbers with
|
|
// `biasedExp == 0`.
|
|
//
|
|
// The special values +-inf and +-nan also satisfy the same property that
|
|
// whole non-unit even numbers satisfy. In particular, the special values have
|
|
// `biasedExp > 23`, so they get treated as large numbers with no room for
|
|
// decimals, which are always even.
|
|
Value roundBiasedExpEq0 =
|
|
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
|
|
Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
|
|
Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
|
|
Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
|
|
b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
|
|
roundIsNotEvenOrSpecialVal =
|
|
arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
|
|
|
|
// A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
|
|
// integers if the bit at index `biasedExp` starting from the left in the
|
|
// mantissa is 1 and all the bits to the right are zero. Values with
|
|
// `biasedExp >= 23` don't have decimals, so they are never halfway. The
|
|
// values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
|
|
// so these are handled separately. In particular, if `biasedExp == -1`, the
|
|
// value is halfway if the entire mantissa is zero.
|
|
Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
|
|
b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
|
|
Value expectedOperandMaskedMantissa = arith::SelectOp::create(
|
|
b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
|
|
Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
|
|
Value operandIsHalfway =
|
|
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
|
|
expectedOperandMaskedMantissa);
|
|
// Ensure `biasedExp` is in the valid range for half values.
|
|
Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
|
|
b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
|
|
Value operandBiasedExpLt23 = arith::CmpIOp::create(
|
|
b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
|
|
operandIsHalfway =
|
|
arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
|
|
operandIsHalfway =
|
|
arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
|
|
|
|
// Adjust rounded operand with `round(operand) - sign(operand)` to correct the
|
|
// case where `round` rounded in the opposite direction of `roundeven`.
|
|
Value sign = math::CopySignOp::create(b, c1Float, operand);
|
|
Value roundShifted = arith::SubFOp::create(b, round, sign);
|
|
// If the rounded value is even or a special value, we default to the behavior
|
|
// of `math.round`.
|
|
Value needsShift =
|
|
arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
|
|
Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
|
|
// The `x - sign` adjustment does not preserve the sign when we are adjusting
|
|
// the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
|
|
// rounded to -0.0.
|
|
result = math::CopySignOp::create(b, result, operand);
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
|
|
static LogicalResult convertRsqrtOp(math::RsqrtOp op,
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto operand = op.getOperand();
|
|
auto operandTy = operand.getType();
|
|
// Operand type must be shatic shaped type to create const float.
|
|
auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
|
|
if (shapedOperandType && !shapedOperandType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto eTy = getElementTypeOrSelf(operandTy);
|
|
if (!isa<FloatType>(eTy))
|
|
return failure();
|
|
|
|
Location loc = op->getLoc();
|
|
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
|
|
auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
|
|
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
|
|
return success();
|
|
}
|
|
|
|
// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
|
|
static LogicalResult convertClampfOp(math::ClampFOp op,
|
|
PatternRewriter &rewriter) {
|
|
auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
|
|
op.getMax(), op.getFastmath());
|
|
rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMin(),
|
|
op.getFastmath());
|
|
return success();
|
|
}
|
|
|
|
void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
|
|
ArrayRef<StringRef> opMnemonics) {
|
|
auto filter = [&](StringRef name) {
|
|
// This should be a static assert and `consume_front` take a twine, but none
|
|
// is currently possible. TODO: augment `StringRef::consume_front` and make
|
|
// `getDialectNamespace` use `std::string_view`.
|
|
assert("math" == MathDialect::getDialectNamespace());
|
|
name.consume_front("math.");
|
|
return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
|
|
};
|
|
if (filter(CountLeadingZerosOp::getOperationName()))
|
|
patterns.add(convertCtlzOp);
|
|
if (filter(SinhOp::getOperationName()))
|
|
patterns.add(convertSinhOp);
|
|
if (filter(CoshOp::getOperationName()))
|
|
patterns.add(convertCoshOp);
|
|
if (filter(TanOp::getOperationName()))
|
|
patterns.add(convertTanOp);
|
|
if (filter(TanhOp::getOperationName()))
|
|
patterns.add(convertTanhOp);
|
|
if (filter(AsinhOp::getOperationName()))
|
|
patterns.add(convertAsinhOp);
|
|
if (filter(AcoshOp::getOperationName()))
|
|
patterns.add(convertAcoshOp);
|
|
if (filter(AtanhOp::getOperationName()))
|
|
patterns.add(convertAtanhOp);
|
|
if (filter(FmaOp::getOperationName()))
|
|
patterns.add(convertFmaFOp);
|
|
if (filter(CeilOp::getOperationName()))
|
|
patterns.add(convertCeilOp);
|
|
if (filter(Exp2Op::getOperationName()))
|
|
patterns.add(convertExp2fOp);
|
|
if (filter(PowFOp::getOperationName()))
|
|
patterns.add(convertPowfOp);
|
|
if (filter(FPowIOp::getOperationName()))
|
|
patterns.add(convertFPowIOp);
|
|
if (filter(RoundOp::getOperationName()))
|
|
patterns.add(convertRoundOp);
|
|
if (filter(RoundEvenOp::getOperationName()))
|
|
patterns.add(convertRoundEvenOp);
|
|
if (filter(RsqrtOp::getOperationName()))
|
|
patterns.add(convertRsqrtOp);
|
|
if (filter(ClampFOp::getOperationName()))
|
|
patterns.add(convertClampfOp);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MathExpandOpsPass pass
|
|
//===----------------------------------------------------------------------===//
|
|
namespace {
|
|
struct MathExpandOpsPass final
|
|
: math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
|
|
using MathExpandOpsPassBase::MathExpandOpsPassBase;
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
SmallVector<StringRef> mnemonics =
|
|
llvm::to_vector_of<StringRef>(opMnemonics);
|
|
math::populateExpansionPatterns(patterns, mnemonics);
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|