[mlir][x86] Rename x86vector to x86 (#183311)

Renames 'x86vector' dialect to 'x86'.

This is the first PR in series of cleanups around dialects targeting x86
platforms.
The new naming scheme is shorter, cleaner, and opens possibility of
integrating other x86-specific operations not strictly fitting pure
vector representation. For example, the generalization will allow for
future merger of AMX dialect into the x86 dialect to create one-stop x86
operations collection and boost discoverability.
This commit is contained in:
Adam Siemieniuk
2026-02-26 11:21:58 +01:00
committed by GitHub
parent 6d7ec4b7c3
commit 67ac275fee
80 changed files with 750 additions and 768 deletions

View File

@@ -105,7 +105,7 @@ available, should be contacted first, as they're more active in those areas.
* arm_sve Dialect ([@banach-space](https://github.com/banach-space))
* ArmSME Dialect ([@banach-space](https://github.com/banach-space))
* amx Dialect ([@adam-smnk](https://github.com/adam-smnk))
* x86vector Dialect ([@adam-smnk](https://github.com/adam-smnk))
* x86 Dialect ([@adam-smnk](https://github.com/adam-smnk))
* vcix Dialect ([@mshockwave](https://github.com/mshockwave))
#### Paradigm Dialects

View File

@@ -6,7 +6,7 @@ overall flow is two-stage:
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
[X86Vector](Dialects/X86Vector.md) or [ArmNeon](Dialects/ArmNeon.md);
[X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
2. **translation** of MLIR dialects to LLVM IR.
This flow allows the non-trivial transformation to be performed within MLIR

View File

@@ -1,4 +1,4 @@
//===-- mlir-c/Dialect/x86Vector.h - C API for x86Vector Dialect --*- C -*-===//
//===-- mlir-c/Dialect/x86.h - C API for x86 Dialect --------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
@@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_X86VECTOR_H
#define MLIR_C_DIALECT_X86VECTOR_H
#ifndef MLIR_C_DIALECT_X86_H
#define MLIR_C_DIALECT_X86_H
#include "mlir-c/IR.h"
@@ -16,10 +16,10 @@
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(X86Vector, x86vector);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(X86, x86);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_DIALECT_X86VECTOR_H
#endif // MLIR_C_DIALECT_X86_H

View File

@@ -1521,7 +1521,7 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
(AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the
(AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
architectural-neutral vector dialect lowering.
}];
@@ -1564,10 +1564,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_BF16 instructions while lowering "
"the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
Option<"x86", "enable-x86",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
"dialect.">,
"Enables the use of X86 dialect while lowering the vector dialect.">,
Option<"vectorContractLowering", "vector-contract-lowering",
"vector::VectorContractLowering",
/*default=*/"vector::VectorContractLowering::Dot",

View File

@@ -42,5 +42,5 @@ add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
add_subdirectory(WasmSSA)
add_subdirectory(X86Vector)
add_subdirectory(X86)
add_subdirectory(XeGPU)

View File

@@ -19,7 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"

View File

@@ -118,9 +118,9 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
desc("Enables the use of ArmSVE dialect while lowering the vector "
"dialect"),
init(false)};
PassOptions::Option<bool> x86Vector{
*this, "enable-x86vector",
desc("Enables the use of X86Vector dialect while lowering the vector "
PassOptions::Option<bool> x86{
*this, "enable-x86",
desc("Enables the use of X86 dialect while lowering the vector "
"dialect"),
init(false)};
@@ -169,7 +169,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
opts.armNeon = armNeon;
opts.armSVE = armSVE;
opts.amx = amx;
opts.x86Vector = x86Vector;
opts.x86 = x86;
return opts;
}
};

View File

@@ -0,0 +1,7 @@
add_mlir_dialect(X86 x86)
add_mlir_doc(X86 X86 Dialects/ -gen-dialect-doc -dialect=x86)
add_mlir_interface(X86Interfaces)
add_dependencies(MLIRX86IncGen MLIRX86InterfacesIncGen)
add_subdirectory(TransformOps)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS X86TransformOps.td)
mlir_tablegen(X86TransformOps.h.inc -gen-op-decls)
mlir_tablegen(X86TransformOps.cpp.inc -gen-op-defs)
add_mlir_dialect_tablegen_target(MLIRX86TransformOpsIncGen)

View File

@@ -1,4 +1,4 @@
//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
//===- X86TransformOps.h - X86 transform ops --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,26 +6,26 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
#ifndef MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
#define MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
//===----------------------------------------------------------------------===//
// X86Vector Transform Operations
// X86 Transform Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h.inc"
namespace mlir {
class DialectRegistry;
namespace x86vector {
namespace x86 {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace x86vector
} // namespace x86
} // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
#endif // MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H

View File

@@ -1,4 +1,4 @@
//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===//
//===- X86TransformOps.td - X86 transform ops --------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef X86VECTOR_TRANSFORM_OPS
#define X86VECTOR_TRANSFORM_OPS
#ifndef X86_TRANSFORM_OPS
#define X86_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
@@ -18,7 +18,7 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/IR/RegionKindInterface.td"
def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_fma",
"apply_patterns.x86.vector_contract_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to lower a F32 type vector.contract operation to a FMA.
@@ -28,7 +28,7 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
}
def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
"apply_patterns.x86.vector_contract_to_packed_type_dot_product",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to lower a BF16/Int8 type vector.contract operation
@@ -39,7 +39,7 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
}
def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_bf16_to_fma",
"apply_patterns.x86.vector_contract_bf16_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to lower a BF16 type vector.contract operation
@@ -50,7 +50,7 @@ def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
}
def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.sink_vector_producer_ops",
"apply_patterns.x86.sink_vector_producer_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to sink vector producer operations forward in a block to
@@ -61,10 +61,10 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
}
def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.shuffle_vector_fma_ops",
"apply_patterns.x86.shuffle_vector_fma_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to shuffle FMAs with x86vector operations as operands
Collect patterns to shuffle FMAs with x86 operations as operands
such that FMAs are grouped with respect to odd/even packed index.
}];
@@ -72,5 +72,4 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
}
#endif // X86VECTOR_TRANSFORM_OPS
#endif // X86_TRANSFORM_OPS

View File

@@ -1,4 +1,4 @@
//=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=//
//=- Transforms.h - X86 Dialect Transformation Entrypoints --------*- C++ -*-=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#ifndef MLIR_DIALECT_X86_TRANSFORMS_H
#define MLIR_DIALECT_X86_TRANSFORMS_H
#include "mlir/IR/Value.h"
@@ -18,7 +18,7 @@ class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
namespace x86vector {
namespace x86 {
/// Helper class to factor out the creation and extraction of masks from nibs.
struct MaskHelper {
@@ -100,7 +100,7 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
// range by placing them at their earliest legal use site.
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
// Shuffles FMAs with x86vector operations as operands such that FMAs are
// Shuffles FMAs with x86 operations as operands such that FMAs are
// grouped with respect to odd/even packed index.
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
@@ -196,17 +196,17 @@ void populateSpecializedTransposeLoweringPatterns(
int benefit = 10);
} // namespace avx2
} // namespace x86vector
} // namespace x86
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
/// Collect a set of patterns to lower X86 ops to ops that map to LLVM
/// intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(
void populateX86LegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Configure the target to support lowering X86Vector ops to ops that map to
/// Configure the target to support lowering X86 ops to ops that map to
/// LLVM intrinsics.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target);
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
} // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#endif // MLIR_DIALECT_X86_TRANSFORMS_H

View File

@@ -1,4 +1,4 @@
//===- X86VectorUtils.h - X86Vector Utilities -------------------*- C++ -*-===//
//===- X86Utils.h - X86 Utilities -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
#ifndef MLIR_DIALECT_X86_UTILS_X86UTILS_H_
#define MLIR_DIALECT_X86_UTILS_X86UTILS_H_
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -22,7 +22,7 @@ namespace mlir {
class AffineMap;
class Operation;
namespace x86vector {
namespace x86 {
// Return true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
@@ -63,7 +63,7 @@ LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
Operation *opA, Operation *opB,
int64_t nonUnitDimAcc, VectorType accTy);
} // namespace x86vector
} // namespace x86
} // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
#endif // MLIR_DIALECT_X86_UTILS_X86UTILS_H_

View File

@@ -1,4 +1,4 @@
//===-- X86VectorOps.td - X86Vector dialect operation defs -*- tablegen -*-===//
//===-- X86Ops.td - X86 dialect operation defs -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,25 +6,25 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines the basic operations for the X86Vector dialect.
// This file defines the basic operations for the X86 dialect.
//
//===----------------------------------------------------------------------===//
#ifndef X86VECTOR_OPS
#define X86VECTOR_OPS
#ifndef X86_OPS
#define X86_OPS
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/X86Vector/X86VectorInterfaces.td"
include "mlir/Dialect/X86/X86Interfaces.td"
//===----------------------------------------------------------------------===//
// X86Vector dialect definition
// X86 dialect definition
//===----------------------------------------------------------------------===//
def X86Vector_Dialect : Dialect {
let name = "x86vector";
let cppNamespace = "::mlir::x86vector";
def X86_Dialect : Dialect {
let name = "x86";
let cppNamespace = "::mlir::x86";
}
//===----------------------------------------------------------------------===//
@@ -33,7 +33,7 @@ def X86Vector_Dialect : Dialect {
// Operation that is part of the input dialect.
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
Op<X86_Dialect, "avx512." # mnemonic, traits> {}
//----------------------------------------------------------------------------//
// MaskCompressOp
@@ -279,7 +279,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
Example:
```mlir
%dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
%dst = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
```
}];
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -323,7 +323,7 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
Example:
```mlir
%dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
%dst = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
```
}];
let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
@@ -349,7 +349,7 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
// Operation that is part of the input dialect.
class AVX10_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {}
Op<X86_Dialect, "avx10." # mnemonic, traits> {}
//----------------------------------------------------------------------------//
// AVX10 Int8 Dot
@@ -376,7 +376,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
Example:
```mlir
%dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
%dst = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
```
}];
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
@@ -401,13 +401,13 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
// Operation that is part of the input dialect.
class AVX_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx." # mnemonic, traits> {}
Op<X86_Dialect, "avx." # mnemonic, traits> {}
// Operation that may be part of the input dialect, but whose
// form is somewhere between the user view of the operation
// and the actual lower level intrinsic in LLVM IR.
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
Op<X86_Dialect, "avx.intr." # mnemonic, traits> {}
//----------------------------------------------------------------------------//
// AVX Rsqrt
@@ -448,7 +448,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
Example:
```mlir
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
%1 = vector.extract %0[%i0] : f32 from vector<8xf32>
%2 = vector.extract %0[%i4] : f32 from vector<8xf32>
%d = arith.addf %1, %2 : f32
@@ -500,7 +500,7 @@ def DotInt8Op : AVX_Op<"dot.i8", [Pure,
Example:
```mlir
%dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
%dst = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
```
}];
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
@@ -543,8 +543,8 @@ def BcstToPackedF32Op
Example:
```mlir
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
%dst = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%dst = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
```
}];
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
@@ -595,8 +595,8 @@ def CvtPackedEvenIndexedToF32Op
Example:
```mlir
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%dst = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
@@ -642,8 +642,8 @@ def CvtPackedOddIndexedToF32Op
Example:
```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%dst = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
@@ -673,4 +673,4 @@ def CvtPackedOddIndexedToF32Op
::mlir::RewriterBase &rewriter);
}];
}
#endif // X86VECTOR_OPS
#endif // X86_OPS

View File

@@ -1,4 +1,4 @@
//===- X86VectorDialect.h - MLIR Dialect for X86Vector ----------*- C++ -*-===//
//===- X86Dialect.h - MLIR Dialect for X86 ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
// This file declares the Target dialect for X86Vector in MLIR.
// This file declares the Target dialect for X86 in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
#ifndef MLIR_DIALECT_X86_X86DIALECT_H_
#define MLIR_DIALECT_X86_X86DIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -25,11 +25,11 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
/// Include the generated interface declarations.
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc"
#include "mlir/Dialect/X86/X86Interfaces.h.inc"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc"
#include "mlir/Dialect/X86/X86Dialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.h.inc"
#include "mlir/Dialect/X86/X86.h.inc"
#endif // MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
#endif // MLIR_DIALECT_X86_X86DIALECT_H_

View File

@@ -1,4 +1,4 @@
//===- X86VectorInterfaces.td - X86Vector interfaces -------*- tablegen -*-===//
//===- X86Interfaces.td - X86 interfaces -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines interfaces for the X86Vector dialect.
// This file defines interfaces for the X86 dialect.
//
//===----------------------------------------------------------------------===//
#ifndef X86VECTOR_INTERFACES
#define X86VECTOR_INTERFACES
#ifndef X86_INTERFACES
#define X86_INTERFACES
include "mlir/IR/Interfaces.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
@@ -25,7 +25,7 @@ def X86IntrinsicOpInterface
let description = [{
A wrapper interface for operations representing x86 LLVM intrinsics.
}];
let cppNamespace = "::mlir::x86vector";
let cppNamespace = "::mlir::x86";
}
#endif // X86VECTOR_INTERFACES
#endif // X86_INTERFACES

View File

@@ -1,7 +0,0 @@
add_mlir_dialect(X86Vector x86vector)
add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
add_subdirectory(TransformOps)

View File

@@ -1,4 +0,0 @@
set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)

View File

@@ -517,13 +517,13 @@ add_mlir_upstream_c_api_library(MLIRCAPIWasmSSA
MLIRWasmSSADialect
)
add_mlir_upstream_c_api_library(MLIRCAPIX86Vector
X86Vector.cpp
add_mlir_upstream_c_api_library(MLIRCAPIX86
X86.cpp
PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRX86VectorDialect
MLIRX86Dialect
)
add_mlir_upstream_c_api_library(MLIRCAPIXeGPU

View File

@@ -1,4 +1,4 @@
//===- X86Vector.cpp - C Interface for X86Vector dialect ------------------===//
//===- X86.cpp - C Interface for X86 dialect ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/X86Vector.h"
#include "mlir-c/Dialect/X86.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/X86Dialect.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(X86Vector, x86vector,
mlir::x86vector::X86VectorDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(X86, x86, mlir::x86::X86Dialect)

View File

@@ -40,6 +40,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRArmSVETransforms
MLIRAMXDialect
MLIRAMXTransforms
MLIRX86VectorDialect
MLIRX86VectorTransforms
MLIRX86Dialect
MLIRX86Transforms
)

View File

@@ -22,8 +22,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -53,8 +53,8 @@ struct ConvertVectorToLLVMPass
registry.insert<arm_sve::ArmSVEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
registry.insert<x86vector::X86VectorDialect>();
if (x86)
registry.insert<x86::X86Dialect>();
}
void runOnOperation() override;
};
@@ -140,9 +140,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
}
if (x86Vector) {
configureX86VectorLegalizeForExportTarget(target);
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
if (x86) {
configureX86LegalizeForExportTarget(target);
populateX86LegalizeForLLVMExportPatterns(converter, patterns);
}
if (failed(

View File

@@ -42,7 +42,7 @@ add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
add_subdirectory(WasmSSA)
add_subdirectory(X86Vector)
add_subdirectory(X86)
add_subdirectory(XeGPU)
set(LLVM_OPTIONAL_SOURCES

View File

@@ -21,6 +21,6 @@ add_mlir_dialect_library(MLIRMathTransforms
MLIRSCFDialect
MLIRPass
MLIRTransforms
MLIRX86VectorDialect
MLIRX86Dialect
MLIRVectorDialect
)

View File

@@ -22,7 +22,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -1740,7 +1740,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
// Compute an approximate result.
Value yApprox = handleMultidimensionalVectors(
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
return x86vector::RsqrtOp::create(builder, operands);
return x86::RsqrtOp::create(builder, operands);
});
// Do a single step of Newton-Raphson iteration to improve the approximation.

View File

@@ -18,5 +18,5 @@ add_mlir_dialect_library(MLIRVectorTransformOps
MLIRTransformDialect
MLIRVectorDialect
MLIRVectorToSCF
MLIRX86VectorTransforms
MLIRX86Transforms
)

View File

@@ -18,7 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86/Transforms.h"
using namespace mlir;
using namespace mlir::vector;
@@ -191,12 +191,10 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
vector::populateVectorTransposeLoweringPatterns(patterns,
getLoweringStrategy());
if (getAvx2LoweringStrategy()) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32(true)
.lower8x8xf32(true));
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
auto avx2LoweringOptions = x86::avx2::LoweringOptions().setTransposeOptions(
x86::avx2::TransposeLoweringOptions().lower4x8xf32(true).lower8x8xf32(
true));
x86::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);
}
}

View File

@@ -1,11 +1,11 @@
add_mlir_dialect_library(MLIRX86VectorDialect
X86VectorDialect.cpp
add_mlir_dialect_library(MLIRX86Dialect
X86Dialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86
DEPENDS
MLIRX86VectorIncGen
MLIRX86IncGen
LINK_LIBS PUBLIC
MLIRIR

View File

@@ -1,4 +1,4 @@
//===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
//===- X86Dialect.cpp - MLIR X86 ops implementation -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,25 +6,25 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements the X86Vector dialect and its operations.
// This file implements the X86 dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
#include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
#include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
#include "mlir/Dialect/X86/X86Dialect.cpp.inc"
void x86vector::X86VectorDialect::initialize() {
void x86::X86Dialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
#include "mlir/Dialect/X86/X86.cpp.inc"
>();
}
@@ -35,7 +35,7 @@ static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
}
LogicalResult x86vector::MaskCompressOp::verify() {
LogicalResult x86::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src");
@@ -49,7 +49,7 @@ LogicalResult x86vector::MaskCompressOp::verify() {
return success();
}
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
@@ -71,9 +71,9 @@ SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
}
SmallVector<Value>
x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
SmallVector<Value> intrinsicOperands(operands);
// Dot product of all elements, broadcasted to all elements.
Value scale =
@@ -83,7 +83,7 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
return intrinsicOperands;
}
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -91,7 +91,7 @@ SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -99,7 +99,7 @@ SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -108,4 +108,4 @@ SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
}
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
#include "mlir/Dialect/X86/X86.cpp.inc"

View File

@@ -1,8 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorTransformOps
X86VectorTransformOps.cpp
add_mlir_dialect_library(MLIRX86TransformOps
X86TransformOps.cpp
DEPENDS
MLIRX86VectorTransformOpsIncGen
MLIRX86TransformOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
@@ -12,6 +12,6 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
MLIRX86VectorDialect
MLIRX86VectorTransforms
MLIRX86Dialect
MLIRX86Transforms
)

View File

@@ -1,4 +1,4 @@
//===- X86VectorTransformOps.cpp ------------------------------------------===//
//===- X86TransformOps.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,45 +6,45 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"
using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::x86;
using namespace mlir::transform;
void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateVectorContractToFMAPatterns(patterns);
x86::populateVectorContractToFMAPatterns(patterns);
}
void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
x86::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}
void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateVectorContractBF16ToFMAPatterns(patterns);
x86::populateVectorContractBF16ToFMAPatterns(patterns);
}
void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
x86::populateSinkVectorProducerOpsPatterns(patterns);
}
void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateShuffleVectorFMAOpsPatterns(patterns);
x86::populateShuffleVectorFMAOpsPatterns(patterns);
}
//===----------------------------------------------------------------------===//
@@ -52,28 +52,26 @@ void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
//===----------------------------------------------------------------------===//
namespace {
class X86VectorTransformDialectExtension
class X86TransformDialectExtension
: public transform::TransformDialectExtension<
X86VectorTransformDialectExtension> {
X86TransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
X86VectorTransformDialectExtension)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86TransformDialectExtension)
X86VectorTransformDialectExtension() {
declareGeneratedDialect<x86vector::X86VectorDialect>();
X86TransformDialectExtension() {
declareGeneratedDialect<x86::X86Dialect>();
declareGeneratedDialect<LLVM::LLVMDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.cpp.inc"
void mlir::x86vector::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<X86VectorTransformDialectExtension>();
void mlir::x86::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<X86TransformDialectExtension>();
}

View File

@@ -15,20 +15,21 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86vector::avx2;
using namespace mlir::x86vector::avx2::inline_asm;
using namespace mlir::x86vector::avx2::intrin;
using namespace mlir::x86;
using namespace mlir::x86::avx2;
using namespace mlir::x86::avx2::inline_asm;
using namespace mlir::x86::avx2::intrin;
Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
Value mlir::x86::avx2::inline_asm::mm256BlendPsAsm(ImplicitLocOpBuilder &b,
Value v1, Value v2,
uint8_t mask) {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
const auto *asmTp = "vblendps $0, $1, $2, {0}";
@@ -45,14 +46,14 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
return asmOp.getResult(0);
}
Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) {
Value mlir::x86::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) {
return vector::ShuffleOp::create(b, v1, v2,
ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
}
Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) {
Value mlir::x86::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) {
return vector::ShuffleOp::create(
b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
}
@@ -60,9 +61,8 @@ Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
/// 0:127 | 128:255
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
Value v1, Value v2,
uint8_t mask) {
Value mlir::x86::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
Value v2, uint8_t mask) {
uint8_t b01, b23, b45, b67;
MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
SmallVector<int64_t> shuffleMask = {
@@ -76,8 +76,9 @@ Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
// 0 1 2 3
// imm[0:1] out of imm[4:7].
Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
Value mlir::x86::avx2::intrin::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
Value v1, Value v2,
uint8_t mask) {
SmallVector<int64_t> shuffleMask;
auto appendToMask = [&](uint8_t control) {
if (control == 0)
@@ -99,9 +100,8 @@ Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
}
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
Value v1, Value v2,
uint8_t mask) {
Value mlir::x86::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, Value v1,
Value v2, uint8_t mask) {
SmallVector<int64_t, 8> shuffleMask;
for (int i = 0; i < 8; ++i) {
bool isSet = mask & (1 << i);
@@ -111,8 +111,8 @@ Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
}
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
void mlir::x86::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
#ifndef NDEBUG
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
assert(vs.size() == 4 && "expects 4 vectors");
@@ -136,8 +136,8 @@ void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
}
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
void mlir::x86::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
(void)vt;
assert(vs.size() == 8 && "expects 8 vectors");
@@ -284,7 +284,7 @@ private:
LoweringOptions loweringOptions;
};
void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
void mlir::x86::avx2::populateSpecializedTransposeLoweringPatterns(
RewritePatternSet &patterns, LoweringOptions options, int benefit) {
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
}

View File

@@ -1,4 +1,4 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
add_mlir_dialect_library(MLIRX86Transforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
@@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
MLIRLLVMDialect
MLIRVectorDialect
MLIRVectorUtils
MLIRX86VectorDialect
MLIRX86VectorUtils
MLIRX86Dialect
MLIRX86Utils
)

View File

@@ -1,4 +1,4 @@
//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
//===- LegalizeForLLVMExport.cpp - Prepare X86 for LLVM translation -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,26 +6,26 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::x86;
namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct X86IntrinsicOpConversion
: public ConvertOpInterfaceToLLVMPattern<x86vector::X86IntrinsicOp> {
: public ConvertOpInterfaceToLLVMPattern<x86::X86IntrinsicOp> {
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult
matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
matchAndRewrite(x86::X86IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite(
@@ -37,13 +37,12 @@ struct X86IntrinsicOpConversion
} // namespace
/// Populate the given list with patterns that convert from X86Vector to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
/// Populate the given list with patterns that convert from X86 to LLVM.
void mlir::populateX86LegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<X86IntrinsicOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addIllegalDialect<X86VectorDialect>();
void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<X86Dialect>();
}

View File

@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/PatternMatch.h"
@@ -17,17 +17,17 @@
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86;
namespace {
// Validates whether the given operation is an x86vector operation and has only
// Validates whether the given operation is an x86 operation and has only
// one consumer.
static bool validateFMAOperands(Value op) {
if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
if (auto cvt = op.getDefiningOp<x86::CvtPackedEvenIndexedToF32Op>())
return cvt.getResult().hasOneUse();
if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
if (auto bcst = op.getDefiningOp<x86::BcstToPackedF32Op>())
return bcst.getResult().hasOneUse();
return false;
@@ -36,14 +36,14 @@ static bool validateFMAOperands(Value op) {
// Validates the vector.fma operation on the following conditions:
// (i) one of the lhs or rhs defining operation should be
// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
// an x86vector operation and has only one consumer, (iii) all operations
// an x86 operation and has only one consumer, (iii) all operations
// are in the same block, and (iv) ths FMA has only one user.
static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
Value lhs = fmaOp.getLhs();
Value rhs = fmaOp.getRhs();
if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
!isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
if (!isa<x86::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
!isa<x86::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
return false;
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
@@ -93,38 +93,38 @@ static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
return;
}
// Shuffle FMAs with x86vector operations as operands such that
// Shuffle FMAs with x86 operations as operands such that
// FMAs are grouped with respect to odd/even packed index.
//
// For example:
// ```
// %1 = x86vector.avx.bcst_to_f32.packed
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
// %1 = x86.avx.bcst_to_f32.packed
// %2 = x86.avx.cvt.packed.odd.indexed_to_f32
// %3 = vector.fma %1, %2, %arg1
// %4 = x86vector.avx.bcst_to_f32.packed
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
// %4 = x86.avx.bcst_to_f32.packed
// %5 = x86.avx.cvt.packed.even.indexed_to_f32
// %6 = vector.fma %4, %5, %3
// %7 = x86vector.avx.bcst_to_f32.packed
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
// %7 = x86.avx.bcst_to_f32.packed
// %8 = x86.avx.cvt.packed.odd.indexed_to_f32
// %9 = vector.fma %7, %8, %arg2
// %10 = x86vector.avx.bcst_to_f32.packed
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
// %10 = x86.avx.bcst_to_f32.packed
// %11 = x86.avx.cvt.packed.even.indexed_to_f32
// %12 = vector.fma %10, %11, %9
// yield %6, %12
// ```
// to
// ```
// %1 = x86vector.avx.bcst_to_f32.packed
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
// %1 = x86.avx.bcst_to_f32.packed
// %2 = x86.avx.cvt.packed.odd.indexed_to_f32
// %3 = vector.fma %1, %2, %arg1
// %7 = x86vector.avx.bcst_to_f32.packed
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
// %7 = x86.avx.bcst_to_f32.packed
// %8 = x86.avx.cvt.packed.odd.indexed_to_f32
// %9 = vector.fma %7, %8, %arg2
// %4 = x86vector.avx.bcst_to_f32.packed
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
// %4 = x86.avx.bcst_to_f32.packed
// %5 = x86.avx.cvt.packed.even.indexed_to_f32
// %6 = vector.fma %4, %5, %3
// %10 = x86vector.avx.bcst_to_f32.packed
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
// %10 = x86.avx.bcst_to_f32.packed
// %11 = x86.avx.cvt.packed.even.indexed_to_f32
// %12 = vector.fma %10, %11, %9
// yield %9, %12
// ```
@@ -150,10 +150,9 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
if (!fma)
continue;
bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
fma.getLhs().getDefiningOp()) ||
isa<x86vector::CvtPackedEvenIndexedToF32Op>(
fma.getRhs().getDefiningOp());
bool hasX86CvtOperand =
isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getLhs().getDefiningOp()) ||
isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getRhs().getDefiningOp());
if (hasX86CvtOperand && stopAtNextDependentFMA)
break;
@@ -180,7 +179,6 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
} // namespace
void x86vector::populateShuffleVectorFMAOpsPatterns(
RewritePatternSet &patterns) {
void x86::populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns) {
patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
}

