[mlir][linalg] Restrict fill initial value type to output element type (#169567)

Disallow implicit casting, which is surprising, and, IME, usually
indicative of copy-paste errors.

Because the initial value must be a scalar, I don't expect this to
affect any data movement.
This commit is contained in:
Jakub Kuderski
2025-11-30 09:51:37 -05:00
committed by GitHub
parent b228256312
commit 0bd2f12753
11 changed files with 81 additions and 75 deletions

View File

@@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:
```python
@linalg_structured_op
def fill(value=ScalarDef(T1),
O=TensorDef(U, output=True)):
O[None] = TypeFn.cast_signed(U, value)
def fill(value=ScalarDef(T),
O=TensorDef(T, output=True)):
O[None] = value
```
The operation sets the elements of the output tensor `O` to `value`. All
operands are either scalars or rank zero tensors that are accessed using the
index `None`. The operation thus performs a scalar computation that trivially
extends to a multi-dimensional pointwise computation. As a result, we may use
`fill` with arbitrary ranked output tensors:
The operation sets the elements of the output tensor `O` to `value`. The value
type must match the element type of the output tensor. All operands are either
scalars or rank zero tensors that are accessed using the index `None`. The
operation thus performs a scalar computation that trivially extends to a
multi-dimensional pointwise computation. As a result, we may use `fill` with
arbitrary ranked output tensors:
```python
tensor_2d = tensor.EmptyOp([4, 8], f32)

View File

@@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
doc: |-
Fills the output tensor with the given value.
Works for arbitrary ranked output tensors since the operation performs scalar
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
Works for arbitrary ranked output tensors since the operation performs
scalar accesses only and is thus rank polymorphic. The value operand
type must match the element type of the output.
implements:
- LinalgFillOpInterface
defines:
@@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: value
kind: scalar
type_var: T1
type_var: T
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
type_var: T
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
@@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: value
scalar_arg: value
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d

View File

@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
namespace {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
NotScalarInput
NotScalarInput,
TypeMismatch
};
} // namespace
static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
// Check that the scalar input type matches the output element type.
OpOperand *output = linalgOp.getDpsInitOperand(0);
Type scalarType = value->get().getType();
Type outputElementType = getElementTypeOrSelf(output->get().getType());
if (scalarType != outputElementType)
return MatchFillResult::TypeMismatch;
return MatchFillResult::Success;
}
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
auto res = isFillInterfaceImpl(op);
MatchFillResult res = isFillInterfaceImpl(op);
if (res == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
if (res == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
if (res == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
if (res == MatchFillResult::TypeMismatch) {
auto linalgOp = cast<linalg::LinalgOp>(op);
Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
Type outputElementType =
getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
return op->emitOpError("expected fill value type (")
<< scalarType << ") to match output element type ("
<< outputElementType << ")";
}
return success();
}

View File

@@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(
@linalg_structured_op
def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
"""Fills the output tensor with the given value.
Works for arbitrary ranked output tensors since the operation performs scalar
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
accesses only and is thus rank polymorphic. The value type must match the
element type of the output tensor or memref.
"""
implements(FillOpInterface)
defines(Canonicalizer)
O[None] = TypeFn.cast_signed(U, value)
O[None] = value
@linalg_structured_op

View File

@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
// CHECK: "test.some_use"(%[[c5]])
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: "test.some_use"(%[[c5]])
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
scf.for %iv = %c0 to %ub step %c4 {
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
%filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()

View File

@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// -----
// CHECK-LABEL: func @fold_fill_generic_different_dtype
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
#map0 = affine_map<(d0) -> (d0)>
func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 7.0 : f32
%0 = tensor.dim %arg0, %c0 : tensor<?xf16>
%1 = tensor.empty(%0) : tensor<?xf16>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
%3 = tensor.empty(%0) : tensor<?xf16>
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
%5 = arith.addf %arg1, %arg2 : f16
linalg.yield %5 : f16
} -> tensor<?xf16>
return %4 : tensor<?xf16>
}
// -----
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
@@ -1079,4 +1055,4 @@ module {
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)

View File

@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
// -----
func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
%0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
%0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
}
@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
// -----
func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
return
}

View File

@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
// -----
func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
{
// expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
%0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
{
// expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
%0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// -----
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
{
// expected-error @+1 {{expected op with scalar input}}

View File

@@ -27,8 +27,8 @@ func.func @main() {
%A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
%B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
%c0_i32 = arith.constant 0 : i32
%C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
%c0_f32 = arith.constant 0.0 : f32
%C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
%res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>

View File

@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
}
func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
%cst = arith.constant 0.0 : f64
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<10x15xf32>
// expected-remark @below {{fill}}
%fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
%real_lhs = linalg.mul
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>

View File

@@ -25,13 +25,13 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} {
%O1 = memref.alloc() : memref<16xi32>
%O2 = memref.alloc() : memref<4x16xi32>
%val0 = arith.constant 1.0 : f32
%val1 = arith.constant 2.0 : f32
%val2 = arith.constant 3.0 : f32
%val0 = arith.constant 1 : i32
%val1 = arith.constant 2 : i32
%val2 = arith.constant 3 : i32
call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()
%c0 = arith.constant 0 : index
%res0 = memref.load %O0[] : memref<i32>
@@ -149,19 +149,18 @@ def transform(module, boilerplate):
def test_fill_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out])
@func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out])
@func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out])
@@ -184,19 +183,18 @@ test_fill_builtin()
def test_fill_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
@func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
@func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
@func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)