[mlir][linalg] Restrict fill initial value type to output element type (#169567)
Disallow implicit casting, which is surprising, and, IME, usually indicative of copy-paste errors. Because the initial value must be a scalar, I don't expect this to affect any data movement.
This commit is contained in:
@@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:
|
||||
|
||||
```python
|
||||
@linalg_structured_op
|
||||
def fill(value=ScalarDef(T1),
|
||||
O=TensorDef(U, output=True)):
|
||||
O[None] = TypeFn.cast_signed(U, value)
|
||||
def fill(value=ScalarDef(T),
|
||||
O=TensorDef(T, output=True)):
|
||||
O[None] = value
|
||||
```
|
||||
|
||||
The operation sets the elements of the output tensor `O` to `value`. All
|
||||
operands are either scalars or rank zero tensors that are accessed using the
|
||||
index `None`. The operation thus performs a scalar computation that trivially
|
||||
extends to a multi-dimensional pointwise computation. As a result, we may use
|
||||
`fill` with arbitrary ranked output tensors:
|
||||
The operation sets the elements of the output tensor `O` to `value`. The value
|
||||
type must match the element type of the output tensor. All operands are either
|
||||
scalars or rank zero tensors that are accessed using the index `None`. The
|
||||
operation thus performs a scalar computation that trivially extends to a
|
||||
multi-dimensional pointwise computation. As a result, we may use `fill` with
|
||||
arbitrary ranked output tensors:
|
||||
|
||||
```python
|
||||
tensor_2d = tensor.EmptyOp([4, 8], f32)
|
||||
|
||||
@@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
|
||||
doc: |-
|
||||
Fills the output tensor with the given value.
|
||||
|
||||
Works for arbitrary ranked output tensors since the operation performs scalar
|
||||
accesses only and is thus rank polymorphic. Numeric casting is performed on
|
||||
the value operand, promoting it to the same data type as the output.
|
||||
Works for arbitrary ranked output tensors since the operation performs
|
||||
scalar accesses only and is thus rank polymorphic. The value operand
|
||||
type must match the element type of the output.
|
||||
implements:
|
||||
- LinalgFillOpInterface
|
||||
defines:
|
||||
@@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !LinalgOperandDefConfig
|
||||
name: value
|
||||
kind: scalar
|
||||
type_var: T1
|
||||
type_var: T
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
type_var: T
|
||||
shape_map: affine_map<() -> ()>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
@@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: value
|
||||
scalar_arg: value
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: fill_rng_2d
|
||||
|
||||
@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
|
||||
// FillOpInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
enum class MatchFillResult {
|
||||
Success = 0,
|
||||
NotLinalgOp,
|
||||
WrongNumOperands,
|
||||
NotScalarInput
|
||||
NotScalarInput,
|
||||
TypeMismatch
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static MatchFillResult isFillInterfaceImpl(Operation *op) {
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
||||
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
|
||||
if (!linalgOp.isScalar(value))
|
||||
return MatchFillResult::NotScalarInput;
|
||||
|
||||
// Check that the scalar input type matches the output element type.
|
||||
OpOperand *output = linalgOp.getDpsInitOperand(0);
|
||||
Type scalarType = value->get().getType();
|
||||
Type outputElementType = getElementTypeOrSelf(output->get().getType());
|
||||
if (scalarType != outputElementType)
|
||||
return MatchFillResult::TypeMismatch;
|
||||
|
||||
return MatchFillResult::Success;
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
|
||||
auto res = isFillInterfaceImpl(op);
|
||||
MatchFillResult res = isFillInterfaceImpl(op);
|
||||
if (res == MatchFillResult::NotLinalgOp)
|
||||
return op->emitError("expected a LinalgOp");
|
||||
if (res == MatchFillResult::WrongNumOperands)
|
||||
return op->emitError("expected op with 1 input and 1 output");
|
||||
if (res == MatchFillResult::NotScalarInput)
|
||||
return op->emitError("expected op with scalar input");
|
||||
if (res == MatchFillResult::TypeMismatch) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
|
||||
Type outputElementType =
|
||||
getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
|
||||
return op->emitOpError("expected fill value type (")
|
||||
<< scalarType << ") to match output element type ("
|
||||
<< outputElementType << ")";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
||||
def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
|
||||
"""Fills the output tensor with the given value.
|
||||
|
||||
Works for arbitrary ranked output tensors since the operation performs scalar
|
||||
accesses only and is thus rank polymorphic. Numeric casting is performed on
|
||||
the value operand, promoting it to the same data type as the output.
|
||||
accesses only and is thus rank polymorphic. The value type must match the
|
||||
element type of the output tensor or memref.
|
||||
"""
|
||||
implements(FillOpInterface)
|
||||
defines(Canonicalizer)
|
||||
O[None] = TypeFn.cast_signed(U, value)
|
||||
O[None] = value
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
||||
@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
|
||||
// CHECK: "test.some_use"(%[[c5]])
|
||||
// CHECK: %[[c5:.*]] = arith.constant 5 : index
|
||||
// CHECK: "test.some_use"(%[[c5]])
|
||||
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
|
||||
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
scf.for %iv = %c0 to %ub step %c4 {
|
||||
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
|
||||
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
|
||||
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
|
||||
%filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
|
||||
|
||||
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
|
||||
"test.some_use"(%bound) : (index) -> ()
|
||||
|
||||
@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_fill_generic_different_dtype
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
|
||||
// CHECK-NOT: linalg.fill
|
||||
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
|
||||
// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 7.0 : f32
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?xf16>
|
||||
%1 = tensor.empty(%0) : tensor<?xf16>
|
||||
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
|
||||
%3 = tensor.empty(%0) : tensor<?xf16>
|
||||
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
|
||||
^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
|
||||
%5 = arith.addf %arg1, %arg2 : f16
|
||||
linalg.yield %5 : f16
|
||||
} -> tensor<?xf16>
|
||||
return %4 : tensor<?xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
|
||||
// CHECK-NOT: linalg.fill
|
||||
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
|
||||
@@ -1079,4 +1055,4 @@ module {
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: tensor.expand_shape
|
||||
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
|
||||
|
||||
@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
|
||||
%0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
|
||||
func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
|
||||
%0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
|
||||
return %0: tensor<f32>
|
||||
}
|
||||
|
||||
@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
|
||||
linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
|
||||
func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
|
||||
linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
|
||||
{
|
||||
// expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
|
||||
%0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
|
||||
{
|
||||
// expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
|
||||
%0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
|
||||
{
|
||||
// expected-error @+1 {{expected op with scalar input}}
|
||||
|
||||
@@ -27,8 +27,8 @@ func.func @main() {
|
||||
%A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
|
||||
%B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
|
||||
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%c0_f32 = arith.constant 0.0 : f32
|
||||
%C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
|
||||
}
|
||||
|
||||
func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
|
||||
%cst = arith.constant 0.0 : f64
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%empty = tensor.empty() : tensor<10x15xf32>
|
||||
|
||||
// expected-remark @below {{fill}}
|
||||
%fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
|
||||
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
|
||||
|
||||
%real_lhs = linalg.mul
|
||||
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||
|
||||
@@ -25,13 +25,13 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} {
|
||||
%O1 = memref.alloc() : memref<16xi32>
|
||||
%O2 = memref.alloc() : memref<4x16xi32>
|
||||
|
||||
%val0 = arith.constant 1.0 : f32
|
||||
%val1 = arith.constant 2.0 : f32
|
||||
%val2 = arith.constant 3.0 : f32
|
||||
%val0 = arith.constant 1 : i32
|
||||
%val1 = arith.constant 2 : i32
|
||||
%val2 = arith.constant 3 : i32
|
||||
|
||||
call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
|
||||
call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
|
||||
call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
|
||||
call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
|
||||
call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
|
||||
call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()
|
||||
|
||||
%c0 = arith.constant 0 : index
|
||||
%res0 = memref.load %O0[] : memref<i32>
|
||||
@@ -149,19 +149,18 @@ def transform(module, boilerplate):
|
||||
def test_fill_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
|
||||
@func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
|
||||
def fill_0d_on_buffers(value, out):
|
||||
linalg.fill(value, outs=[out])
|
||||
|
||||
@func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
|
||||
@func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
|
||||
def fill_1d_on_buffers(value, out):
|
||||
linalg.fill(value, outs=[out])
|
||||
|
||||
@func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
|
||||
@func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
|
||||
def fill_2d_on_buffers(value, out):
|
||||
linalg.fill(value, outs=[out])
|
||||
|
||||
@@ -184,19 +183,18 @@ test_fill_builtin()
|
||||
def test_fill_generic():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
|
||||
@func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
|
||||
def fill_0d_on_buffers(value, out):
|
||||
linalg.fill(value, outs=[out], emit_generic=True)
|
||||
|
||||
@func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
|
||||
@func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
|
||||
def fill_1d_on_buffers(value, out):
|
||||
linalg.fill(value, outs=[out], emit_generic=True)
|
||||
|
||||
@func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
|
||||
@func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
|
||||
def fill_2d_on_buffers(value, out):
|
||||
linalg.fill(value, outs=[out], emit_generic=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user