Files
llvm-project/mlir/test/Dialect/common_folders.mlir
Matthias Guenther de2bac367f [MLIR] Allow constFoldBinaryOp to fold (T1, T1) -> T2 (#151410)
The `constFoldBinaryOp` helper function had limited support for
different input and output types, but the static type of the underlying
value (e.g. `APInt`) had to match between the inputs and the output.

This worked fine for int comparisons of the form `(intN, intN) -> int1`,
as the static type signature was `(APInt, APInt) -> APInt`. However,
float comparisons map `(floatN, floatN) -> int1`, with a static type
signature of `(APFloat, APFloat) -> APInt`. This use case wasn't
supported by `constFoldBinaryOp`.

`constFoldBinaryOp` now accepts an optional template argument overriding
the return type in case it differs from the input type. If the new
template argument isn't provided, the default behavior is unchanged
(i.e. the return type will be assumed to match the input type).

`constFoldUnaryOp` received similar changes in order to support folding
non-cast ops of the form `(T1) -> T2` (e.g. a `sign` op mapping
`(floatN) -> sint32`).
2025-08-07 17:52:03 +02:00

23 lines
982 B
MLIR

// RUN: mlir-opt %s --test-fold-type-converting-op --split-input-file | FileCheck %s
// CHECK-LABEL: @test_fold_unary_op_f32_to_si32(
func.func @test_fold_unary_op_f32_to_si32() -> tensor<4x2xsi32> {
// CHECK-NEXT: %[[POSITIVE_ONE:.*]] = arith.constant dense<1> : tensor<4x2xsi32>
// CHECK-NEXT: return %[[POSITIVE_ONE]] : tensor<4x2xsi32>
%operand = arith.constant dense<5.1> : tensor<4x2xf32>
%sign = test.sign %operand : (tensor<4x2xf32>) -> tensor<4x2xsi32>
return %sign : tensor<4x2xsi32>
}
// -----
// CHECK-LABEL: @test_fold_binary_op_f32_to_i1(
func.func @test_fold_binary_op_f32_to_i1() -> tensor<8xi1> {
// CHECK-NEXT: %[[FALSE:.*]] = arith.constant dense<false> : tensor<8xi1>
// CHECK-NEXT: return %[[FALSE]] : tensor<8xi1>
%lhs = arith.constant dense<5.1> : tensor<8xf32>
%rhs = arith.constant dense<4.2> : tensor<8xf32>
%less_than = test.less_than %lhs, %rhs : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xi1>
return %less_than : tensor<8xi1>
}