View File

@@ -8,8 +8,8 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
@@ -20,7 +20,7 @@
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86;
static FailureOr<llvm::SmallVector<Operation *>>
getSameBlockUsers(Operation *op) {
@@ -141,8 +141,7 @@ struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
}
};
void x86vector::populateSinkVectorProducerOpsPatterns(
RewritePatternSet &patterns) {
void x86::populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns) {
patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
}

View File

@@ -11,9 +11,9 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/Utils/X86Utils.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
@@ -25,7 +25,7 @@
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86;
// Verifies that the LHS and RHS operands of a vector.contract are load or
// vector.transfer_read operations on a memref source buffer, and checks
@@ -61,7 +61,7 @@ static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
return false;
// Return false if the two innermost strides of the memref are not contiguous.
// The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
// The x86.avx.cvt.packed.even/odd.indexed_to_f32 operations require
// an eight-element tuple of bf16 values to be contiguous.
int dimsToCheck = isVnni ? 2 : 1;
if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
@@ -162,7 +162,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
subviews.push_back(subview);
// For unit-dims, two subviews should be created for the odd and even
// element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
// element in the VNNI tuple (2xbf16) because x86.avx.bcst_to_f32.packed
// op loads and broadcast the first BF16 element into packed F32. It
// cannot distinguish between even and odd BF16 elements within a
// packed pair.
@@ -193,11 +193,11 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// ```
// to
// ```
// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
// %1 = x86.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
// %2 = x86.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
// %3 = vector.fma %1, %2, %arg1
// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// %4 = x86.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
// %5 = x86.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// return vector.fma %4, %5, %3
// ```
//
@@ -212,10 +212,10 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// ```
// to
// ```
// %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
// %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// %1 = x86.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
// %2 = x86.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// %3 = vector.fma %1, %2, %arg1
// %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
// %4 = x86.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
// %5 = vector.fma %1, %4, %arg2
// scf.yield %3, %5
struct VectorContractBF16ToFMA
@@ -446,11 +446,10 @@ struct VectorContractBF16ToFMA
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
auto loadBcstBF16ElementToF32 = x86::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
auto loadEvenIdxElementF32 =
x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
nonUnitDimSubview[0]);
auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
auto evenIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
loadEvenIdxElementF32, castAcc);
@@ -468,7 +467,7 @@ struct VectorContractBF16ToFMA
accTyPairCont.getElementType()),
pairContractOp.getAcc());
auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA = vector::FMAOp::create(
rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
@@ -481,18 +480,18 @@ struct VectorContractBF16ToFMA
}
// Load, broadcast, and do FMA for odd indexed BF16 elements.
auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
auto loadBcstOddIdxElementToF32 = x86::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
loadOddIdxElementF32, castAcc);
// Load, broadcast, and do FMA for even indexed BF16 elements.
auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
auto loadBcstEvenIdxElementToF32 = x86::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[1]);
auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
vector::FMAOp fma =
vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
@@ -504,7 +503,6 @@ struct VectorContractBF16ToFMA
}
};
void x86vector::populateVectorContractBF16ToFMAPatterns(
RewritePatternSet &patterns) {
void x86::populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns) {
patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
}

