//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Utils.h" #include "mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::func; /// Helper function to look up or create the symbol for a runtime library /// function for a binary arithmetic operation. /// /// Parameter 1: APFloat semantics /// Parameter 2: Left-hand side operand /// Parameter 3: Right-hand side operand /// /// This function will return a failure if the function is found but has an /// unexpected signature. /// static FailureOr lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables = nullptr) { auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); return lookupOrCreateFnDecl(b, symTable, funcName, {i32Type, i64Type, i64Type}, symbolTables); } /// Rewrite a binary arithmetic operation to an APFloat function call. template struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), symTable(symTable), APFloatName(APFloatName) {}; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { if (failed(checkPreconditions(rewriter, op))) return failure(); // Get APFloat function from runtime library. FailureOr fn = lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); if (failed(fn)) return fn; // Scalarize and convert to APFloat runtime calls. Location loc = op.getLoc(); rewriter.setInsertionPoint(op); Value repl = forEachScalarValue( rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), [&](Value lhs, Value rhs, Type resultType) { // Cast operands to 64-bit integers. auto floatTy = cast(resultType); auto intWType = rewriter.getIntegerType(floatTy.getWidth()); auto int64Type = rewriter.getI64Type(); Value lhsBits = arith::ExtUIOp::create( rewriter, loc, int64Type, arith::BitcastOp::create(rewriter, loc, intWType, lhs)); Value rhsBits = arith::ExtUIOp::create( rewriter, loc, int64Type, arith::BitcastOp::create(rewriter, loc, intWType, rhs)); // Call APFloat function. Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, lhsBits, rhsBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), SymbolRefAttr::get(*fn), params); // Truncate result to the original width. Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0)); return arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits); }); rewriter.replaceOp(op, repl); return success(); } SymbolOpInterface symTable; const char *APFloatName; }; template struct FpToFpConversion final : OpRewritePattern { FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), symTable(symTable) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { if (failed(checkPreconditions(rewriter, op))) return failure(); // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert", {i32Type, i32Type, i64Type}); if (failed(fn)) return fn; // Scalarize and convert to APFloat runtime calls. Location loc = op.getLoc(); rewriter.setInsertionPoint(op); Value repl = forEachScalarValue( rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), [&](Value operand1, Value operand2, Type resultType) { // Cast operands to 64-bit integers. auto inFloatTy = cast(operand1.getType()); auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); Value operandBits = arith::ExtUIOp::create( rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); // Call APFloat function. Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy); auto outFloatTy = cast(resultType); Value outSemValue = getAPFloatSemanticsValue(rewriter, loc, outFloatTy); std::array params = {inSemValue, outSemValue, operandBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), SymbolRefAttr::get(*fn), params); // Truncate result to the original width. auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); Value truncatedBits = arith::TruncIOp::create( rewriter, loc, outIntWType, resultOp->getResult(0)); return arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits); }); rewriter.replaceOp(op, repl); return success(); } SymbolOpInterface symTable; }; template struct FpToIntConversion final : OpRewritePattern { FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), symTable(symTable), isUnsigned(isUnsigned) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { if (failed(checkPreconditions(rewriter, op))) return failure(); // Get APFloat function from runtime library. auto i1Type = IntegerType::get(symTable->getContext(), 1); auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int", {i32Type, i32Type, i1Type, i64Type}); if (failed(fn)) return fn; // Scalarize and convert to APFloat runtime calls. Location loc = op.getLoc(); rewriter.setInsertionPoint(op); Value repl = forEachScalarValue( rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), [&](Value operand1, Value operand2, Type resultType) { // Cast operands to 64-bit integers. auto inFloatTy = cast(operand1.getType()); auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); Value operandBits = arith::ExtUIOp::create( rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); // Call APFloat function. Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy); auto outIntTy = cast(resultType); Value outWidthValue = arith::ConstantOp::create( rewriter, loc, i32Type, rewriter.getIntegerAttr(i32Type, outIntTy.getWidth())); Value isUnsignedValue = arith::ConstantOp::create( rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned)); SmallVector params = {inSemValue, outWidthValue, isUnsignedValue, operandBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), SymbolRefAttr::get(*fn), params); // Truncate result to the original width. return arith::TruncIOp::create(rewriter, loc, outIntTy, resultOp->getResult(0)); }); rewriter.replaceOp(op, repl); return success(); } SymbolOpInterface symTable; bool isUnsigned; }; template struct IntToFpConversion final : OpRewritePattern { IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), symTable(symTable), isUnsigned(isUnsigned) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { if (failed(checkPreconditions(rewriter, op))) return failure(); // Get APFloat function from runtime library. auto i1Type = IntegerType::get(symTable->getContext(), 1); auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = lookupOrCreateFnDecl( rewriter, symTable, "_mlir_apfloat_convert_from_int", {i32Type, i32Type, i1Type, i64Type}); if (failed(fn)) return fn; // Scalarize and convert to APFloat runtime calls. Location loc = op.getLoc(); rewriter.setInsertionPoint(op); Value repl = forEachScalarValue( rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), [&](Value operand1, Value operand2, Type resultType) { // Cast operands to 64-bit integers. auto inIntTy = cast(operand1.getType()); Value operandBits = operand1; if (operandBits.getType().getIntOrFloatBitWidth() < 64) { if (isUnsigned) { operandBits = arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits); } else { operandBits = arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits); } } // Call APFloat function. auto outFloatTy = cast(resultType); Value outSemValue = getAPFloatSemanticsValue(rewriter, loc, outFloatTy); Value inWidthValue = arith::ConstantOp::create( rewriter, loc, i32Type, rewriter.getIntegerAttr(i32Type, inIntTy.getWidth())); Value isUnsignedValue = arith::ConstantOp::create( rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned)); SmallVector params = {outSemValue, inWidthValue, isUnsignedValue, operandBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), SymbolRefAttr::get(*fn), params); // Truncate result to the original width. auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); Value truncatedBits = arith::TruncIOp::create( rewriter, loc, outIntWType, resultOp->getResult(0)); return arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits); }); rewriter.replaceOp(op, repl); return success(); } SymbolOpInterface symTable; bool isUnsigned; }; struct CmpFOpToAPFloatConversion final : OpRewritePattern { CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), symTable(symTable) {} LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override { if (failed(checkPreconditions(rewriter, op))) return failure(); // Get APFloat function from runtime library. auto i1Type = IntegerType::get(symTable->getContext(), 1); auto i8Type = IntegerType::get(symTable->getContext(), 8); auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare", {i32Type, i64Type, i64Type}, nullptr, i8Type); if (failed(fn)) return fn; // Scalarize and convert to APFloat runtime calls. Location loc = op.getLoc(); rewriter.setInsertionPoint(op); Value repl = forEachScalarValue( rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), [&](Value lhs, Value rhs, Type resultType) { // Cast operands to 64-bit integers. auto floatTy = cast(lhs.getType()); auto intWType = rewriter.getIntegerType(floatTy.getWidth()); Value lhsBits = arith::ExtUIOp::create( rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, lhs)); Value rhsBits = arith::ExtUIOp::create( rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, rhs)); // Call APFloat function. Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, lhsBits, rhsBits}; Value comparisonResult = func::CallOp::create(rewriter, loc, TypeRange(i8Type), SymbolRefAttr::get(*fn), params) ->getResult(0); // Generate an i1 SSA value that is "true" if the comparison result // matches the given `val`. auto checkResult = [&](llvm::APFloat::cmpResult val) { return arith::CmpIOp::create( rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, arith::ConstantOp::create( rewriter, loc, i8Type, rewriter.getIntegerAttr(i8Type, static_cast(val))) .getResult()); }; // Generate an i1 SSA value that is "true" if the comparison result // matches any of the given `vals`. std::function)> checkResults = [&](ArrayRef vals) { Value first = checkResult(vals.front()); if (vals.size() == 1) return first; Value rest = checkResults(vals.drop_front()); return arith::OrIOp::create(rewriter, loc, first, rest) .getResult(); }; // This switch-case statement was taken from arith::applyCmpPredicate. Value result; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: result = arith::ConstantOp::create(rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, 0)) .getResult(); break; case arith::CmpFPredicate::OEQ: result = checkResult(llvm::APFloat::cmpEqual); break; case arith::CmpFPredicate::OGT: result = checkResult(llvm::APFloat::cmpGreaterThan); break; case arith::CmpFPredicate::OGE: result = checkResults( {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); break; case arith::CmpFPredicate::OLT: result = checkResult(llvm::APFloat::cmpLessThan); break; case arith::CmpFPredicate::OLE: result = checkResults( {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); break; case arith::CmpFPredicate::ONE: // Not cmpUnordered and not cmpUnordered. result = checkResults( {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); break; case arith::CmpFPredicate::ORD: // Not cmpUnordered. result = checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); break; case arith::CmpFPredicate::UEQ: result = checkResults( {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); break; case arith::CmpFPredicate::UGT: result = checkResults( {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); break; case arith::CmpFPredicate::UGE: result = checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); break; case arith::CmpFPredicate::ULT: result = checkResults( {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); break; case arith::CmpFPredicate::ULE: result = checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); break; case arith::CmpFPredicate::UNE: // Not cmpEqual. result = checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpUnordered}); break; case arith::CmpFPredicate::UNO: result = checkResult(llvm::APFloat::cmpUnordered); break; case arith::CmpFPredicate::AlwaysTrue: result = arith::ConstantOp::create(rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, 1)) .getResult(); break; } return result; }); rewriter.replaceOp(op, repl); return success(); } SymbolOpInterface symTable; }; /// Rewrite a unary floating-point op (same input/output float type) to an /// APFloat runtime call of the form `(i32 semantics, i64 bits) -> i64 bits`. template struct UnaryFloatOpToAPFloatConversion final : OpRewritePattern { UnaryFloatOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), symTable(symTable), APFloatName(APFloatName) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { if (failed(checkPreconditions(rewriter, op))) return failure(); // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); std::string funcName = (llvm::Twine("_mlir_apfloat_") + APFloatName).str(); FailureOr fn = lookupOrCreateFnDecl(rewriter, symTable, funcName, {i32Type, i64Type}); if (failed(fn)) return fn; // Scalarize and convert to APFloat runtime calls. Location loc = op.getLoc(); rewriter.setInsertionPoint(op); Value repl = forEachScalarValue( rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), [&](Value operand1, Value operand2, Type resultType) { // Cast operands to 64-bit integers. auto floatTy = cast(operand1.getType()); auto intWType = rewriter.getIntegerType(floatTy.getWidth()); Value operandBits = arith::ExtUIOp::create( rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, operand1)); // Call APFloat function. Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, operandBits}; Value resultBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type), SymbolRefAttr::get(*fn), params) ->getResult(0); // Truncate result to the original width. Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, resultBits); return arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits); }); rewriter.replaceOp(op, repl); return success(); } SymbolOpInterface symTable; const char *APFloatName; }; namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase { using Base::Base; void runOnOperation() override; }; void ArithToAPFloatConversionPass::runOnOperation() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add>(context, "add", getOperation()); patterns.add>( context, "subtract", getOperation()); patterns.add>( context, "multiply", getOperation()); patterns.add>( context, "divide", getOperation()); patterns.add>( context, "remainder", getOperation()); patterns.add>( context, "minnum", getOperation()); patterns.add>( context, "maxnum", getOperation()); patterns.add>( context, "minimum", getOperation()); patterns.add>( context, "maximum", getOperation()); patterns.add, FpToFpConversion, CmpFOpToAPFloatConversion>( context, getOperation()); patterns.add>(context, "neg", getOperation()); patterns.add>( context, "flush_denormals", getOperation()); patterns.add>(context, getOperation(), /*isUnsigned=*/false); patterns.add>(context, getOperation(), /*isUnsigned=*/true); patterns.add>(context, getOperation(), /*isUnsigned=*/false); patterns.add>(context, getOperation(), /*isUnsigned=*/true); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { result = failure(); } // NB: if you don't return failure, no other diag handlers will fire (see // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). return failure(); }); walkAndApplyPatterns(getOperation(), std::move(patterns)); if (failed(result)) return signalPassFailure(); } } // namespace