From e7063e820637498355a184a45c42c19fa58ff2f3 Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Tue, 27 Jan 2026 15:08:25 +0000 Subject: [PATCH] [mlir][spirv] Enforce `SPIRV_Vector` to have rank of one (#178185) Currently only vector length is enforced however this allows vectors of rank >1 to pass the verification as long as the length agrees. This change restricts `SPIRV_Vector`s to be of rank 1 as required by the SPIR-V spec. This also fixes a bug where `SPIRV_Composite` allowed high ranked vectors but `spirv::CompositeType` did not leading to cast assertions where the composite type was assumed. Finally, this change adds two new common constraints that can enforce all three: rank, length and type. fixes #178127 --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 6 ++--- mlir/include/mlir/IR/CommonTypeConstraints.td | 22 +++++++++++++++++++ mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 8 +++++++ mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 6 ++--- mlir/test/Dialect/SPIRV/IR/group-ops.mlir | 6 ++--- mlir/test/Dialect/SPIRV/IR/image-ops.mlir | 6 ++--- mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 6 ++--- mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 2 +- .../Dialect/SPIRV/IR/non-uniform-ops.mlir | 20 ++++++++--------- mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 2 +- 10 files changed, 57 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 4ea6d784dd88..f8093d3042c5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4242,7 +4242,7 @@ def SPIRV_BFloat16KHR : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>; -def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], +def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. @@ -4295,7 +4295,7 @@ class SPIRV_MatrixOfType allowedTypes> : "Matrix">; class SPIRV_VectorOf : - FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; + FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>; class SPIRV_ScalarOrVectorOf : AnyTypeOf<[type, SPIRV_VectorOf]>; @@ -4314,7 +4314,7 @@ class SPIRV_MatrixOf : def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; -class SPIRV_Vec4 : VectorOfLengthAndType<[4], [type]>; +class SPIRV_Vec4 : VectorOfRankAndLengthAndType<[1], [4], [type]>; def SPIRV_IntVec4 : SPIRV_Vec4; def SPIRV_IOrUIVec4 : SPIRV_Vec4; def SPIRV_Int32Vec4 : SPIRV_Vec4; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 0fb4837e528b..a49880b81e90 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -649,6 +649,16 @@ class VectorOfLengthAndType allowedLengths, VectorOfNonZeroRankOf.summary # VectorOfLength.summary, "::mlir::VectorType">; +// Any vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` +// list and the rank is from the given `allowedRanks` list +class VectorOfRankAndLengthAndType allowedRanks, + list allowedLengths, + list allowedTypes> : AllOfType< + [VectorOfRank, VectorOfNonZeroRankOf, VectorOfLength], + VectorOfNonZeroRankOf.summary # VectorOfLength.summary # VectorOfRank.summary, + "::mlir::VectorType">; + // Any vector where the number of elements is between // `minLength` and `maxLength` (inclusive). class VectorOfMinMaxLengthAndType allowedLengths, FixedVectorOfLength.summary, "::mlir::VectorType">; +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +// as the rank is from the given `allowedRanks` list +class FixedVectorOfRankAndLengthAndType allowedRanks, + list allowedLengths, + list allowedTypes> : AllOfType< + [VectorOfRank, FixedVectorOfAnyRank, FixedVectorOfLength], + FixedVectorOfAnyRank.summary # + FixedVectorOfLength.summary # + VectorOfRank.summary, + "::mlir::VectorType">; + // Any scalable vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class ScalableVectorOfLengthAndType allowedLengths, diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir index e71b545de11d..9323518f5037 100644 --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -99,6 +99,14 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 // ----- +func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> { + // expected-error @+1 {{op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}} + %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1> + return %0: vector<4x2xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.CompositeExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index bab12b183faf..eea80ca3798a 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) { - // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}} %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32> return } @@ -1126,7 +1126,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () { // ----- func.func @length_i32_in(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'i32'}} %0 = spirv.GL.Length %arg0 : i32 -> f32 return } @@ -1142,7 +1142,7 @@ func.func @length_f16_in(%arg0 : f16) -> () { // ----- func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}} %0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir index d7a4a6d92fcd..d26bfe9185bd 100644 --- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir @@ -220,7 +220,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_scope(%value: vector<4xi32>) // ----- func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi32>) -> i32 { - // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<3xi32>'}} + // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<3xi32>'}} %0 = spirv.GroupNonUniformBallotBitCount %value : vector<3xi32> -> i32 return %0: i32 } @@ -228,7 +228,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi // ----- func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4xi8>) -> i32 { - // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xi8>'}} + // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xi8>'}} %0 = spirv.GroupNonUniformBallotBitCount %value : vector<4xi8> -> i32 return %0: i32 } @@ -236,7 +236,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4x // ----- func.func @group_non_uniform_ballot_bit_count_value_sign(%value: vector<4xsi32>) -> i32 { - // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}} + // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}} %0 = spirv.GroupNonUniformBallotBitCount %value : vector<4xsi32> -> i32 return %0: i32 } diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir index 7369d719ca53..12b5f2ce62a6 100644 --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -29,7 +29,7 @@ func.func @image_dref_gather_with_mismatch_imageoperands(%arg0 : !spirv.sampled_ // ----- func.func @image_dref_gather_error_result_type(%arg0 : !spirv.sampled_image>, %arg1 : vector<4xf32>, %arg2 : f32) -> () { - // expected-error @+1 {{must be vector of 8/16/32/64-bit integer values of length 4 or vector of 16/32/64-bit float values of length 4}} + // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit integer values of length 4 of ranks 1 or vector of 16/32/64-bit float values of length 4 of ranks 1, but got 'vector<3xi32>'}} %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image>, vector<4xf32>, f32 -> vector<3xi32> spirv.Return } @@ -326,7 +326,7 @@ func.func @image_fetch_type_mismatch(%arg0: !spirv.image, %arg1: vector<2xsi32>) -> () { - // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 or vector of 8/16/32/64-bit integer values of length 4, but got 'vector<2xf32>'}} + // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 of ranks 1 or vector of 8/16/32/64-bit integer values of length 4 of ranks 1, but got 'vector<2xf32>'}} %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<2xf32> spirv.Return } @@ -334,7 +334,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image, %arg1: vector<2xf32>) -> () { - // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}} + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<2xf32>'}} %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xf32> -> vector<2xf32> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index 2e2fb1a9df32..d124c0223116 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 spirv.Return } @@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { // ----- spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { - // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 spirv.Return } @@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}} + // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got 'f64'}} %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index d7f4ed05969a..1018751cf65e 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1) func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}} %0 = spirv.LogicalNot %arg0 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index b22951f90510..168823a6e9c2 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -21,7 +21,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { // ----- func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> { - // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}} + // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}} %0 = spirv.GroupNonUniformBallot %predicate : vector<4xsi32> return %0: vector<4xsi32> } @@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto // ----- func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}} %0 = spirv.GroupNonUniformFMul %val : bf16 -> bf16 return %0: bf16 } @@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 { // ----- func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}} %0 = spirv.GroupNonUniformFMax %val : bf16 -> bf16 return %0: bf16 } @@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseAnd %val : i1 -> i1 return %0: i1 } @@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseOr %val : i1 -> i1 return %0: i1 } @@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseXor %val : i1 -> i1 return %0: i1 } @@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_and(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalAnd %val : i32 -> i32 return %0: i32 } @@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_or(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalOr %val : i32 -> i32 return %0: i32 } @@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalXor %val : i32 -> i32 return %0: i32 } @@ -807,7 +807,7 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> { func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> { %dir = spirv.Constant 0 : i32 - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or bool or fixed-length vector of bool values of length 2/3/4/8/16, but got '!spirv.array<3 x i32>'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}} %0 = spirv.GroupNonUniformQuadSwap %value %dir : !spirv.array<3 x i32>, i32 return %0: !spirv.array<3 x i32> } diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir index 20bb4eace370..2c5dc8b9f3b0 100644 --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () { //===----------------------------------------------------------------------===// func.func @ccr_result_not_composite() -> () { - // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}} + // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}} %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32 return }