[mlir][spirv] Enforce SPIRV_Vector to have rank of one (#178185)

Currently only vector length is enforced however this allows vectors of
rank >1 to pass the verification as long as the length agrees. This
change restricts `SPIRV_Vector`s to be of rank 1 as required by the
SPIR-V spec.

This also fixes a bug where `SPIRV_Composite` allowed high ranked
vectors but `spirv::CompositeType` did not leading to cast assertions
where the composite type was assumed.

Finally, this change adds two new common constraints that can enforce
all three: rank, length and type.

fixes #178127
This commit is contained in:
Igor Wodiany
2026-01-27 15:08:25 +00:00
committed by GitHub
parent 3d07cc3c1b
commit e7063e8206
10 changed files with 57 additions and 27 deletions

View File

@@ -4242,7 +4242,7 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
@@ -4295,7 +4295,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
"Matrix">;
class SPIRV_VectorOf<Type type> :
FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>;
class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
@@ -4314,7 +4314,7 @@ class SPIRV_MatrixOf<Type type> :
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
class SPIRV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
class SPIRV_Vec4<Type type> : VectorOfRankAndLengthAndType<[1], [4], [type]>;
def SPIRV_IntVec4 : SPIRV_Vec4<SPIRV_Integer>;
def SPIRV_IOrUIVec4 : SPIRV_Vec4<SPIRV_SignlessOrUnsignedInt>;
def SPIRV_Int32Vec4 : SPIRV_Vec4<AnyI32>;

View File

@@ -649,6 +649,16 @@ class VectorOfLengthAndType<list<int> allowedLengths,
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes`
// list and the rank is from the given `allowedRanks` list
class VectorOfRankAndLengthAndType<list<int> allowedRanks,
list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
[VectorOfRank<allowedRanks>, VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>],
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
// Any vector where the number of elements is between
// `minLength` and `maxLength` (inclusive).
class VectorOfMinMaxLengthAndType<int minLength, int maxLength,
@@ -674,6 +684,18 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
// as the rank is from the given `allowedRanks` list
class FixedVectorOfRankAndLengthAndType<list<int> allowedRanks,
list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
[VectorOfRank<allowedRanks>, FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
FixedVectorOfAnyRank<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary #
VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,

View File

@@ -99,6 +99,14 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
// -----
func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
// expected-error @+1 {{op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
return %0: vector<4x2xi1>
}
// -----
//===----------------------------------------------------------------------===//
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//

View File

@@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
// -----
func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
// expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
// expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
return
}
@@ -1126,7 +1126,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () {
// -----
func.func @length_i32_in(%arg0 : i32) -> () {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GL.Length %arg0 : i32 -> f32
return
}
@@ -1142,7 +1142,7 @@ func.func @length_f16_in(%arg0 : f16) -> () {
// -----
func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32
return
}

View File

@@ -220,7 +220,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_scope(%value: vector<4xi32>)
// -----
func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi32>) -> i32 {
// expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<3xi32>'}}
// expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<3xi32> -> i32
return %0: i32
}
@@ -228,7 +228,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi
// -----
func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4xi8>) -> i32 {
// expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xi8>'}}
// expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xi8>'}}
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<4xi8> -> i32
return %0: i32
}
@@ -236,7 +236,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4x
// -----
func.func @group_non_uniform_ballot_bit_count_value_sign(%value: vector<4xsi32>) -> i32 {
// expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
// expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}}
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<4xsi32> -> i32
return %0: i32
}

View File

@@ -29,7 +29,7 @@ func.func @image_dref_gather_with_mismatch_imageoperands(%arg0 : !spirv.sampled_
// -----
func.func @image_dref_gather_error_result_type(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
// expected-error @+1 {{must be vector of 8/16/32/64-bit integer values of length 4 or vector of 16/32/64-bit float values of length 4}}
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit integer values of length 4 of ranks 1 or vector of 16/32/64-bit float values of length 4 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<3xi32>
spirv.Return
}
@@ -326,7 +326,7 @@ func.func @image_fetch_type_mismatch(%arg0: !spirv.image<f32, Dim2D, NoDepth, No
// -----
func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xsi32>) -> () {
// expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 or vector of 8/16/32/64-bit integer values of length 4, but got 'vector<2xf32>'}}
// expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 of ranks 1 or vector of 8/16/32/64-bit integer values of length 4 of ranks 1, but got 'vector<2xf32>'}}
%0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xsi32> -> vector<2xf32>
spirv.Return
}
@@ -334,7 +334,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArr
// -----
func.func @image_fetch_float_coords(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xf32>) -> () {
// expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}}
// expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<2xf32>'}}
%0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xf32> -> vector<2xf32>
spirv.Return
}

View File

@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
// expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
// expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
spirv.Return
}
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
// -----
spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
// expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
// expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
spirv.Return
}
@@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
// expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
// expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got 'f64'}}
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
spirv.Return
}

View File

@@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1)
func.func @logicalUnary(%arg0 : i32)
{
// expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
// expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.LogicalNot %arg0 : i32
return
}

View File

@@ -21,7 +21,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// -----
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}}
%0 = spirv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xsi32>
return %0: vector<4xsi32>
}
@@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
// -----
func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}}
%0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
return %0: bf16
}
@@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
// -----
func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}}
%0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
return %0: bf16
}
@@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
// expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
// expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
// expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
// expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
// expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
// expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -807,7 +807,7 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
%dir = spirv.Constant 0 : i32
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or bool or fixed-length vector of bool values of length 2/3/4/8/16, but got '!spirv.array<3 x i32>'}}
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
%0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
return %0: !spirv.array<3 x i32>
}

View File

@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
//===----------------------------------------------------------------------===//
func.func @ccr_result_not_composite() -> () {
// expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
// expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
return
}