[mlir][Vector] Add support for poison indices to Extract/IndexOp (#123488)
Following up on #122188, this PR adds support for poison indices to `ExtractOp` and `InsertOp`. It also includes canonicalization patterns to turn extract/insert ops with poison indices into `ub.poison`.
This commit is contained in:
@@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
|
||||
def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
|
||||
let summary = "Convert Vector dialect to SPIR-V dialect";
|
||||
let constructor = "mlir::createConvertVectorToSPIRVPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
let dependentDialects = [
|
||||
"spirv::SPIRVDialect",
|
||||
"ub::UBDialect"
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
|
||||
|
||||
// Base class for Vector dialect ops.
|
||||
class Vector_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Vector_Dialect, mnemonic, traits>;
|
||||
Op<Vector_Dialect, mnemonic, traits> {
|
||||
|
||||
// Includes definitions for operations that support the use of poison values
|
||||
// within positive index ranges.
|
||||
code extraPoisonClassDeclaration = [{
|
||||
// Integer to represent a poison index within a static and positive integer
|
||||
// range.
|
||||
static constexpr int64_t kPoisonIndex = -1;
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR
|
||||
|
||||
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
|
||||
```
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Integer to represent a poison value in a vector shuffle mask.
|
||||
static constexpr int64_t kMaskPoisonValue = -1;
|
||||
|
||||
let extraClassDeclaration = extraPoisonClassDeclaration # [{
|
||||
VectorType getV1VectorType() {
|
||||
return ::llvm::cast<VectorType>(getV1().getType());
|
||||
}
|
||||
@@ -693,9 +690,10 @@ def Vector_ExtractOp :
|
||||
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
|
||||
the proper position. Degenerates to an element type if n-k is zero.
|
||||
|
||||
Dynamic indices must be greater or equal to zero and less than the size of
|
||||
the corresponding dimension. The result is undefined if any index is
|
||||
out-of-bounds.
|
||||
Static and dynamic indices must be greater or equal to zero and less than
|
||||
the size of the corresponding dimension. The result is undefined if any
|
||||
index is out-of-bounds. The value `-1` represents a poison index, which
|
||||
specifies that the extracted element is poison.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -705,9 +703,8 @@ def Vector_ExtractOp :
|
||||
%3 = vector.extract %1[]: vector<f32> from vector<f32>
|
||||
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
|
||||
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
|
||||
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
|
||||
```
|
||||
|
||||
TODO: Implement support for poison indices.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@@ -724,7 +721,7 @@ def Vector_ExtractOp :
|
||||
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
let extraClassDeclaration = extraPoisonClassDeclaration # [{
|
||||
VectorType getSourceVectorType() {
|
||||
return ::llvm::cast<VectorType>(getVector().getType());
|
||||
}
|
||||
@@ -885,9 +882,10 @@ def Vector_InsertOp :
|
||||
and inserts the n-D source into the (n+k)-D destination at the proper
|
||||
position. Degenerates to a scalar or a 0-d vector source type when n = 0.
|
||||
|
||||
Dynamic indices must be greater or equal to zero and less than the size of
|
||||
the corresponding dimension. The result is undefined if any index is
|
||||
out-of-bounds.
|
||||
Static and dynamic indices must be greater or equal to zero and less than
|
||||
the size of the corresponding dimension. The result is undefined if any
|
||||
index is out-of-bounds. The value `-1` represents a poison index, which
|
||||
specifies that the resulting vector is poison.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -897,9 +895,8 @@ def Vector_InsertOp :
|
||||
%8 = vector.insert %6, %7[] : f32 into vector<f32>
|
||||
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
|
||||
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
|
||||
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
|
||||
```
|
||||
|
||||
TODO: Implement support for poison indices.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@@ -917,7 +914,7 @@ def Vector_InsertOp :
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
let extraClassDeclaration = extraPoisonClassDeclaration # [{
|
||||
Type getSourceType() { return getSource().getType(); }
|
||||
VectorType getDestVectorType() {
|
||||
return ::llvm::cast<VectorType>(getDest().getType());
|
||||
@@ -990,15 +987,13 @@ def Vector_ScalableInsertOp :
|
||||
```mlir
|
||||
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
|
||||
```
|
||||
|
||||
TODO: Implement support for poison indices.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
let extraClassDeclaration = extraPoisonClassDeclaration # [{
|
||||
VectorType getSourceVectorType() {
|
||||
return ::llvm::cast<VectorType>(getSource().getType());
|
||||
}
|
||||
@@ -1043,15 +1038,13 @@ def Vector_ScalableExtractOp :
|
||||
```mlir
|
||||
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
|
||||
```
|
||||
|
||||
TODO: Implement support for poison indices.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
let extraClassDeclaration = extraPoisonClassDeclaration # [{
|
||||
VectorType getSourceVectorType() {
|
||||
return ::llvm::cast<VectorType>(getSource().getType());
|
||||
}
|
||||
@@ -1089,8 +1082,6 @@ def Vector_InsertStridedSliceOp :
|
||||
{offsets = [0, 0, 2], strides = [1, 1]}:
|
||||
vector<2x4xf32> into vector<16x4x8xf32>
|
||||
```
|
||||
|
||||
TODO: Implement support for poison indices.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
||||
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
|
||||
details.
|
||||
}];
|
||||
let constructor = "mlir::createCanonicalizerPass()";
|
||||
let dependentDialects = ["ub::UBDialect"];
|
||||
let options = [
|
||||
Option<"topDownProcessingEnabled", "top-down", "bool",
|
||||
/*default=*/"true",
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
@@ -27,7 +26,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include <memory>
|
||||
|
||||
#define DEBUG_TYPE "convert-to-spirv"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
|
||||
return srcElements[posIdx];
|
||||
}
|
||||
|
||||
// Returns `true` if `index` is either within [0, maxIndex) or equal to
|
||||
// `poisonValue`.
|
||||
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
|
||||
int64_t maxIndex) {
|
||||
return index == poisonValue || (index >= 0 && index < maxIndex);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() {
|
||||
for (auto [idx, pos] : llvm::enumerate(position)) {
|
||||
if (auto attr = dyn_cast<Attribute>(pos)) {
|
||||
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
|
||||
if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
|
||||
if (!isValidPositiveIndexOrPoison(
|
||||
constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
|
||||
return emitOpError("expected position attribute #")
|
||||
<< (idx + 1)
|
||||
<< " to be a non-negative integer smaller than the "
|
||||
"corresponding vector dimension";
|
||||
"corresponding vector dimension or poison (-1)";
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
|
||||
return fromElementsOp.getElements()[flatIndex];
|
||||
}
|
||||
|
||||
OpFoldResult ExtractOp::fold(FoldAdaptor) {
|
||||
/// Fold an insert or extract operation into an poison value when a poison index
|
||||
/// is found at any dimension of the static position.
|
||||
static ub::PoisonAttr
|
||||
foldPoisonIndexInsertExtractOp(MLIRContext *context,
|
||||
ArrayRef<int64_t> staticPos, int64_t poisonVal) {
|
||||
if (!llvm::is_contained(staticPos, poisonVal))
|
||||
return ub::PoisonAttr();
|
||||
|
||||
return ub::PoisonAttr::get(context);
|
||||
}
|
||||
|
||||
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
|
||||
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
|
||||
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
|
||||
// mismatch).
|
||||
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
|
||||
return getVector();
|
||||
if (auto res = foldPoisonIndexInsertExtractOp(
|
||||
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
|
||||
return res;
|
||||
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
||||
return getResult();
|
||||
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
|
||||
@@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
|
||||
resultType.getNumElements()));
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Fold an insert or extract operation into an poison value when a poison index
|
||||
/// is found at any dimension of the static position.
|
||||
template <typename OpTy>
|
||||
LogicalResult
|
||||
canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
|
||||
if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
|
||||
op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
|
||||
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
|
||||
results.add(foldExtractFromShapeCastToShapeCast);
|
||||
results.add(foldExtractFromFromElements);
|
||||
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
|
||||
}
|
||||
|
||||
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
|
||||
@@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() {
|
||||
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
|
||||
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
|
||||
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
|
||||
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
|
||||
if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
|
||||
return emitOpError("mask index #") << (idx + 1) << " out of range";
|
||||
}
|
||||
return success();
|
||||
@@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() {
|
||||
for (auto [idx, pos] : llvm::enumerate(position)) {
|
||||
if (auto attr = pos.dyn_cast<Attribute>()) {
|
||||
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
|
||||
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
|
||||
if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
|
||||
destVectorType.getDimSize(idx))) {
|
||||
return emitOpError("expected position attribute #")
|
||||
<< (idx + 1)
|
||||
<< " to be a non-negative integer smaller than the "
|
||||
@@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
|
||||
InsertOpConstantFolder>(context);
|
||||
results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
|
||||
}
|
||||
|
||||
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
|
||||
@@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
|
||||
// (type mismatch).
|
||||
if (getNumIndices() == 0 && getSourceType() == getType())
|
||||
return getSource();
|
||||
if (auto res = foldPoisonIndexInsertExtractOp(
|
||||
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
|
||||
return res;
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRTransformUtils
|
||||
MLIRUBDialect
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/UB/IR/UBOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
||||
@@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
|
||||
%0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: @extract_poison_idx
|
||||
// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
|
||||
// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
|
||||
%0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
|
||||
return %0 : f32
|
||||
|
||||
@@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
|
||||
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
|
||||
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_size1_vector
|
||||
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
|
||||
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
|
||||
@@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
|
||||
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
|
||||
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
|
||||
return %1: vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_index_vector
|
||||
// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
|
||||
func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
|
||||
|
||||
@@ -132,6 +132,37 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_scalar_poison_idx
|
||||
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
|
||||
// CHECK-NOT: vector.extract
|
||||
// CHECK-NEXT: ub.poison : f32
|
||||
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_vector_poison_idx
|
||||
func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
|
||||
// CHECK-NOT: vector.extract
|
||||
// CHECK-NEXT: ub.poison : vector<5xf32>
|
||||
%0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
|
||||
return %0 : vector<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_multiple_poison_idx
|
||||
func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>)
|
||||
-> vector<8xf32> {
|
||||
// CHECK-NOT: vector.extract
|
||||
// CHECK-NEXT: ub.poison : vector<8xf32>
|
||||
%0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
|
||||
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
|
||||
func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
|
||||
@@ -2778,7 +2809,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
|
||||
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_insert_const_regression(
|
||||
@@ -2792,6 +2822,39 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_scalar_poison_idx
|
||||
func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
|
||||
-> vector<4x5xf32> {
|
||||
// CHECK-NOT: vector.insert
|
||||
// CHECK-NEXT: ub.poison : vector<4x5xf32>
|
||||
%0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
|
||||
return %0 : vector<4x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_vector_poison_idx
|
||||
func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
|
||||
-> vector<4x5xf32> {
|
||||
// CHECK-NOT: vector.insert
|
||||
// CHECK-NEXT: ub.poison : vector<4x5xf32>
|
||||
%0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
|
||||
return %0 : vector<4x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_multiple_poison_idx
|
||||
func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
|
||||
-> vector<4x5x8xf32> {
|
||||
// CHECK-NOT: vector.insert
|
||||
// CHECK-NEXT: ub.poison : vector<4x5x8xf32>
|
||||
%0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32>
|
||||
return %0 : vector<4x5x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
|
||||
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
|
||||
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
|
||||
|
||||
@@ -186,8 +186,8 @@ func.func @extract_0d(%arg0: vector<f32>) {
|
||||
// -----
|
||||
|
||||
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
|
||||
%1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}}
|
||||
%1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
|
||||
func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
|
||||
%1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32>
|
||||
%1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
@@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_poison_idx
|
||||
func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
|
||||
// CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32>
|
||||
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_element_0d
|
||||
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
|
||||
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
|
||||
@@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f
|
||||
return %1, %2 : vector<f32>, vector<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_poison_idx
|
||||
func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) {
|
||||
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32>
|
||||
vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @outerproduct
|
||||
func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
|
||||
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
|
||||
|
||||
Reference in New Issue
Block a user