[mlir][spirv] Add SPV_EXT_FP8 type support to SPIR-V TOSA ops (#193199)
Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types. Also update verifier coverage to reflect the new constraints: - refresh existing negative tests whose diagnostics now list FP8 types - add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
This commit is contained in:
@@ -156,14 +156,20 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
|
||||
TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
|
||||
TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
|
||||
TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
|
||||
TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
|
||||
TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
|
||||
TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
|
||||
TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
|
||||
TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
|
||||
TypeConstraintImplicationOn<"input", F8E4M3FN, "weight", [F8E4M3FN]>,
|
||||
TypeConstraintImplicationOn<"input", F8E5M2, "weight", [F8E5M2]>,
|
||||
TypeImpliesAccType<"input", I8, ["INT32"]>,
|
||||
TypeImpliesAccType<"input", I16, ["INT48"]>,
|
||||
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
|
||||
TypeImpliesAccType<"input", BF16, ["FP32"]>,
|
||||
TypeImpliesAccType<"input", F32, ["FP32"]>,
|
||||
TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
|
||||
TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
|
||||
AllElementTypesMatch<["bias", "output"]>,
|
||||
AllElementTypesMatch<["input", "input_zp"]>,
|
||||
AllElementTypesMatch<["weight", "weight_zp"]>])> {
|
||||
@@ -249,7 +255,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
|
||||
let arguments = (ins
|
||||
SPIRV_TensorArmAxisAttr: $axis,
|
||||
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm: $input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -277,6 +283,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
|
||||
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
|
||||
TypeImpliesAccType<"input", BF16, ["FP32"]>,
|
||||
TypeImpliesAccType<"input", F32, ["FP32"]>,
|
||||
TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
|
||||
TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
|
||||
AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
|
||||
let summary = "Performs average pooling on the input.";
|
||||
|
||||
@@ -304,13 +312,13 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
|
||||
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
|
||||
SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
|
||||
SPIRV_TosaExtAccTypeAttr: $acc_type,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $output_zp
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -361,11 +369,11 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
|
||||
SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
|
||||
SPIRV_TosaExtAccTypeAttr: $acc_type,
|
||||
SPIRV_BoolConstAttr: $local_bound,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
|
||||
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -416,11 +424,11 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
|
||||
SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
|
||||
SPIRV_TosaExtAccTypeAttr: $acc_type,
|
||||
SPIRV_BoolConstAttr: $local_bound,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16_TensorArm5D: $weight,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm5D: $weight,
|
||||
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -472,11 +480,11 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
|
||||
SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
|
||||
SPIRV_TosaExtAccTypeAttr: $acc_type,
|
||||
SPIRV_BoolConstAttr: $local_bound,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
|
||||
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -557,6 +565,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
|
||||
TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
|
||||
TypeConstraintImplicationOn<"A", F16, "output", [F16, F32]>,
|
||||
TypeConstraintImplicationOn<"A", F32, "output", [F32]>,
|
||||
TypeConstraintImplicationOn<"A", F8E4M3FN, "output", [F16]>,
|
||||
TypeConstraintImplicationOn<"A", F8E5M2, "output", [F16]>,
|
||||
AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
|
||||
let summary = "Matrix Multiplication operator.";
|
||||
|
||||
@@ -579,10 +589,10 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $A,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $B,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $A_zp,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $B_zp
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D: $A,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D: $B,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $A_zp,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $B_zp
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -634,11 +644,11 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
|
||||
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
|
||||
SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
|
||||
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -734,11 +744,11 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
|
||||
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
|
||||
SPIRV_TosaExtAccTypeAttr: $acc_type,
|
||||
SPIRV_BoolConstAttr: $local_bound,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
|
||||
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
|
||||
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
|
||||
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -2167,11 +2177,11 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_TensorArmAxisAttr: $axis,
|
||||
Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm>: $input1
|
||||
Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm>: $input1
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2214,13 +2224,13 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
|
||||
SPIRV_I32_1DTensorArmOfEvenLength2To12: $padding,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $pad_const
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $pad_const
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2267,12 +2277,12 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
|
||||
SPIRV_I32_1DTensorArmOfLength1To6: $shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2315,11 +2325,11 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_TensorArmAxisAttr: $axis,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2362,13 +2372,13 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
|
||||
SPIRV_I32_1DTensorArmOfLength1To6: $start,
|
||||
SPIRV_I32_1DTensorArmOfLength1To6: $size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2416,12 +2426,12 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
|
||||
SPIRV_I32_1DTensorArmOfLength1To6: $multiples
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2466,11 +2476,11 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_I32_1DTensorArmOfLength1To6Attr: $perms,
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2512,12 +2522,12 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values,
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values,
|
||||
SPIRV_I32_TensorArm2D: $indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $output
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2566,13 +2576,13 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_in,
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values_in,
|
||||
SPIRV_I32_TensorArm2D: $indices,
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $input
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_out
|
||||
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values_out
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -2687,13 +2697,15 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
|
||||
|
||||
def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
|
||||
AllShapesMatch<["input", "output"]>,
|
||||
TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
|
||||
TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
|
||||
TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
|
||||
TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16, F8E4M3FN, F8E5M2]>,
|
||||
TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
|
||||
TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
|
||||
TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
|
||||
TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
|
||||
TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
|
||||
TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
|
||||
TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16, F32, BF16]>,
|
||||
TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16, F32, BF16]>]> {
|
||||
let summary = "Cast operation.";
|
||||
|
||||
let description = [{
|
||||
@@ -2737,6 +2749,18 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
|
||||
| int16 | bf16 |
|
||||
| int32 | bf16 |
|
||||
| int8 | bf16 |
|
||||
| bf16 | fp8e4m3 |
|
||||
| fp8e4m3 | bf16 |
|
||||
| bf16 | fp8e5m2 |
|
||||
| fp8e5m2 | bf16 |
|
||||
| float16 | fp8e4m3 |
|
||||
| float32 | fp8e4m3 |
|
||||
| fp8e4m3 | float16 |
|
||||
| fp8e4m3 | float32 |
|
||||
| float16 | fp8e5m2 |
|
||||
| float32 | fp8e5m2 |
|
||||
| fp8e5m2 | float16 |
|
||||
| fp8e5m2 | float32 |
|
||||
|
||||
References:
|
||||
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
|
||||
@@ -2750,11 +2774,11 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm: $input
|
||||
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm: $output
|
||||
SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8_TensorArm: $output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
||||
@@ -22,16 +22,19 @@ def SPIRV_I8OrI16OrI32OrI64 : AnyIntOfWidths<[8, 16, 32, 64]>;
|
||||
def SPIRV_I16OrI32 : AnyIntOfWidths<[16, 32]>;
|
||||
def SPIRV_I32OrI64 : AnyIntOfWidths<[32, 64]>;
|
||||
def SPIRV_F16OrF32OrBF16 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_F16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_I32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int32, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_I32OrI64OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_I32OrI64OrF16OrF32 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_Float16, SPIRV_Float32]>;
|
||||
def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32OrI64, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_I8OrI32 : AnyTypeOf<[SPIRV_Int8, SPIRV_Int32]>;
|
||||
|
||||
def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
|
||||
@@ -57,23 +60,25 @@ def SPIRV_I32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
|
||||
def SPIRV_F32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
|
||||
def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [1]>;
|
||||
def SPIRV_I8OrI16_TensorArm1D : TensorArmRankOf<[SPIRV_I8OrI16], [1]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [3]>;
|
||||
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [3]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [3]>;
|
||||
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8], [3]>;
|
||||
def SPIRV_I32OrI64OrF16OrF32_TensorArm3D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32], [3]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [4]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [4]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [4]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16OrFP8], [4]>;
|
||||
def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [4]>;
|
||||
def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16], [4]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [5]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [5]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [5]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16OrFP8], [5]>;
|
||||
def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [5]>;
|
||||
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_F16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_F16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrI32OrI64_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrI64], [1, 2, 3, 4, 5, 6]>;
|
||||
def SPIRV_I8OrI16OrI32_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32], [1, 2, 3, 4, 5, 6]>;
|
||||
@@ -121,12 +126,12 @@ def SPIRV_I32_1DTensorArmOfLength1To6Attr : ConfinedAttr<
|
||||
I32ElementsAttr, [SPIRV_DenseElementAttrsWithTensorArmType, Is1DTensorArmAttrOfLength<[1, 2, 3, 4, 5, 6]>]>;
|
||||
|
||||
def SPIRV_I8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrF16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32OrF16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrI16OrI32OrI64_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32OrI64]>;
|
||||
def SPIRV_I8OrI16OrI32_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrF16OrF32OrBF16]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16]>;
|
||||
def SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrF16OrF32OrBF16OrFP8]>;
|
||||
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8]>;
|
||||
|
||||
// Struct type
|
||||
|
||||
|
||||
@@ -75,6 +75,22 @@ spirv.ARM.Graph @avgpool2d_accumulator_should_be_either_FP32_for_fp32_element_ty
|
||||
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x65532x2xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e4m3fn_element_type(%arg0: !spirv.arm.tensor<1x2x2x2xf8E4M3FN>) -> (!spirv.arm.tensor<1x2x2x2xf8E4M3FN>) {
|
||||
%4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
|
||||
%6 = spirv.Tosa.AvgPool2D kernel = [1, 1], stride = [1, 1], pad = [0, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x2x2xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x2x2x2xf8E4M3FN>
|
||||
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E4M3FN>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e5m2_element_type(%arg0: !spirv.arm.tensor<1x2x2x2xf8E5M2>) -> (!spirv.arm.tensor<1x2x2x2xf8E5M2>) {
|
||||
%4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
|
||||
%6 = spirv.Tosa.AvgPool2D kernel = [1, 1], stride = [1, 1], pad = [0, 0, 0, 0], acc_type = <FP32>, %arg0, %4, %5 : !spirv.arm.tensor<1x2x2x2xf8E5M2>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x2x2x2xf8E5M2>
|
||||
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E5M2>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.Conv2D
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -151,6 +167,54 @@ spirv.ARM.Graph @conv2d_accumulator_must_be_either_FP32_for_f32_input_element_ty
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv2d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf32>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv2d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf32>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv2d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
|
||||
%7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv2d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
|
||||
%7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv2d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
|
||||
%7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv2d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
|
||||
%7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.Conv3D
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -227,6 +291,54 @@ spirv.ARM.Graph @conv3d_accumulator_must_be_either_FP32_for_f32_input_element_ty
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv3d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
|
||||
%7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv3d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x4x8xf32>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x4x8xf32>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv3d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x4x8xf32>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x4x8xf32>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv3d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
|
||||
%7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv3d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
|
||||
%7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @conv3d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
|
||||
%7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x4x8xf16>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.DepthwiseConv2D
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -303,6 +415,54 @@ spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_either_FP32_for_f32_input_
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
|
||||
%7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<1x1x4x2xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf32>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E5M2>, %arg2: !spirv.arm.tensor<8xf32>) -> (!spirv.arm.tensor<1x4x4x8xf32>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<1x1x4x2xf8E5M2>, !spirv.arm.tensor<8xf32>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf32>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @depthwise_conv2d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
|
||||
%7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<1x1x4x2xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @depthwise_conv2d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
|
||||
%7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
|
||||
%7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<1x1x4x2xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.MatMul
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -355,6 +515,18 @@ spirv.ARM.Graph @matmul_element_types_must_match_between_input_B_and_B_zero_poin
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @matmul_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<1x4x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<1xf8E4M3FN>, %arg3: !spirv.arm.tensor<1xf8E4M3FN>) -> (!spirv.arm.tensor<1x4x4xf32>) {
|
||||
// expected-error @+1 {{op failed to verify that if A has type f8E4M3FN type then output must have a type in [16-bit float]}}
|
||||
%0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xf8E4M3FN>, !spirv.arm.tensor<1x4x4xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4xf32>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @matmul_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<1x4x4xf8E5M2>, %arg2: !spirv.arm.tensor<1xf8E5M2>, %arg3: !spirv.arm.tensor<1xf8E5M2>) -> (!spirv.arm.tensor<1x4x4xf32>) {
|
||||
// expected-error @+1 {{op failed to verify that if A has type f8E5M2 type then output must have a type in [16-bit float]}}
|
||||
%0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xf8E5M2>, !spirv.arm.tensor<1x4x4xf8E5M2>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4xf32>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.MaxPool2D
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -441,6 +613,54 @@ spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP32_for_f32_input_
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xbf16>) -> (!spirv.arm.tensor<1x4x4x8xbf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xbf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xbf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xbf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xbf16>) -> (!spirv.arm.tensor<1x4x4x8xbf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float]}}
|
||||
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xbf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xbf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xbf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @transpose_conv2d_weight_element_type_must_match_f8e4m3fn_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then weight must have a type in [f8E4M3FN type]}}
|
||||
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @transpose_conv2d_weight_element_type_must_match_f8e5m2_input(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then weight must have a type in [f8E5M2 type]}}
|
||||
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP16>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_FP16_for_f8e4m3fn_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E4M3FN>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E4M3FN type}}
|
||||
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E4M3FN>, !spirv.arm.tensor<8x1x1x4xf8E4M3FN>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E4M3FN>, !spirv.arm.tensor<1xf8E4M3FN> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_FP16_for_f8e5m2_input_element_type(%arg0: !spirv.arm.tensor<1x4x4x4xf8E5M2>, %arg1: !spirv.arm.tensor<8x1x1x4xf8E5M2>, %arg2: !spirv.arm.tensor<8xf16>) -> (!spirv.arm.tensor<1x4x4x8xf16>) {
|
||||
%5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
%6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf8E5M2>
|
||||
// expected-error @+1 {{op failed to verify that acc_type must be one in [FP16] when type has value f8E5M2 type}}
|
||||
%7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = <FP32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x4x4xf8E5M2>, !spirv.arm.tensor<8x1x1x4xf8E5M2>, !spirv.arm.tensor<8xf16>, !spirv.arm.tensor<1xf8E5M2>, !spirv.arm.tensor<1xf8E5M2> -> !spirv.arm.tensor<1x4x4x8xf16>
|
||||
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x4x8xf16>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.Clamp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1866,23 +2086,35 @@ spirv.ARM.Graph @cast_input_output_shapes_not_matching(%arg0: !spirv.arm.tensor<
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
|
||||
// expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xbf16>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
|
||||
// expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xf16>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f16_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf16>) -> (!spirv.arm.tensor<2x3x4xi1>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type 16-bit float then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf16> -> !spirv.arm.tensor<2x3x4xi1>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f32_to_f32_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf32>) -> (!spirv.arm.tensor<2x3x4xf32>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type]}}
|
||||
// expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf32> -> !spirv.arm.tensor<2x3x4xf32>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f32_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf32>) -> (!spirv.arm.tensor<2x3x4xi1>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type 32-bit float then output must have a type in [16-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,bfloat16 type,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf32> -> !spirv.arm.tensor<2x3x4xi1>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_i8_to_i8_not_supported(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi8>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [16-bit float,32-bit float,16-bit signless integer,32-bit signless integer,bool,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xi8> -> !spirv.arm.tensor<2x3x4xi8>
|
||||
@@ -1920,17 +2152,59 @@ spirv.ARM.Graph @cast_bool_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4x
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_bf16_to_f16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xf16>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
|
||||
// expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xf16>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_bf16_to_bf16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xbf16>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer]}}
|
||||
// expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xbf16>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xbf16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_bf16_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xbf16>) -> (!spirv.arm.tensor<2x3x4xi1>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type bfloat16 type then output must have a type in [32-bit float,16-bit signless integer,32-bit signless integer,8-bit signless integer,f8E4M3FN type,f8E5M2 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xbf16> -> !spirv.arm.tensor<2x3x4xi1>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f8e4m3fn_to_i8_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E4M3FN>) -> (!spirv.arm.tensor<2x3x4xi8>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E4M3FN> -> !spirv.arm.tensor<2x3x4xi8>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi8>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f8e4m3fn_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E4M3FN>) -> (!spirv.arm.tensor<2x3x4xi1>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E4M3FN> -> !spirv.arm.tensor<2x3x4xi1>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f8e4m3fn_to_f8e4m3fn_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E4M3FN>) -> (!spirv.arm.tensor<2x3x4xf8E4M3FN>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E4M3FN type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E4M3FN> -> !spirv.arm.tensor<2x3x4xf8E4M3FN>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf8E4M3FN>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f8e5m2_to_i16_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E5M2>) -> (!spirv.arm.tensor<2x3x4xi16>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E5M2> -> !spirv.arm.tensor<2x3x4xi16>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi16>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f8e5m2_to_bool_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E5M2>) -> (!spirv.arm.tensor<2x3x4xi1>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E5M2> -> !spirv.arm.tensor<2x3x4xi1>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xi1>
|
||||
}
|
||||
|
||||
spirv.ARM.Graph @cast_f8e5m2_to_f8e5m2_not_supported(%arg0: !spirv.arm.tensor<2x3x4xf8E5M2>) -> (!spirv.arm.tensor<2x3x4xf8E5M2>) {
|
||||
// expected-error @+1 {{op failed to verify that if input has type f8E5M2 type then output must have a type in [16-bit float,32-bit float,bfloat16 type]}}
|
||||
%0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<2x3x4xf8E5M2> -> !spirv.arm.tensor<2x3x4xf8E5M2>
|
||||
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3x4xf8E5M2>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.TOSA.Rescale
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user