View File

@@ -8,8 +8,8 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
@@ -20,7 +20,7 @@
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86;
namespace {
@@ -137,7 +137,6 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
} // namespace
void x86vector::populateVectorContractToFMAPatterns(
RewritePatternSet &patterns) {
void x86::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
patterns.add<VectorContractToFMA>(patterns.getContext());
}

View File

@@ -11,9 +11,9 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86/Utils/X86Utils.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
@@ -24,7 +24,7 @@
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86;
namespace {
@@ -109,7 +109,7 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
// to
// ```
// vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
// x86.avx512.dot vector<32xbf16> -> vector<16xf32>
// ```
//
// For example - for bf16 type (Flat layout):
@@ -126,9 +126,9 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
// vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
// vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
// ```
struct VectorContractToPackedTypeDotProduct
: public OpRewritePattern<vector::ContractionOp> {
@@ -384,7 +384,7 @@ struct VectorContractToPackedTypeDotProduct
rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
if (lhsTy.getElementType().isBF16()) {
dp = x86vector::DotBF16Op::create(
dp = x86::DotBF16Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
bitcastUnitDimPkType, castNonUnitDim);
@@ -392,12 +392,12 @@ struct VectorContractToPackedTypeDotProduct
if (lhsTy.getElementType().isSignlessInteger(8)) {
if (nonUnitDimAcc.front() == 16) {
dp = x86vector::AVX10DotInt8Op::create(
dp = x86::AVX10DotInt8Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim);
} else {
dp = x86vector::DotInt8Op::create(
dp = x86::DotInt8Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim);
@@ -415,7 +415,7 @@ struct VectorContractToPackedTypeDotProduct
} // namespace
void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
void x86::populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns) {
patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
}

View File

@@ -1,8 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorUtils
X86VectorUtils.cpp
add_mlir_dialect_library(MLIRX86Utils
X86Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector/Utils
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86/Utils
LINK_LIBS PUBLIC
MLIRAffineDialect

View File

@@ -1,4 +1,4 @@
//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===//
//===- X86Utils.cpp - MLIR Utilities for X86Ops -------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
#include "mlir/Dialect/X86/Utils/X86Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -23,7 +23,7 @@
#include <cassert>
namespace mlir {
namespace x86vector {
namespace x86 {
static FailureOr<SmallVector<mlir::utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) {
@@ -410,5 +410,5 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
return true;
}
} // namespace x86vector
} // namespace x86
} // namespace mlir

View File

@@ -96,7 +96,7 @@
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
@@ -152,7 +152,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
ub::UBDialect,
vector::VectorDialect,
wasmssa::WasmSSADialect,
x86vector::X86VectorDialect,
x86::X86Dialect,
xegpu::XeGPUDialect,
xevm::XeVMDialect>();
// clang-format on

View File

@@ -56,7 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -114,7 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
x86vector::registerTransformDialectExtension(registry);
x86::registerTransformDialectExtension(registry);
xegpu::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);

View File

@@ -328,11 +328,11 @@ declare_mlir_dialect_extension_python_bindings(
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/X86VectorTransformOps.td
TD_FILE dialects/X86TransformOps.td
SOURCES
dialects/transform/x86vector.py
dialects/transform/x86.py
DIALECT_NAME transform
EXTENSION_NAME x86vector_transform)
EXTENSION_NAME x86_transform)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -522,9 +522,9 @@ declare_mlir_dialect_python_bindings(
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/X86Vector.td
SOURCES dialects/x86vector.py
DIALECT_NAME x86vector)
TD_FILE dialects/X86.td
SOURCES dialects/x86.py
DIALECT_NAME x86)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects

View File

@@ -1,4 +1,4 @@
//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
//===-- X86.td - Entry point for x86 bindings --------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_X86VECTOR
#define PYTHON_BINDINGS_X86VECTOR
#ifndef PYTHON_BINDINGS_X86
#define PYTHON_BINDINGS_X86
include "mlir/Dialect/X86Vector/X86Vector.td"
include "mlir/Dialect/X86/X86.td"
#endif // PYTHON_BINDINGS_X86VECTOR
#endif // PYTHON_BINDINGS_X86

View File

@@ -1,4 +1,4 @@
//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
//===-- X86TransformOps.td ---------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
#ifndef PYTHON_BINDINGS_X86TRANSFORMOPS
#define PYTHON_BINDINGS_X86TRANSFORMOPS
include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
include "mlir/Dialect/X86/TransformOps/X86TransformOps.td"
#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
#endif // PYTHON_BINDINGS_X86TRANSFORMOPS

View File

@@ -2,4 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._x86vector_transform_ops_gen import *
from .._x86_transform_ops_gen import *

View File

@@ -2,5 +2,5 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._x86vector_ops_gen import *
from ._x86vector_ops_gen import _Dialect
from ._x86_ops_gen import *
from ._x86_ops_gen import _Dialect

View File

@@ -33,7 +33,7 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
"Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.")
option(MLIR_RUN_X86_TESTS "Run X86 tests.")
option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
@@ -78,7 +78,7 @@ llvm_canonicalize_cmake_booleans(
MLIR_INCLUDE_INTEGRATION_TESTS
MLIR_RUN_AMX_TESTS
MLIR_RUN_CUDA_TENSOR_CORE_TESTS
MLIR_RUN_X86VECTOR_TESTS
MLIR_RUN_X86_TESTS
MLIR_RUN_ARM_SVE_TESTS
MLIR_RUN_ARM_SME_TESTS
MLIR_RUN_CUDA_SM80_TESTS

View File

@@ -21,7 +21,7 @@
// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-x86={{[aA-zZ0-9]+}}
// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
// DEFAULT: vector-contract-lowering=dot

View File

@@ -691,7 +691,7 @@ func.func @rsqrt_scalar(%arg0: f32) -> f32 {
// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32>
// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32>
// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1>
// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32>
// AVX2: %[[VAL_9:.*]] = x86.avx.rsqrt %[[VAL_0]] : vector<8xf32>
// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32>
// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32>
// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32>
@@ -725,9 +725,9 @@ func.func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> {
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32>
// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0]
// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
// AVX2: %[[RSQRT0:.*]] = x86.avx.rsqrt %[[VEC0]]
// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1]
// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
// AVX2: %[[RSQRT1:.*]] = x86.avx.rsqrt %[[VEC1]]
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32>
@@ -746,9 +746,9 @@ func.func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> {
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
// AVX2-NOT: vector.shape_cast
// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0]
// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
// AVX2: %[[RSQRT0:.*]] = x86.avx.rsqrt %[[VEC0]]
// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1]
// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
// AVX2: %[[RSQRT1:.*]] = x86.avx.rsqrt %[[VEC1]]
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
// AVX2-NOT: vector.shape_cast
@@ -768,13 +768,13 @@ func.func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> {
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32>
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32>
// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0]
// AVX2: %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]]
// AVX2: %[[RSQRT00:.*]] = x86.avx.rsqrt %[[VEC00]]
// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1]
// AVX2: %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]]
// AVX2: %[[RSQRT01:.*]] = x86.avx.rsqrt %[[VEC01]]
// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0]
// AVX2: %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]]
// AVX2: %[[RSQRT10:.*]] = x86.avx.rsqrt %[[VEC10]]
// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1]
// AVX2: %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]]
// AVX2: %[[RSQRT11:.*]] = x86.avx.rsqrt %[[VEC11]]
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0]
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1]
// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0]

View File

@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
// NOTE: This file tests lowerings that are implemented in the X86Vector
// NOTE: This file tests lowerings that are implemented in the X86
// dialect. Since X86 does not support scalable vectors, all examples in this
// file use fixed-width vectors.

View File

@@ -1,7 +1,7 @@
// REQUIRES: target=x86{{.*}}
// RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
// RUN: -convert-vector-to-llvm="enable-x86" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sapphirerapids | \
@@ -9,7 +9,7 @@
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> vector<8xbf16> {
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
@@ -17,7 +17,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> vector<16xbf16> {
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:

View File

@@ -1,7 +1,7 @@
// REQUIRES: target=x86{{.*}}
// RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
// RUN: -convert-vector-to-llvm="enable-x86" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sapphirerapids | \
@@ -9,7 +9,7 @@
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> vector<4xf32> {
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: avx512bf16_dot_128:
@@ -17,7 +17,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> vector<8xf32> {
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: avx512bf16_dot_256:
@@ -25,7 +25,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> vector<16xf32> {
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: avx512bf16_dot_512:

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" | mlir-opt | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s
// CHECK-LABEL: func @avx512_mask_rndscale
func.func @avx512_mask_rndscale(
@@ -9,14 +9,14 @@ func.func @avx512_mask_rndscale(
%rnd_k = arith.constant 15 : i32
%rnd = arith.constant 42 : i32
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512"
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32>
%0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512"
%1 = x86vector.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64>
%1 = x86.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512"
%2 = x86vector.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32>
%2 = x86.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512"
%3 = x86vector.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64>
%3 = x86.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64>
// Keep results alive.
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
@@ -29,13 +29,13 @@ func.func @avx512_mask_compress(
{
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>)
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32>
%0 = x86.avx512.mask.compress %k1, %a1 : vector<16xf32>
// CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>)
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
%1 = x86vector.avx512.mask.compress %k1, %a1
%1 = x86.avx512.mask.compress %k1, %a1
{constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
%2 = x86.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
}
@@ -44,9 +44,9 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512"
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32>
%0, %1 = x86.avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512"
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
%2, %3 = x86.avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
@@ -55,7 +55,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> (vector<4xf32>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128"
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -64,7 +64,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> (vector<8xf32>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256"
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -73,7 +73,7 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> (vector<16xf32>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512"
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
@@ -82,7 +82,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> (vector<8xbf16>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256"
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}
@@ -91,7 +91,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> (vector<16xbf16>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512"
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}
@@ -99,7 +99,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
%0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
@@ -108,7 +108,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -117,7 +117,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -126,7 +126,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -135,7 +135,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -144,7 +144,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
%a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -153,7 +153,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
%a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -162,7 +162,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -171,7 +171,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -180,7 +180,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -189,7 +189,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -198,7 +198,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -207,7 +207,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -215,7 +215,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
// CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256"
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
%0 = x86.avx.rsqrt %a : vector<8xf32>
return %0 : vector<8xf32>
}
@@ -224,7 +224,7 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
{
// CHECK: llvm.mlir.constant(-1 : i8)
// CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256"
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
@@ -232,7 +232,7 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
%b: vector<16xi8>) -> vector<4xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
%0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
return %0 : vector<4xi32>
}
@@ -240,6 +240,6 @@ func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
%b: vector<32xi8>) -> vector<8xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
%0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
return %0 : vector<8xi32>
}

View File

@@ -4,10 +4,10 @@
func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
-> (vector<16xf32>, vector<8xf64>)
{
// CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<16xf32>
%0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>
// CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<8xf64>
%1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>
// CHECK: x86.avx512.mask.rndscale {{.*}}: vector<16xf32>
%0 = x86.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>
// CHECK: x86.avx512.mask.rndscale {{.*}}: vector<8xf64>
%1 = x86.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>
return %0, %1: vector<16xf32>, vector<8xf64>
}
@@ -15,10 +15,10 @@ func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32
func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
-> (vector<16xf32>, vector<8xf64>)
{
// CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<16xf32>
%0 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
// CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<8xf64>
%1 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
// CHECK: x86.avx512.mask.scalef {{.*}}: vector<16xf32>
%0 = x86.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
// CHECK: x86.avx512.mask.scalef {{.*}}: vector<8xf64>
%1 = x86.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
return %0, %1: vector<16xf32>, vector<8xf64>
}
@@ -26,10 +26,10 @@ func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16:
func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
{
// CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<16xi32>
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<8xi64>
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
// CHECK: x86.avx512.vp2intersect {{.*}} : vector<16xi32>
%0, %1 = x86.avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: x86.avx512.vp2intersect {{.*}} : vector<8xi64>
%2, %3 = x86.avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
@@ -38,12 +38,12 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
%k2: vector<8xi1>, %a2: vector<8xi64>)
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
{
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32>
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32>
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32>
%1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<8xi64>
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
// CHECK: x86.avx512.mask.compress {{.*}} : vector<16xf32>
%0 = x86.avx512.mask.compress %k1, %a1 : vector<16xf32>
// CHECK: x86.avx512.mask.compress {{.*}} : vector<16xf32>
%1 = x86.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
// CHECK: x86.avx512.mask.compress {{.*}} : vector<8xi64>
%2 = x86.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
}
@@ -51,8 +51,8 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> (vector<4xf32>)
{
// CHECK: x86vector.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
// CHECK: x86.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -60,8 +60,8 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> (vector<8xf32>)
{
// CHECK: x86vector.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
// CHECK: x86.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -69,8 +69,8 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> (vector<16xf32>)
{
// CHECK: x86vector.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
// CHECK: x86.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
@@ -78,9 +78,9 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> (vector<8xbf16>)
{
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK: x86.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK-SAME: vector<8xf32> -> vector<8xbf16>
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}
@@ -88,17 +88,17 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> (vector<16xbf16>)
{
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK: x86.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK-SAME: vector<16xf32> -> vector<16xbf16>
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}
// CHECK-LABEL: func @avx10_dot_i8_512
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
// CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
// CHECK: x86.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
%0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
@@ -106,9 +106,9 @@ func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -116,9 +116,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -126,9 +126,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -136,9 +136,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -146,9 +146,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
func.func @avxbf16_bcst_bf16_to_f32_128(
%a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -156,9 +156,9 @@ func.func @avxbf16_bcst_bf16_to_f32_128(
func.func @avxbf16_bcst_bf16_to_f32_256(
%a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -166,9 +166,9 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -176,9 +176,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -186,9 +186,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -196,9 +196,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -206,9 +206,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
func.func @avxf16_bcst_f16_to_f32_128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<4xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -216,40 +216,40 @@ func.func @avxf16_bcst_f16_to_f32_128(
func.func @avxf16_bcst_f16_to_f32_256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<8xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
// CHECK: x86vector.avx.rsqrt {{.*}} : vector<8xf32>
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
// CHECK: x86.avx.rsqrt {{.*}} : vector<8xf32>
%0 = x86.avx.rsqrt %a : vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avx_dot
func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
{
// CHECK: x86vector.avx.intr.dot {{.*}} : vector<8xf32>
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
// CHECK: x86.avx.intr.dot {{.*}} : vector<8xf32>
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avx_dot_i8_128
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
%b: vector<16xi8>) -> vector<4xi32> {
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
// CHECK: x86.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
%0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
return %0 : vector<4xi32>
}
// CHECK-LABEL: func @avx_dot_i8_256
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
%b: vector<32xi8>) -> vector<8xi32> {
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
// CHECK: x86.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
%0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
return %0 : vector<8xi32>
}

View File

@@ -8,17 +8,17 @@ func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
{
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec
return %12 : !vec
@@ -28,25 +28,25 @@ func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
// The vector.fma at %5 is moved along with its operands after %8.
// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
// Odd-Indexed FMAs
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
// Even-Indexed FMAs
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op
transform.yield
}
@@ -62,17 +62,17 @@ func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
{
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %4, %3, %2 : !vec
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec
return %12 : !vec
@@ -81,25 +81,25 @@ func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
// The vector.fma at %5 is moved along with its operands after %8.
// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
// Odd-Indexed FMAs
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
// Even-Indexed FMAs
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg1
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op
transform.yield
}
@@ -116,18 +116,18 @@ func.func @shuffle_fma_with_shape_cast(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
{
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec
%res1 = vector.shape_cast %5 : !vec to !vecOut
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%res2 = vector.shape_cast %11 : !vec to !vecOut
%12 = arith.addf %res1, %res2 : !vecOut
@@ -136,19 +136,19 @@ func.func @shuffle_fma_with_shape_cast(
// CHECK-LABEL: @shuffle_fma_with_shape_cast
// Odd-Indexed FMAs
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
// Even-Indexed FMAs
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
// CHECK: vector.shape_cast
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
// CHECK: vector.shape_cast
@@ -156,7 +156,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op
transform.yield
}
@@ -172,16 +172,16 @@ func.func @negative_fma_operand_has_multiple_consumer(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
%arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
{
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg5 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
%8 = vector.fma %3, %7, %arg5 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg5 : !vec
return %12 : !vec
@@ -190,18 +190,18 @@ func.func @negative_fma_operand_has_multiple_consumer(
// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore,
// the rewrite is not applied.
// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op
transform.yield
}
@@ -217,17 +217,17 @@ func.func @negative_fma_has_multiple_consumer(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
{
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %5 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec
return %12 : !vec
@@ -235,17 +235,17 @@ func.func @negative_fma_has_multiple_consumer(
// vector.fma at %5 has two uses; therefore no re-write applied.
// CHECK-LABEL: @negative_fma_has_multiple_consumer
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op
transform.yield
}
@@ -260,28 +260,28 @@ func.func @negative_no_shuffle_outside_block(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
{
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec
%loop = scf.if %arg7 -> (vector<8xf32>) {
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec
scf.yield %12 : vector<8xf32>
} else {
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec
scf.yield %12 : vector<8xf32>
@@ -293,19 +293,19 @@ func.func @negative_no_shuffle_outside_block(
// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
// applied.
// CHECK-LABEL: @negative_no_shuffle_outside_block
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
// CHECK: scf.if
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op
transform.yield
}

View File

@@ -24,7 +24,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
@@ -57,7 +57,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
@@ -90,7 +90,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
@@ -134,7 +134,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
@@ -160,7 +160,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
@@ -191,9 +191,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

View File

@@ -30,18 +30,18 @@ func.func @brgemm_to_fma(
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 1] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 0] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg1[%c0, %c0, %c0, %c0] {{.*}} : memref<1x1x32x2xbf16> to memref<1x1x8x2xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -76,18 +76,18 @@ func.func @brgemm_to_fma_load(
}
// CHECK-LABEL: @brgemm_to_fma_load
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -122,18 +122,18 @@ func.func @brgemm_to_fma_load_bcst_B(
}
// CHECK-LABEL: @brgemm_to_fma_load_bcst_B
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -168,18 +168,18 @@ func.func @batch_matmul_fma_load(
}
// CHECK-LABEL: @batch_matmul_fma_load
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -214,18 +214,18 @@ func.func @matmul_outer_product_to_fma_load(
}
// CHECK-LABEL: @matmul_outer_product_to_fma_load
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -285,10 +285,10 @@ func.func @matmul_to_fma_flat_layout(
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
@@ -297,7 +297,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -357,10 +357,10 @@ func.func @matmul_to_fma_flat_layout_load(
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
@@ -369,7 +369,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -448,10 +448,10 @@ func.func @matmul_to_fma_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: m
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: scf.for
// CHECK: scf.for
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: scf.yield
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
@@ -461,7 +461,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -524,10 +524,10 @@ func.func @matmul_to_fma_flat_layout_bcstB(
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
@@ -536,7 +536,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -572,18 +572,18 @@ func.func @matmul_dynamic_offset(
// CHECK-LABEL: @matmul_dynamic_offset
// CHECK: memref.subview %arg0[%arg3, %c0, 1]{{.*}}
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -618,18 +618,18 @@ func.func @matmul_to_fma_load_bcst_B(
}
// CHECK-LABEL: @matmul_to_fma_load_bcst_B
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -664,18 +664,18 @@ func.func @many_dimensions(
}
// CHECK-LABEL: @many_dimensions
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -744,7 +744,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -785,7 +785,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -841,17 +841,17 @@ func.func @negative_offset_diff_is_not_8(
}
// CHECK-LABEL: @negative_offset_diff_is_not_8
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -907,17 +907,17 @@ func.func @negative_vector_contracts_not_in_order(
}
// CHECK-LABEL: @negative_vector_contracts_not_in_order
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -976,17 +976,17 @@ func.func @negative_flat_layout_dynamic_index(
}
// CHECK-LABEL: @negative_flat_layout_dynamic_index
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1045,17 +1045,17 @@ func.func @negative_non_unit_K_dim(
}
// CHECK-LABEL: @negative_non_unit_K_dim
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1089,7 +1089,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1132,7 +1132,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1176,7 +1176,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1222,7 +1222,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1266,7 +1266,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
@@ -1308,9 +1308,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op
transform.yield
}
}

View File

@@ -27,7 +27,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -61,7 +61,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -129,7 +129,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -163,7 +163,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -197,7 +197,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -233,7 +233,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -269,7 +269,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -303,7 +303,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}
@@ -337,7 +337,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma
transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op
transform.yield
}

View File

@@ -20,13 +20,13 @@ func.func @brgemm_to_bf16dp(
// CHECK-LABEL: @brgemm_to_bf16dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -54,13 +54,13 @@ func.func @brgemm_to_bf16dp_bcst_B(
// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
// CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -88,13 +88,13 @@ func.func @brgemm_to_avx10int8dp(
// CHECK-LABEL: @brgemm_to_avx10int8dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx10.dot.i8
// CHECK: x86.avx10.dot.i8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -123,13 +123,13 @@ func.func @batch_matmul_avx10int8dp_bcst_B(
// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
// CHECK: vector.broadcast
// CHECK: x86vector.avx10.dot.i8
// CHECK: x86.avx10.dot.i8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -157,13 +157,13 @@ func.func @brgemm_to_int8dp(
// CHECK-LABEL: @brgemm_to_int8dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8
// CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -191,13 +191,13 @@ func.func @batch_matmul_bf16dp(
// CHECK-LABEL: @batch_matmul_bf16dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -226,13 +226,13 @@ func.func @batch_matmul_int8dp(
// CHECK-LABEL: @batch_matmul_int8dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8
// CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -261,13 +261,13 @@ func.func @batch_matmul_int8dp_bcst_B(
// CHECK-LABEL: @batch_matmul_int8dp_bcst_B
// CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8
// CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -295,13 +295,13 @@ func.func @matmul_outer_product_to_bf16dp(
// CHECK-LABEL: @matmul_outer_product_to_bf16dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -329,13 +329,13 @@ func.func @matmul_outer_product_to_int8dp(
// CHECK-LABEL: @matmul_outer_product_to_int8dp
// CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8
// CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -395,8 +395,8 @@ func.func @matmul_bf16dp_flat_layout(
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
// CHECK: x86vector.avx512.dot
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
// CHECK: x86.avx512.dot
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
@@ -404,7 +404,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -485,8 +485,8 @@ func.func @brmatmul_bf16dp_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1:
// CHECK: scf.for
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
// CHECK: x86vector.avx512.dot
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
// CHECK: x86.avx512.dot
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
@@ -496,7 +496,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -566,15 +566,15 @@ func.func @matmul_bf16dp_flat_layout_B_shuffled(
}
// CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled
// CHECK: x86vector.avx512.dot
// CHECK: x86vector.avx512.dot
// CHECK: x86.avx512.dot
// CHECK: x86.avx512.dot
// CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -601,14 +601,14 @@ func.func @negative_invalid_vc_kind(
}
// CHECK-LABEL: @negative_invalid_vc_kind
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -635,14 +635,14 @@ func.func @negative_false_vnni_bf16(
}
// CHECK-LABEL: @negative_false_vnni_bf16
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -669,14 +669,14 @@ func.func @negative_false_vnni_int8(
}
// CHECK-LABEL: @negative_false_vnni_int8
// CHECK-NOT: x86vector.avx.dot.i8
// CHECK-NOT: x86.avx.dot.i8
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -703,14 +703,14 @@ func.func @negative_batch_dimension(
}
// CHECK-LABEL: @negative_batch_dimension
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -737,14 +737,14 @@ func.func @negative_brgemm_dimension(
}
// CHECK-LABEL: @negative_brgemm_dimension
// CHECK-NOT: x86vector.avx.dot.i8
// CHECK-NOT: x86.avx.dot.i8
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -771,14 +771,14 @@ func.func @negative_float_acc_type(
}
// CHECK-LABEL: @negative_float_acc_type
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -805,14 +805,14 @@ func.func @negative_int_acc_type(
}
// CHECK-LABEL: @negative_int_acc_type
// CHECK-NOT: x86vector.avx.dot.i8
// CHECK-NOT: x86.avx.dot.i8
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -839,14 +839,14 @@ func.func @negative_wrong_vnni_blocking_factor_bf16(
}
// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -873,14 +873,14 @@ func.func @negative_brgemm_not_vnni(
}
// CHECK-LABEL: @negative_brgemm_not_vnni
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -907,15 +907,15 @@ func.func @negative_wrong_vector_shape_int8(
}
// CHECK-LABEL: @negative_wrong_vector_shape_int8
// CHECK-NOT: x86vector.avx.dot.i8
// CHECK-NOT: x86vector.avx10.dot.i8
// CHECK-NOT: x86.avx.dot.i8
// CHECK-NOT: x86.avx10.dot.i8
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -942,14 +942,14 @@ func.func @negative_wrong_vector_shape_bf16(
}
// CHECK-LABEL: @negative_wrong_vector_shape_bf16
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1005,14 +1005,14 @@ func.func @negative_flat_other_dim_is_not_2(
}
// CHECK-LABEL: @negative_flat_other_dim_is_not_2
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1069,14 +1069,14 @@ func.func @negative_flat_offset_diff_is_not16(
}
// CHECK-LABEL: @negative_flat_offset_diff_is_not16
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1132,14 +1132,14 @@ func.func @negative_flat_dynamic_offset(
}
// CHECK-LABEL: @negative_flat_dynamic_offset
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1196,14 +1196,14 @@ func.func @negative_flat_read_after_contract(
}
// CHECK-LABEL: @negative_flat_read_after_contract
// CHECK-NOT: x86vector.avx512.dot
// CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1271,7 +1271,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1338,7 +1338,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1406,7 +1406,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}
@@ -1473,7 +1473,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op
transform.yield
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -9,7 +9,7 @@ func.func @entry() -> i32 {
%a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32>
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
%r = x86vector.avx.intr.dot %a, %b : vector<8xf32>
%r = x86.avx.intr.dot %a, %b : vector<8xf32>
%1 = vector.extract %r[%i0] : f32 from vector<8xf32>
%2 = vector.extract %r[%i4] : f32 from vector<8xf32>

View File

@@ -1,7 +1,7 @@
import sys
# X86Vector tests must be enabled via build flag.
if not config.mlir_run_x86vector_tests:
# X86 tests must be enabled via build flag.
if not config.mlir_run_x86_tests:
config.unsupported = True
# No JIT on win32.

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -8,8 +8,8 @@ func.func @entry() -> i32 {
%a = arith.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32>
%k = arith.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1>
%r1 = x86vector.avx512.mask.compress %k, %a : vector<16xf32>
%r2 = x86vector.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
%r1 = x86.avx512.mask.compress %k, %a : vector<16xf32>
%r2 = x86.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
vector.print %r1 : vector<16xf32>
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
@@ -18,7 +18,7 @@ func.func @entry() -> i32 {
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 )
%src = arith.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32>
%r3 = x86vector.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
%r3 = x86.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
vector.print %r3 : vector<16xf32>
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 )

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -7,7 +7,7 @@ func.func @entry() -> i32 {
%i0 = arith.constant 0 : i32
%v = arith.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32>
%r = x86vector.avx.rsqrt %v : vector<8xf32>
%r = x86.avx.rsqrt %v : vector<8xf32>
// `rsqrt` may produce slightly different results on Intel and AMD machines: accept both results here.
// CHECK: {{( 2.82[0-9]*, 1.99[0-9]*, 1.41[0-9]*, 0.99[0-9]*, 0.70[0-9]*, 0.49[0-9]*, 0.35[0-9]*, 0.24[0-9]* )}}
vector.print %r : vector<8xf32>

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -35,11 +35,11 @@
func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>,
%v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 {
// Compute intersection of indices.
%k0, %k1 = x86vector.avx512.vp2intersect %v_A, %v_C : vector<8xi64>
%k0, %k1 = x86.avx512.vp2intersect %v_A, %v_C : vector<8xi64>
// Filter out values without match and compress vector.
%p0 = x86vector.avx512.mask.compress %k0, %v_B : vector<8xf64>
%p1 = x86vector.avx512.mask.compress %k1, %v_D : vector<8xf64>
%p0 = x86.avx512.mask.compress %k0, %v_B : vector<8xf64>
%p1 = x86.avx512.mask.compress %k1, %v_D : vector<8xf64>
// Dense vector dot product.
%acc = arith.constant 0.0 : f64

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -40,7 +40,7 @@ func.func @entry() -> i32 {
vector.print %w9 : vector<16xi32>
// CHECK: ( 1, 1, 1, 1, 2, 1, 1, -219, 12, 12, 12, 0, 0, 0, 1, 0 )
%k1, %k2 = x86vector.avx512.vp2intersect %v9, %w9 : vector<16xi32>
%k1, %k2 = x86.avx512.vp2intersect %v9, %w9 : vector<16xi32>
vector.print %k1 : vector<16xi1>
// CHECK: ( 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1 )

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: | mlir-translate --mlir-to-llvmir \
// RUN: | FileCheck %s
@@ -11,9 +11,9 @@ func.func @LLVM_x86_avx512_mask_ps_512(
%rnd_k = arith.constant 15 : i32
%rnd = arith.constant 42 : i32
// CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32>
%0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32>
// CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32>
%1 = x86.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32>
return %1 : vector<16xf32>
}
@@ -26,9 +26,9 @@ func.func @LLVM_x86_avx512_mask_pd_512(
%rnd_k = arith.constant 15 : i32
%rnd = arith.constant 42 : i32
// CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64>
%0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64>
// CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64>
%1 = x86.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64>
return %1 : vector<8xf64>
}
@@ -37,7 +37,7 @@ func.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
-> vector<16xf32>
{
// CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
%0 = x86vector.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32>
%0 = x86.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32>
return %0 : vector<16xf32>
}
@@ -46,7 +46,7 @@ func.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
-> (vector<16xi1>, vector<16xi1>)
{
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<16xi32>
%0, %1 = x86.avx512.vp2intersect %a, %b : vector<16xi32>
return %0, %1 : vector<16xi1>, vector<16xi1>
}
@@ -55,7 +55,7 @@ func.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
-> (vector<8 x i1>, vector<8 x i1>)
{
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<8xi64>
%0, %1 = x86.avx512.vp2intersect %a, %b : vector<8xi64>
return %0, %1 : vector<8 x i1>, vector<8 x i1>
}
@@ -65,7 +65,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_128(
) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -75,7 +75,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_256(
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -85,7 +85,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_512(
) -> vector<16xf32>
{
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
@@ -94,7 +94,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
%a: vector<8xf32>) -> vector<8xbf16>
{
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a
: vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}
@@ -104,7 +104,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
%a: vector<16xf32>) -> vector<16xbf16>
{
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a
: vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}
@@ -113,7 +113,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
// CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
%0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
@@ -122,7 +122,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -131,7 +131,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -140,7 +140,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -149,7 +149,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -158,7 +158,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
%a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -167,7 +167,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
%a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -176,7 +176,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -185,7 +185,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -194,7 +194,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -203,7 +203,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -212,7 +212,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -221,7 +221,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -229,7 +229,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
%0 = x86.avx.rsqrt %a : vector<8xf32>
return %0 : vector<8xf32>
}
@@ -239,7 +239,7 @@ func.func @LLVM_x86_avx_dp_ps_256(
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
@@ -247,7 +247,7 @@ func.func @LLVM_x86_avx_dp_ps_256(
func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
%b: vector<16xi8>) -> vector<4xi32> {
// CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
%0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
return %0 : vector<4xi32>
}
@@ -255,6 +255,6 @@ func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
%b: vector<32xi8>) -> vector<8xi32> {
// CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
%0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
return %0 : vector<8xi32>
}

View File

@@ -10,5 +10,5 @@ mlir_target_link_libraries(MLIRMathTestPasses PUBLIC
MLIRPass
MLIRTransformUtils
MLIRVectorDialect
MLIRX86VectorDialect
MLIRX86Dialect
)

View File

@@ -15,7 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -37,7 +37,7 @@ struct TestMathPolynomialApproximationPass
registry.insert<arith::ArithDialect, math::MathDialect,
vector::VectorDialect>();
if (enableAvx2)
registry.insert<x86vector::X86VectorDialect>();
registry.insert<x86::X86Dialect>();
}
StringRef getArgument() const final {
return "test-math-polynomial-approximation";
@@ -49,7 +49,7 @@ struct TestMathPolynomialApproximationPass
Option<bool> enableAvx2{
*this, "enable-avx2",
llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
"X86Vector dialect"),
"X86 dialect"),
llvm::cl::init(false)};
};
} // namespace

View File

@@ -20,5 +20,5 @@ mlir_target_link_libraries(MLIRVectorTestPasses PUBLIC
MLIRTransformUtils
MLIRVectorDialect
MLIRVectorToSCF
MLIRX86VectorDialect
MLIRX86Dialect
)

View File

@@ -54,7 +54,7 @@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
if config.mlir_run_arm_sve_tests:
config.available_features.add("mlir_arm_sve_tests")
config.mlir_run_arm_sme_tests = @MLIR_RUN_ARM_SME_TESTS@
config.mlir_run_x86vector_tests = @MLIR_RUN_X86VECTOR_TESTS@
config.mlir_run_x86_tests = @MLIR_RUN_X86_TESTS@
config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@"
config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@
config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@

View File

@@ -41,7 +41,7 @@
// CHECK-SAME: tosa
// CHECK-SAME: transform
// CHECK-SAME: vector
// CHECK-SAME: x86vector
// CHECK-SAME: x86
// RUN: mlir-opt --help-hidden | FileCheck %s -check-prefix=CHECK-HELP
// CHECK-HELP: -p - Alias for --pass-pipeline

View File

@@ -0,0 +1,5 @@
import sys
# X86 tests must be enabled via build flag.
if not config.mlir_run_x86_tests:
config.unsupported = True

View File

@@ -3,7 +3,7 @@
// RUN: -convert-scf-to-cf \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-cf-to-llvm \
// RUN: -convert-vector-to-llvm="enable-x86vector" \
// RUN: -convert-vector-to-llvm="enable-x86" \
// RUN: -convert-math-to-llvm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \

View File

@@ -1,5 +0,0 @@
import sys
# X86Vector tests must be enabled via build flag.
if not config.mlir_run_x86vector_tests:
config.unsupported = True

View File

@@ -2,7 +2,7 @@
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import x86vector
from mlir.dialects.transform import x86
def run_apply_patterns(f):
@@ -28,13 +28,13 @@ def run_apply_patterns(f):
def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
x86vector.ApplyVectorContractToFMAPatternsOp()
# CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
# CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
# CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
x86vector.ApplySinkVectorProducerOpsPatternsOp()
# CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
# CHECK: transform.apply_patterns.x86.vector_contract_to_fma
x86.ApplyVectorContractToFMAPatternsOp()
# CHECK: transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
x86.ApplyVectorContractToPackedTypeDotProductPatternsOp()
# CHECK: transform.apply_patterns.x86.vector_contract_bf16_to_fma
x86.ApplyVectorContractBF16ToFMAPatternsOp()
# CHECK: transform.apply_patterns.x86.sink_vector_producer_ops
x86.ApplySinkVectorProducerOpsPatternsOp()
# CHECK: transform.apply_patterns.x86.shuffle_vector_fma_ops
x86.ApplyShuffleVectorFMAOpsPatternsOp()

View File

@@ -3,7 +3,7 @@
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.func as func
import mlir.dialects.x86vector as x86vector
import mlir.dialects.x86 as x86
def run(f):
@@ -21,13 +21,11 @@ def testAvxOp():
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
def avx_op(arg):
return x86vector.BcstToPackedF32Op(
a=arg, dst=VectorType.get((8,), F32Type.get())
)
return x86.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
# CHECK-LABEL: func @avx_op(
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
# CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
# CHECK: %[[VAL:.+]] = x86.avx.bcst_to_f32.packed %[[ARG]]
# CHECK: return %[[VAL]] : vector<8xf32>
# CHECK: }
print(module)
@@ -41,13 +39,13 @@ def testAvx512Op():
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
def avx512_op(arg):
return x86vector.CvtPackedF32ToBF16Op(
return x86.CvtPackedF32ToBF16Op(
a=arg, dst=VectorType.get((8,), BF16Type.get())
)
# CHECK-LABEL: func @avx512_op(
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
# CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
# CHECK: %[[VAL:.+]] = x86.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
# CHECK: return %[[VAL]] : vector<8xbf16>
# CHECK: }
print(module)
@@ -65,12 +63,12 @@ def testAvx10Op():
VectorType.get((64,), IntegerType.get(8)),
)
def avx10_op(*args):
return x86vector.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
return x86.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
# CHECK-LABEL: func @avx10_op(
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
# CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
# CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
# CHECK: %[[VAL:.+]] = x86.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
# CHECK: return %[[VAL]] : vector<16xi32>
# CHECK: }
print(module)