[mlir][x86] Rename x86vector to x86 (#183311)

Renames 'x86vector' dialect to 'x86'.

This is the first PR in series of cleanups around dialects targeting x86
platforms.
The new naming scheme is shorter, cleaner, and opens possibility of
integrating other x86-specific operations not strictly fitting pure
vector representation. For example, the generalization will allow for
future merger of AMX dialect into the x86 dialect to create one-stop x86
operations collection and boost discoverability.
This commit is contained in:
Adam Siemieniuk
2026-02-26 11:21:58 +01:00
committed by GitHub
parent 6d7ec4b7c3
commit 67ac275fee
80 changed files with 750 additions and 768 deletions

View File

@@ -105,7 +105,7 @@ available, should be contacted first, as they're more active in those areas.
* arm_sve Dialect ([@banach-space](https://github.com/banach-space)) * arm_sve Dialect ([@banach-space](https://github.com/banach-space))
* ArmSME Dialect ([@banach-space](https://github.com/banach-space)) * ArmSME Dialect ([@banach-space](https://github.com/banach-space))
* amx Dialect ([@adam-smnk](https://github.com/adam-smnk)) * amx Dialect ([@adam-smnk](https://github.com/adam-smnk))
* x86vector Dialect ([@adam-smnk](https://github.com/adam-smnk)) * x86 Dialect ([@adam-smnk](https://github.com/adam-smnk))
* vcix Dialect ([@mshockwave](https://github.com/mshockwave)) * vcix Dialect ([@mshockwave](https://github.com/mshockwave))
#### Paradigm Dialects #### Paradigm Dialects

View File

@@ -6,7 +6,7 @@ overall flow is two-stage:
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for 1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md), dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
[X86Vector](Dialects/X86Vector.md) or [ArmNeon](Dialects/ArmNeon.md); [X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
2. **translation** of MLIR dialects to LLVM IR. 2. **translation** of MLIR dialects to LLVM IR.
This flow allows the non-trivial transformation to be performed within MLIR This flow allows the non-trivial transformation to be performed within MLIR

View File

@@ -1,4 +1,4 @@
//===-- mlir-c/Dialect/x86Vector.h - C API for x86Vector Dialect --*- C -*-===// //===-- mlir-c/Dialect/x86.h - C API for x86 Dialect --------------*- C -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM // Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions. // Exceptions.
@@ -7,8 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_X86VECTOR_H #ifndef MLIR_C_DIALECT_X86_H
#define MLIR_C_DIALECT_X86VECTOR_H #define MLIR_C_DIALECT_X86_H
#include "mlir-c/IR.h" #include "mlir-c/IR.h"
@@ -16,10 +16,10 @@
extern "C" { extern "C" {
#endif #endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(X86Vector, x86vector); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(X86, x86);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif // MLIR_C_DIALECT_X86VECTOR_H #endif // MLIR_C_DIALECT_X86_H

View File

@@ -1521,7 +1521,7 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
operations. The lowering pass provides several options to control operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects that enable the use of one or more architectural-specific dialects
(AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the (AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
architectural-neutral vector dialect lowering. architectural-neutral vector dialect lowering.
}]; }];
@@ -1564,10 +1564,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false", "bool", /*default=*/"false",
"Enables the use of Arm FEAT_BF16 instructions while lowering " "Enables the use of Arm FEAT_BF16 instructions while lowering "
"the vector dialect.">, "the vector dialect.">,
Option<"x86Vector", "enable-x86vector", Option<"x86", "enable-x86",
"bool", /*default=*/"false", "bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector " "Enables the use of X86 dialect while lowering the vector dialect.">,
"dialect.">,
Option<"vectorContractLowering", "vector-contract-lowering", Option<"vectorContractLowering", "vector-contract-lowering",
"vector::VectorContractLowering", "vector::VectorContractLowering",
/*default=*/"vector::VectorContractLowering::Dot", /*default=*/"vector::VectorContractLowering::Dot",

View File

@@ -42,5 +42,5 @@ add_subdirectory(UB)
add_subdirectory(Utils) add_subdirectory(Utils)
add_subdirectory(Vector) add_subdirectory(Vector)
add_subdirectory(WasmSSA) add_subdirectory(WasmSSA)
add_subdirectory(X86Vector) add_subdirectory(X86)
add_subdirectory(XeGPU) add_subdirectory(XeGPU)

View File

@@ -19,7 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/TilingInterface.h"

View File

@@ -118,9 +118,9 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
desc("Enables the use of ArmSVE dialect while lowering the vector " desc("Enables the use of ArmSVE dialect while lowering the vector "
"dialect"), "dialect"),
init(false)}; init(false)};
PassOptions::Option<bool> x86Vector{ PassOptions::Option<bool> x86{
*this, "enable-x86vector", *this, "enable-x86",
desc("Enables the use of X86Vector dialect while lowering the vector " desc("Enables the use of X86 dialect while lowering the vector "
"dialect"), "dialect"),
init(false)}; init(false)};
@@ -169,7 +169,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
opts.armNeon = armNeon; opts.armNeon = armNeon;
opts.armSVE = armSVE; opts.armSVE = armSVE;
opts.amx = amx; opts.amx = amx;
opts.x86Vector = x86Vector; opts.x86 = x86;
return opts; return opts;
} }
}; };

View File

@@ -0,0 +1,7 @@
add_mlir_dialect(X86 x86)
add_mlir_doc(X86 X86 Dialects/ -gen-dialect-doc -dialect=x86)
add_mlir_interface(X86Interfaces)
add_dependencies(MLIRX86IncGen MLIRX86InterfacesIncGen)
add_subdirectory(TransformOps)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS X86TransformOps.td)
mlir_tablegen(X86TransformOps.h.inc -gen-op-decls)
mlir_tablegen(X86TransformOps.cpp.inc -gen-op-defs)
add_mlir_dialect_tablegen_target(MLIRX86TransformOpsIncGen)

View File

@@ -1,4 +1,4 @@
//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===// //===- X86TransformOps.h - X86 transform ops --------------------*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,26 +6,26 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H #ifndef MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H #define MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// X86Vector Transform Operations // X86 Transform Operations
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc" #include "mlir/Dialect/X86/TransformOps/X86TransformOps.h.inc"
namespace mlir { namespace mlir {
class DialectRegistry; class DialectRegistry;
namespace x86vector { namespace x86 {
void registerTransformDialectExtension(DialectRegistry &registry); void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace x86vector } // namespace x86
} // namespace mlir } // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H #endif // MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H

View File

@@ -1,4 +1,4 @@
//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===// //===- X86TransformOps.td - X86 transform ops --------------*- tablegen -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef X86VECTOR_TRANSFORM_OPS #ifndef X86_TRANSFORM_OPS
#define X86VECTOR_TRANSFORM_OPS #define X86_TRANSFORM_OPS
include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
@@ -18,7 +18,7 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/IR/RegionKindInterface.td" include "mlir/IR/RegionKindInterface.td"
def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect, def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_fma", "apply_patterns.x86.vector_contract_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{ let description = [{
Collect patterns to lower a F32 type vector.contract operation to a FMA. Collect patterns to lower a F32 type vector.contract operation to a FMA.
@@ -28,7 +28,7 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
} }
def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect, def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product", "apply_patterns.x86.vector_contract_to_packed_type_dot_product",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{ let description = [{
Collect patterns to lower a BF16/Int8 type vector.contract operation Collect patterns to lower a BF16/Int8 type vector.contract operation
@@ -39,7 +39,7 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
} }
def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect, def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_bf16_to_fma", "apply_patterns.x86.vector_contract_bf16_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{ let description = [{
Collect patterns to lower a BF16 type vector.contract operation Collect patterns to lower a BF16 type vector.contract operation
@@ -50,7 +50,7 @@ def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
} }
def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect, def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.sink_vector_producer_ops", "apply_patterns.x86.sink_vector_producer_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{ let description = [{
Collect patterns to sink vector producer operations forward in a block to Collect patterns to sink vector producer operations forward in a block to
@@ -61,10 +61,10 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
} }
def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect, def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.shuffle_vector_fma_ops", "apply_patterns.x86.shuffle_vector_fma_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{ let description = [{
Collect patterns to shuffle FMAs with x86vector operations as operands Collect patterns to shuffle FMAs with x86 operations as operands
such that FMAs are grouped with respect to odd/even packed index. such that FMAs are grouped with respect to odd/even packed index.
}]; }];
@@ -72,5 +72,4 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
} }
#endif // X86VECTOR_TRANSFORM_OPS #endif // X86_TRANSFORM_OPS

View File

@@ -1,4 +1,4 @@
//=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=// //=- Transforms.h - X86 Dialect Transformation Entrypoints --------*- C++ -*-=//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H #ifndef MLIR_DIALECT_X86_TRANSFORMS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H #define MLIR_DIALECT_X86_TRANSFORMS_H
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
@@ -18,7 +18,7 @@ class LLVMConversionTarget;
class LLVMTypeConverter; class LLVMTypeConverter;
class RewritePatternSet; class RewritePatternSet;
namespace x86vector { namespace x86 {
/// Helper class to factor out the creation and extraction of masks from nibs. /// Helper class to factor out the creation and extraction of masks from nibs.
struct MaskHelper { struct MaskHelper {
@@ -100,7 +100,7 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
// range by placing them at their earliest legal use site. // range by placing them at their earliest legal use site.
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns); void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
// Shuffles FMAs with x86vector operations as operands such that FMAs are // Shuffles FMAs with x86 operations as operands such that FMAs are
// grouped with respect to odd/even packed index. // grouped with respect to odd/even packed index.
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns); void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
@@ -196,17 +196,17 @@ void populateSpecializedTransposeLoweringPatterns(
int benefit = 10); int benefit = 10);
} // namespace avx2 } // namespace avx2
} // namespace x86vector } // namespace x86
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM /// Collect a set of patterns to lower X86 ops to ops that map to LLVM
/// intrinsics. /// intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns( void populateX86LegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns); const LLVMTypeConverter &converter, RewritePatternSet &patterns);
/// Configure the target to support lowering X86Vector ops to ops that map to /// Configure the target to support lowering X86 ops to ops that map to
/// LLVM intrinsics. /// LLVM intrinsics.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target); void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
} // namespace mlir } // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H #endif // MLIR_DIALECT_X86_TRANSFORMS_H

View File

@@ -1,4 +1,4 @@
//===- X86VectorUtils.h - X86Vector Utilities -------------------*- C++ -*-===// //===- X86Utils.h - X86 Utilities -------------------------------*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_ #ifndef MLIR_DIALECT_X86_UTILS_X86UTILS_H_
#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_ #define MLIR_DIALECT_X86_UTILS_X86UTILS_H_
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@@ -22,7 +22,7 @@ namespace mlir {
class AffineMap; class AffineMap;
class Operation; class Operation;
namespace x86vector { namespace x86 {
// Return true if the operation is in VNNI layout. // Return true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor. // Optionally, the check can be constrained to a specific VNNI blocking factor.
@@ -63,7 +63,7 @@ LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
Operation *opA, Operation *opB, Operation *opA, Operation *opB,
int64_t nonUnitDimAcc, VectorType accTy); int64_t nonUnitDimAcc, VectorType accTy);
} // namespace x86vector } // namespace x86
} // namespace mlir } // namespace mlir
#endif // MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_ #endif // MLIR_DIALECT_X86_UTILS_X86UTILS_H_

View File

@@ -1,4 +1,4 @@
//===-- X86VectorOps.td - X86Vector dialect operation defs -*- tablegen -*-===// //===-- X86Ops.td - X86 dialect operation defs -------------*- tablegen -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,25 +6,25 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// This file defines the basic operations for the X86Vector dialect. // This file defines the basic operations for the X86 dialect.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef X86VECTOR_OPS #ifndef X86_OPS
#define X86VECTOR_OPS #define X86_OPS
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/X86Vector/X86VectorInterfaces.td" include "mlir/Dialect/X86/X86Interfaces.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// X86Vector dialect definition // X86 dialect definition
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def X86Vector_Dialect : Dialect { def X86_Dialect : Dialect {
let name = "x86vector"; let name = "x86";
let cppNamespace = "::mlir::x86vector"; let cppNamespace = "::mlir::x86";
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -33,7 +33,7 @@ def X86Vector_Dialect : Dialect {
// Operation that is part of the input dialect. // Operation that is part of the input dialect.
class AVX512_Op<string mnemonic, list<Trait> traits = []> : class AVX512_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {} Op<X86_Dialect, "avx512." # mnemonic, traits> {}
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//
// MaskCompressOp // MaskCompressOp
@@ -279,7 +279,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
Example: Example:
```mlir ```mlir
%dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> %dst = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
``` ```
}]; }];
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src, let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -323,7 +323,7 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
Example: Example:
```mlir ```mlir
%dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> %dst = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
``` ```
}]; }];
let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a); let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
@@ -349,7 +349,7 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
// Operation that is part of the input dialect. // Operation that is part of the input dialect.
class AVX10_Op<string mnemonic, list<Trait> traits = []> : class AVX10_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {} Op<X86_Dialect, "avx10." # mnemonic, traits> {}
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//
// AVX10 Int8 Dot // AVX10 Int8 Dot
@@ -376,7 +376,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
Example: Example:
```mlir ```mlir
%dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32> %dst = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
``` ```
}]; }];
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w, let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
@@ -401,13 +401,13 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
// Operation that is part of the input dialect. // Operation that is part of the input dialect.
class AVX_Op<string mnemonic, list<Trait> traits = []> : class AVX_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx." # mnemonic, traits> {} Op<X86_Dialect, "avx." # mnemonic, traits> {}
// Operation that may be part of the input dialect, but whose // Operation that may be part of the input dialect, but whose
// form is somewhere between the user view of the operation // form is somewhere between the user view of the operation
// and the actual lower level intrinsic in LLVM IR. // and the actual lower level intrinsic in LLVM IR.
class AVX_LowOp<string mnemonic, list<Trait> traits = []> : class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {} Op<X86_Dialect, "avx.intr." # mnemonic, traits> {}
//----------------------------------------------------------------------------// //----------------------------------------------------------------------------//
// AVX Rsqrt // AVX Rsqrt
@@ -448,7 +448,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
Example: Example:
```mlir ```mlir
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %0 = x86.avx.intr.dot %a, %b : vector<8xf32>
%1 = vector.extract %0[%i0] : f32 from vector<8xf32> %1 = vector.extract %0[%i0] : f32 from vector<8xf32>
%2 = vector.extract %0[%i4] : f32 from vector<8xf32> %2 = vector.extract %0[%i4] : f32 from vector<8xf32>
%d = arith.addf %1, %2 : f32 %d = arith.addf %1, %2 : f32
@@ -500,7 +500,7 @@ def DotInt8Op : AVX_Op<"dot.i8", [Pure,
Example: Example:
```mlir ```mlir
%dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> %dst = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
``` ```
}]; }];
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w, let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
@@ -543,8 +543,8 @@ def BcstToPackedF32Op
Example: Example:
```mlir ```mlir
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> %dst = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> %dst = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
``` ```
}]; }];
let arguments = (ins MemRefOf<[BF16, F16]>:$a); let arguments = (ins MemRefOf<[BF16, F16]>:$a);
@@ -595,8 +595,8 @@ def CvtPackedEvenIndexedToF32Op
Example: Example:
```mlir ```mlir
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %dst = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %dst = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
``` ```
}]; }];
let arguments = (ins MemRefOf<[BF16, F16]>:$a); let arguments = (ins MemRefOf<[BF16, F16]>:$a);
@@ -642,8 +642,8 @@ def CvtPackedOddIndexedToF32Op
Example: Example:
```mlir ```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %dst = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %dst = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
``` ```
}]; }];
let arguments = (ins MemRefOf<[BF16, F16]>:$a); let arguments = (ins MemRefOf<[BF16, F16]>:$a);
@@ -673,4 +673,4 @@ def CvtPackedOddIndexedToF32Op
::mlir::RewriterBase &rewriter); ::mlir::RewriterBase &rewriter);
}]; }];
} }
#endif // X86VECTOR_OPS #endif // X86_OPS

View File

@@ -1,4 +1,4 @@
//===- X86VectorDialect.h - MLIR Dialect for X86Vector ----------*- C++ -*-===// //===- X86Dialect.h - MLIR Dialect for X86 ----------------------*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// This file declares the Target dialect for X86Vector in MLIR. // This file declares the Target dialect for X86 in MLIR.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #ifndef MLIR_DIALECT_X86_X86DIALECT_H_
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #define MLIR_DIALECT_X86_X86DIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -25,11 +25,11 @@
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
/// Include the generated interface declarations. /// Include the generated interface declarations.
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc" #include "mlir/Dialect/X86/X86Interfaces.h.inc"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc" #include "mlir/Dialect/X86/X86Dialect.h.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.h.inc" #include "mlir/Dialect/X86/X86.h.inc"
#endif // MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #endif // MLIR_DIALECT_X86_X86DIALECT_H_

View File

@@ -1,4 +1,4 @@
//===- X86VectorInterfaces.td - X86Vector interfaces -------*- tablegen -*-===// //===- X86Interfaces.td - X86 interfaces -------------------*- tablegen -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// This file defines interfaces for the X86Vector dialect. // This file defines interfaces for the X86 dialect.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef X86VECTOR_INTERFACES #ifndef X86_INTERFACES
#define X86VECTOR_INTERFACES #define X86_INTERFACES
include "mlir/IR/Interfaces.td" include "mlir/IR/Interfaces.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
@@ -25,7 +25,7 @@ def X86IntrinsicOpInterface
let description = [{ let description = [{
A wrapper interface for operations representing x86 LLVM intrinsics. A wrapper interface for operations representing x86 LLVM intrinsics.
}]; }];
let cppNamespace = "::mlir::x86vector"; let cppNamespace = "::mlir::x86";
} }
#endif // X86VECTOR_INTERFACES #endif // X86_INTERFACES

View File

@@ -1,7 +0,0 @@
add_mlir_dialect(X86Vector x86vector)
add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
add_subdirectory(TransformOps)

View File

@@ -1,4 +0,0 @@
set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)

View File

@@ -517,13 +517,13 @@ add_mlir_upstream_c_api_library(MLIRCAPIWasmSSA
MLIRWasmSSADialect MLIRWasmSSADialect
) )
add_mlir_upstream_c_api_library(MLIRCAPIX86Vector add_mlir_upstream_c_api_library(MLIRCAPIX86
X86Vector.cpp X86.cpp
PARTIAL_SOURCES_INTENDED PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRCAPIIR MLIRCAPIIR
MLIRX86VectorDialect MLIRX86Dialect
) )
add_mlir_upstream_c_api_library(MLIRCAPIXeGPU add_mlir_upstream_c_api_library(MLIRCAPIXeGPU

View File

@@ -1,4 +1,4 @@
//===- X86Vector.cpp - C Interface for X86Vector dialect ------------------===// //===- X86.cpp - C Interface for X86 dialect ------------------------------===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/X86Vector.h" #include "mlir-c/Dialect/X86.h"
#include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(X86Vector, x86vector, MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(X86, x86, mlir::x86::X86Dialect)
mlir::x86vector::X86VectorDialect)

View File

@@ -40,6 +40,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRArmSVETransforms MLIRArmSVETransforms
MLIRAMXDialect MLIRAMXDialect
MLIRAMXTransforms MLIRAMXTransforms
MLIRX86VectorDialect MLIRX86Dialect
MLIRX86VectorTransforms MLIRX86Transforms
) )

View File

@@ -22,8 +22,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -53,8 +53,8 @@ struct ConvertVectorToLLVMPass
registry.insert<arm_sve::ArmSVEDialect>(); registry.insert<arm_sve::ArmSVEDialect>();
if (amx) if (amx)
registry.insert<amx::AMXDialect>(); registry.insert<amx::AMXDialect>();
if (x86Vector) if (x86)
registry.insert<x86vector::X86VectorDialect>(); registry.insert<x86::X86Dialect>();
} }
void runOnOperation() override; void runOnOperation() override;
}; };
@@ -140,9 +140,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
configureAMXLegalizeForExportTarget(target); configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns); populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
} }
if (x86Vector) { if (x86) {
configureX86VectorLegalizeForExportTarget(target); configureX86LegalizeForExportTarget(target);
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); populateX86LegalizeForLLVMExportPatterns(converter, patterns);
} }
if (failed( if (failed(

View File

@@ -42,7 +42,7 @@ add_subdirectory(UB)
add_subdirectory(Utils) add_subdirectory(Utils)
add_subdirectory(Vector) add_subdirectory(Vector)
add_subdirectory(WasmSSA) add_subdirectory(WasmSSA)
add_subdirectory(X86Vector) add_subdirectory(X86)
add_subdirectory(XeGPU) add_subdirectory(XeGPU)
set(LLVM_OPTIONAL_SOURCES set(LLVM_OPTIONAL_SOURCES

View File

@@ -21,6 +21,6 @@ add_mlir_dialect_library(MLIRMathTransforms
MLIRSCFDialect MLIRSCFDialect
MLIRPass MLIRPass
MLIRTransforms MLIRTransforms
MLIRX86VectorDialect MLIRX86Dialect
MLIRVectorDialect MLIRVectorDialect
) )

View File

@@ -22,7 +22,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
@@ -1740,7 +1740,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
// Compute an approximate result. // Compute an approximate result.
Value yApprox = handleMultidimensionalVectors( Value yApprox = handleMultidimensionalVectors(
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
return x86vector::RsqrtOp::create(builder, operands); return x86::RsqrtOp::create(builder, operands);
}); });
// Do a single step of Newton-Raphson iteration to improve the approximation. // Do a single step of Newton-Raphson iteration to improve the approximation.

View File

@@ -18,5 +18,5 @@ add_mlir_dialect_library(MLIRVectorTransformOps
MLIRTransformDialect MLIRTransformDialect
MLIRVectorDialect MLIRVectorDialect
MLIRVectorToSCF MLIRVectorToSCF
MLIRX86VectorTransforms MLIRX86Transforms
) )

View File

@@ -18,7 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
@@ -191,12 +191,10 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
vector::populateVectorTransposeLoweringPatterns(patterns, vector::populateVectorTransposeLoweringPatterns(patterns,
getLoweringStrategy()); getLoweringStrategy());
if (getAvx2LoweringStrategy()) { if (getAvx2LoweringStrategy()) {
auto avx2LoweringOptions = auto avx2LoweringOptions = x86::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::LoweringOptions().setTransposeOptions( x86::avx2::TransposeLoweringOptions().lower4x8xf32(true).lower8x8xf32(
x86vector::avx2::TransposeLoweringOptions() true));
.lower4x8xf32(true) x86::avx2::populateSpecializedTransposeLoweringPatterns(
.lower8x8xf32(true));
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10); patterns, avx2LoweringOptions, /*benefit=*/10);
} }
} }

View File

@@ -1,11 +1,11 @@
add_mlir_dialect_library(MLIRX86VectorDialect add_mlir_dialect_library(MLIRX86Dialect
X86VectorDialect.cpp X86Dialect.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86
DEPENDS DEPENDS
MLIRX86VectorIncGen MLIRX86IncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR

View File

@@ -1,4 +1,4 @@
//===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===// //===- X86Dialect.cpp - MLIR X86 ops implementation -----------------------===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,25 +6,25 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// This file implements the X86Vector dialect and its operations. // This file implements the X86 dialect and its operations.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
using namespace mlir; using namespace mlir;
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc" #include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
#include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc" #include "mlir/Dialect/X86/X86Dialect.cpp.inc"
void x86vector::X86VectorDialect::initialize() { void x86::X86Dialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" #include "mlir/Dialect/X86/X86.cpp.inc"
>(); >();
} }
@@ -35,7 +35,7 @@ static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type); return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
} }
LogicalResult x86vector::MaskCompressOp::verify() { LogicalResult x86::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc()) if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src"); return emitError("cannot use both src and constant_src");
@@ -49,7 +49,7 @@ LogicalResult x86vector::MaskCompressOp::verify() {
return success(); return success();
} }
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands( SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) { RewriterBase &rewriter) {
auto loc = getLoc(); auto loc = getLoc();
@@ -71,9 +71,9 @@ SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
} }
SmallVector<Value> SmallVector<Value>
x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands, x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) { RewriterBase &rewriter) {
SmallVector<Value> intrinsicOperands(operands); SmallVector<Value> intrinsicOperands(operands);
// Dot product of all elements, broadcasted to all elements. // Dot product of all elements, broadcasted to all elements.
Value scale = Value scale =
@@ -83,7 +83,7 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
return intrinsicOperands; return intrinsicOperands;
} }
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands( SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) { RewriterBase &rewriter) {
Adaptor adaptor(operands, *this); Adaptor adaptor(operands, *this);
@@ -91,7 +91,7 @@ SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
typeConverter, rewriter)}; typeConverter, rewriter)};
} }
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) { RewriterBase &rewriter) {
Adaptor adaptor(operands, *this); Adaptor adaptor(operands, *this);
@@ -99,7 +99,7 @@ SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
typeConverter, rewriter)}; typeConverter, rewriter)};
} }
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) { RewriterBase &rewriter) {
Adaptor adaptor(operands, *this); Adaptor adaptor(operands, *this);
@@ -108,4 +108,4 @@ SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" #include "mlir/Dialect/X86/X86.cpp.inc"

View File

@@ -1,8 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorTransformOps add_mlir_dialect_library(MLIRX86TransformOps
X86VectorTransformOps.cpp X86TransformOps.cpp
DEPENDS DEPENDS
MLIRX86VectorTransformOpsIncGen MLIRX86TransformOpsIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR
@@ -12,6 +12,6 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
MLIRSideEffectInterfaces MLIRSideEffectInterfaces
MLIRTransformDialect MLIRTransformDialect
MLIRTransformDialectUtils MLIRTransformDialectUtils
MLIRX86VectorDialect MLIRX86Dialect
MLIRX86VectorTransforms MLIRX86Transforms
) )

View File

@@ -1,4 +1,4 @@
//===- X86VectorTransformOps.cpp ------------------------------------------===// //===- X86TransformOps.cpp ------------------------------------------------===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,45 +6,45 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/RegionKindInterface.h"
using namespace mlir; using namespace mlir;
using namespace mlir::x86vector; using namespace mlir::x86;
using namespace mlir::transform; using namespace mlir::transform;
void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
x86vector::populateVectorContractToFMAPatterns(patterns); x86::populateVectorContractToFMAPatterns(patterns);
} }
void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp:: void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
populatePatterns(RewritePatternSet &patterns) { populatePatterns(RewritePatternSet &patterns) {
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns); x86::populateVectorContractToPackedTypeDotProductPatterns(patterns);
} }
void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns( void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
x86vector::populateVectorContractBF16ToFMAPatterns(patterns); x86::populateVectorContractBF16ToFMAPatterns(patterns);
} }
void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns( void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
x86vector::populateSinkVectorProducerOpsPatterns(patterns); x86::populateSinkVectorProducerOpsPatterns(patterns);
} }
void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns( void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
x86vector::populateShuffleVectorFMAOpsPatterns(patterns); x86::populateShuffleVectorFMAOpsPatterns(patterns);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -52,28 +52,26 @@ void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
class X86VectorTransformDialectExtension class X86TransformDialectExtension
: public transform::TransformDialectExtension< : public transform::TransformDialectExtension<
X86VectorTransformDialectExtension> { X86TransformDialectExtension> {
public: public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86TransformDialectExtension)
X86VectorTransformDialectExtension)
X86VectorTransformDialectExtension() { X86TransformDialectExtension() {
declareGeneratedDialect<x86vector::X86VectorDialect>(); declareGeneratedDialect<x86::X86Dialect>();
declareGeneratedDialect<LLVM::LLVMDialect>(); declareGeneratedDialect<LLVM::LLVMDialect>();
registerTransformOps< registerTransformOps<
#define GET_OP_LIST #define GET_OP_LIST
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" #include "mlir/Dialect/X86/TransformOps/X86TransformOps.cpp.inc"
>(); >();
} }
}; };
} // namespace } // namespace
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" #include "mlir/Dialect/X86/TransformOps/X86TransformOps.cpp.inc"
void mlir::x86vector::registerTransformDialectExtension( void mlir::x86::registerTransformDialectExtension(DialectRegistry &registry) {
DialectRegistry &registry) { registry.addExtensions<X86TransformDialectExtension>();
registry.addExtensions<X86VectorTransformDialectExtension>();
} }

View File

@@ -15,20 +15,21 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
using namespace mlir::x86vector; using namespace mlir::x86;
using namespace mlir::x86vector::avx2; using namespace mlir::x86::avx2;
using namespace mlir::x86vector::avx2::inline_asm; using namespace mlir::x86::avx2::inline_asm;
using namespace mlir::x86vector::avx2::intrin; using namespace mlir::x86::avx2::intrin;
Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( Value mlir::x86::avx2::inline_asm::mm256BlendPsAsm(ImplicitLocOpBuilder &b,
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { Value v1, Value v2,
uint8_t mask) {
auto asmDialectAttr = auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
const auto *asmTp = "vblendps $0, $1, $2, {0}"; const auto *asmTp = "vblendps $0, $1, $2, {0}";
@@ -45,14 +46,14 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
return asmOp.getResult(0); return asmOp.getResult(0);
} }
Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value mlir::x86::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) { Value v1, Value v2) {
return vector::ShuffleOp::create(b, v1, v2, return vector::ShuffleOp::create(b, v1, v2,
ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13}); ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
} }
Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value mlir::x86::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) { Value v1, Value v2) {
return vector::ShuffleOp::create( return vector::ShuffleOp::create(
b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15}); b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
} }
@@ -60,9 +61,8 @@ Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
/// 0:127 | 128:255 /// 0:127 | 128:255
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, Value mlir::x86::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
Value v1, Value v2, Value v2, uint8_t mask) {
uint8_t mask) {
uint8_t b01, b23, b45, b67; uint8_t b01, b23, b45, b67;
MaskHelper::extractShuffle(mask, b01, b23, b45, b67); MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
SmallVector<int64_t> shuffleMask = { SmallVector<int64_t> shuffleMask = {
@@ -76,8 +76,9 @@ Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
// a[0:127] or a[128:255] or b[0:127] or b[128:255] // a[0:127] or a[128:255] or b[0:127] or b[128:255]
// 0 1 2 3 // 0 1 2 3
// imm[0:1] out of imm[4:7]. // imm[0:1] out of imm[4:7].
Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( Value mlir::x86::avx2::intrin::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { Value v1, Value v2,
uint8_t mask) {
SmallVector<int64_t> shuffleMask; SmallVector<int64_t> shuffleMask;
auto appendToMask = [&](uint8_t control) { auto appendToMask = [&](uint8_t control) {
if (control == 0) if (control == 0)
@@ -99,9 +100,8 @@ Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
} }
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, Value mlir::x86::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, Value v1,
Value v1, Value v2, Value v2, uint8_t mask) {
uint8_t mask) {
SmallVector<int64_t, 8> shuffleMask; SmallVector<int64_t, 8> shuffleMask;
for (int i = 0; i < 8; ++i) { for (int i = 0; i < 8; ++i) {
bool isSet = mask & (1 << i); bool isSet = mask & (1 << i);
@@ -111,8 +111,8 @@ Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
} }
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, void mlir::x86::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) { MutableArrayRef<Value> vs) {
#ifndef NDEBUG #ifndef NDEBUG
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
assert(vs.size() == 4 && "expects 4 vectors"); assert(vs.size() == 4 && "expects 4 vectors");
@@ -136,8 +136,8 @@ void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
} }
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, void mlir::x86::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) { MutableArrayRef<Value> vs) {
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
(void)vt; (void)vt;
assert(vs.size() == 8 && "expects 8 vectors"); assert(vs.size() == 8 && "expects 8 vectors");
@@ -284,7 +284,7 @@ private:
LoweringOptions loweringOptions; LoweringOptions loweringOptions;
}; };
void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( void mlir::x86::avx2::populateSpecializedTransposeLoweringPatterns(
RewritePatternSet &patterns, LoweringOptions options, int benefit) { RewritePatternSet &patterns, LoweringOptions options, int benefit) {
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit); patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
} }

View File

@@ -1,4 +1,4 @@
add_mlir_dialect_library(MLIRX86VectorTransforms add_mlir_dialect_library(MLIRX86Transforms
AVXTranspose.cpp AVXTranspose.cpp
LegalizeForLLVMExport.cpp LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp VectorContractToFMA.cpp
@@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
MLIRLLVMDialect MLIRLLVMDialect
MLIRVectorDialect MLIRVectorDialect
MLIRVectorUtils MLIRVectorUtils
MLIRX86VectorDialect MLIRX86Dialect
MLIRX86VectorUtils MLIRX86Utils
) )

View File

@@ -1,4 +1,4 @@
//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// //===- LegalizeForLLVMExport.cpp - Prepare X86 for LLVM translation -------===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,26 +6,26 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
using namespace mlir; using namespace mlir;
using namespace mlir::x86vector; using namespace mlir::x86;
namespace { namespace {
/// Generic one-to-one conversion of simply mappable operations into calls /// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics. /// to their respective LLVM intrinsics.
struct X86IntrinsicOpConversion struct X86IntrinsicOpConversion
: public ConvertOpInterfaceToLLVMPattern<x86vector::X86IntrinsicOp> { : public ConvertOpInterfaceToLLVMPattern<x86::X86IntrinsicOp> {
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern; using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult LogicalResult
matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands, matchAndRewrite(x86::X86IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
const LLVMTypeConverter &typeConverter = *getTypeConverter(); const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite( return LLVM::detail::intrinsicRewrite(
@@ -37,13 +37,12 @@ struct X86IntrinsicOpConversion
} // namespace } // namespace
/// Populate the given list with patterns that convert from X86Vector to LLVM. /// Populate the given list with patterns that convert from X86 to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns( void mlir::populateX86LegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) { const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<X86IntrinsicOpConversion>(converter); patterns.add<X86IntrinsicOpConversion>(converter);
} }
void mlir::configureX86VectorLegalizeForExportTarget( void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) {
LLVMConversionTarget &target) { target.addIllegalDialect<X86Dialect>();
target.addIllegalDialect<X86VectorDialect>();
} }

View File

@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@@ -17,17 +17,17 @@
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
using namespace mlir::x86vector; using namespace mlir::x86;
namespace { namespace {
// Validates whether the given operation is an x86vector operation and has only // Validates whether the given operation is an x86 operation and has only
// one consumer. // one consumer.
static bool validateFMAOperands(Value op) { static bool validateFMAOperands(Value op) {
if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>()) if (auto cvt = op.getDefiningOp<x86::CvtPackedEvenIndexedToF32Op>())
return cvt.getResult().hasOneUse(); return cvt.getResult().hasOneUse();
if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>()) if (auto bcst = op.getDefiningOp<x86::BcstToPackedF32Op>())
return bcst.getResult().hasOneUse(); return bcst.getResult().hasOneUse();
return false; return false;
@@ -36,14 +36,14 @@ static bool validateFMAOperands(Value op) {
// Validates the vector.fma operation on the following conditions: // Validates the vector.fma operation on the following conditions:
// (i) one of the lhs or rhs defining operation should be // (i) one of the lhs or rhs defining operation should be
// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be // CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
// an x86vector operation and has only one consumer, (iii) all operations // an x86 operation and has only one consumer, (iii) all operations
// are in the same block, and (iv) ths FMA has only one user. // are in the same block, and (iv) ths FMA has only one user.
static bool validateVectorFMAOp(vector::FMAOp fmaOp) { static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
Value lhs = fmaOp.getLhs(); Value lhs = fmaOp.getLhs();
Value rhs = fmaOp.getRhs(); Value rhs = fmaOp.getRhs();
if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) && if (!isa<x86::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
!isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp())) !isa<x86::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
return false; return false;
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs)) if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
@@ -93,38 +93,38 @@ static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
return; return;
} }
// Shuffle FMAs with x86vector operations as operands such that // Shuffle FMAs with x86 operations as operands such that
// FMAs are grouped with respect to odd/even packed index. // FMAs are grouped with respect to odd/even packed index.
// //
// For example: // For example:
// ``` // ```
// %1 = x86vector.avx.bcst_to_f32.packed // %1 = x86.avx.bcst_to_f32.packed
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 // %2 = x86.avx.cvt.packed.odd.indexed_to_f32
// %3 = vector.fma %1, %2, %arg1 // %3 = vector.fma %1, %2, %arg1
// %4 = x86vector.avx.bcst_to_f32.packed // %4 = x86.avx.bcst_to_f32.packed
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 // %5 = x86.avx.cvt.packed.even.indexed_to_f32
// %6 = vector.fma %4, %5, %3 // %6 = vector.fma %4, %5, %3
// %7 = x86vector.avx.bcst_to_f32.packed // %7 = x86.avx.bcst_to_f32.packed
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32 // %8 = x86.avx.cvt.packed.odd.indexed_to_f32
// %9 = vector.fma %7, %8, %arg2 // %9 = vector.fma %7, %8, %arg2
// %10 = x86vector.avx.bcst_to_f32.packed // %10 = x86.avx.bcst_to_f32.packed
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32 // %11 = x86.avx.cvt.packed.even.indexed_to_f32
// %12 = vector.fma %10, %11, %9 // %12 = vector.fma %10, %11, %9
// yield %6, %12 // yield %6, %12
// ``` // ```
// to // to
// ``` // ```
// %1 = x86vector.avx.bcst_to_f32.packed // %1 = x86.avx.bcst_to_f32.packed
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 // %2 = x86.avx.cvt.packed.odd.indexed_to_f32
// %3 = vector.fma %1, %2, %arg1 // %3 = vector.fma %1, %2, %arg1
// %7 = x86vector.avx.bcst_to_f32.packed // %7 = x86.avx.bcst_to_f32.packed
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32 // %8 = x86.avx.cvt.packed.odd.indexed_to_f32
// %9 = vector.fma %7, %8, %arg2 // %9 = vector.fma %7, %8, %arg2
// %4 = x86vector.avx.bcst_to_f32.packed // %4 = x86.avx.bcst_to_f32.packed
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 // %5 = x86.avx.cvt.packed.even.indexed_to_f32
// %6 = vector.fma %4, %5, %3 // %6 = vector.fma %4, %5, %3
// %10 = x86vector.avx.bcst_to_f32.packed // %10 = x86.avx.bcst_to_f32.packed
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32 // %11 = x86.avx.cvt.packed.even.indexed_to_f32
// %12 = vector.fma %10, %11, %9 // %12 = vector.fma %10, %11, %9
// yield %9, %12 // yield %9, %12
// ``` // ```
@@ -150,10 +150,9 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
if (!fma) if (!fma)
continue; continue;
bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>( bool hasX86CvtOperand =
fma.getLhs().getDefiningOp()) || isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getLhs().getDefiningOp()) ||
isa<x86vector::CvtPackedEvenIndexedToF32Op>( isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getRhs().getDefiningOp());
fma.getRhs().getDefiningOp());
if (hasX86CvtOperand && stopAtNextDependentFMA) if (hasX86CvtOperand && stopAtNextDependentFMA)
break; break;
@@ -180,7 +179,6 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
} // namespace } // namespace
void x86vector::populateShuffleVectorFMAOpsPatterns( void x86::populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns) {
RewritePatternSet &patterns) {
patterns.add<ShuffleVectorFMAOps>(patterns.getContext()); patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
} }

View File

@@ -8,8 +8,8 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h" #include "mlir/IR/Dominance.h"
@@ -20,7 +20,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
using namespace mlir::x86vector; using namespace mlir::x86;
static FailureOr<llvm::SmallVector<Operation *>> static FailureOr<llvm::SmallVector<Operation *>>
getSameBlockUsers(Operation *op) { getSameBlockUsers(Operation *op) {
@@ -141,8 +141,7 @@ struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
} }
}; };
void x86vector::populateSinkVectorProducerOpsPatterns( void x86::populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns) {
RewritePatternSet &patterns) {
patterns.add<SinkVectorProducerOps<vector::TransferReadOp>, patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext()); SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
} }

View File

@@ -11,9 +11,9 @@
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h" #include "mlir/Dialect/X86/Utils/X86Utils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h" #include "mlir/IR/Dominance.h"
@@ -25,7 +25,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
using namespace mlir::x86vector; using namespace mlir::x86;
// Verifies that the LHS and RHS operands of a vector.contract are load or // Verifies that the LHS and RHS operands of a vector.contract are load or
// vector.transfer_read operations on a memref source buffer, and checks // vector.transfer_read operations on a memref source buffer, and checks
@@ -61,7 +61,7 @@ static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
return false; return false;
// Return false if the two innermost strides of the memref are not contiguous. // Return false if the two innermost strides of the memref are not contiguous.
// The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require // The x86.avx.cvt.packed.even/odd.indexed_to_f32 operations require
// an eight-element tuple of bf16 values to be contiguous. // an eight-element tuple of bf16 values to be contiguous.
int dimsToCheck = isVnni ? 2 : 1; int dimsToCheck = isVnni ? 2 : 1;
if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck)) if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
@@ -162,7 +162,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
subviews.push_back(subview); subviews.push_back(subview);
// For unit-dims, two subviews should be created for the odd and even // For unit-dims, two subviews should be created for the odd and even
// element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed // element in the VNNI tuple (2xbf16) because x86.avx.bcst_to_f32.packed
// op loads and broadcast the first BF16 element into packed F32. It // op loads and broadcast the first BF16 element into packed F32. It
// cannot distinguish between even and odd BF16 elements within a // cannot distinguish between even and odd BF16 elements within a
// packed pair. // packed pair.
@@ -193,11 +193,11 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// ``` // ```
// to // to
// ``` // ```
// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32> // %1 = x86.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32> // %2 = x86.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
// %3 = vector.fma %1, %2, %arg1 // %3 = vector.fma %1, %2, %arg1
// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32> // %4 = x86.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32> // %5 = x86.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// return vector.fma %4, %5, %3 // return vector.fma %4, %5, %3
// ``` // ```
// //
@@ -212,10 +212,10 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// ``` // ```
// to // to
// ``` // ```
// %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32> // %1 = x86.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
// %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32> // %2 = x86.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// %3 = vector.fma %1, %2, %arg1 // %3 = vector.fma %1, %2, %arg1
// %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32> // %4 = x86.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
// %5 = vector.fma %1, %4, %arg2 // %5 = vector.fma %1, %4, %arg2
// scf.yield %3, %5 // scf.yield %3, %5
struct VectorContractBF16ToFMA struct VectorContractBF16ToFMA
@@ -446,11 +446,10 @@ struct VectorContractBF16ToFMA
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc()); contractOp.getAcc());
auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create( auto loadBcstBF16ElementToF32 = x86::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]); rewriter, loc, dstType, unitDimSubview[0]);
auto loadEvenIdxElementF32 = auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType, rewriter, loc, dstType, nonUnitDimSubview[0]);
nonUnitDimSubview[0]);
auto evenIdxFMA = auto evenIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32, vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
loadEvenIdxElementF32, castAcc); loadEvenIdxElementF32, castAcc);
@@ -468,7 +467,7 @@ struct VectorContractBF16ToFMA
accTyPairCont.getElementType()), accTyPairCont.getElementType()),
pairContractOp.getAcc()); pairContractOp.getAcc());
auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create( auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]); rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA = vector::FMAOp::create( auto oddIdxFMA = vector::FMAOp::create(
rewriter, pairContOpLoc, loadBcstBF16ElementToF32, rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
@@ -481,18 +480,18 @@ struct VectorContractBF16ToFMA
} }
// Load, broadcast, and do FMA for odd indexed BF16 elements. // Load, broadcast, and do FMA for odd indexed BF16 elements.
auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create( auto loadBcstOddIdxElementToF32 = x86::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]); rewriter, loc, dstType, unitDimSubview[0]);
auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create( auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]); rewriter, loc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA = auto oddIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32, vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
loadOddIdxElementF32, castAcc); loadOddIdxElementF32, castAcc);
// Load, broadcast, and do FMA for even indexed BF16 elements. // Load, broadcast, and do FMA for even indexed BF16 elements.
auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create( auto loadBcstEvenIdxElementToF32 = x86::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[1]); rewriter, loc, dstType, unitDimSubview[1]);
auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create( auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]); rewriter, loc, dstType, nonUnitDimSubview[0]);
vector::FMAOp fma = vector::FMAOp fma =
vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32, vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
@@ -504,7 +503,6 @@ struct VectorContractBF16ToFMA
} }
}; };
void x86vector::populateVectorContractBF16ToFMAPatterns( void x86::populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns) {
RewritePatternSet &patterns) {
patterns.add<VectorContractBF16ToFMA>(patterns.getContext()); patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
} }

View File

@@ -8,8 +8,8 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h" #include "mlir/IR/Dominance.h"
@@ -20,7 +20,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
using namespace mlir::x86vector; using namespace mlir::x86;
namespace { namespace {
@@ -137,7 +137,6 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
} // namespace } // namespace
void x86vector::populateVectorContractToFMAPatterns( void x86::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
RewritePatternSet &patterns) {
patterns.add<VectorContractToFMA>(patterns.getContext()); patterns.add<VectorContractToFMA>(patterns.getContext());
} }

View File

@@ -11,9 +11,9 @@
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h" #include "mlir/Dialect/X86/Utils/X86Utils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h" #include "mlir/IR/Dominance.h"
@@ -24,7 +24,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
using namespace mlir::x86vector; using namespace mlir::x86;
namespace { namespace {
@@ -109,7 +109,7 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
// to // to
// ``` // ```
// vector.broadcast %lhs to <32xbf16> // vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32> // x86.avx512.dot vector<32xbf16> -> vector<16xf32>
// ``` // ```
// //
// For example - for bf16 type (Flat layout): // For example - for bf16 type (Flat layout):
@@ -126,9 +126,9 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59] // %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63] // %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
// vector.broadcast %lhs to <32xbf16> // vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32> // x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
// vector.broadcast %lhs to <32xbf16> // vector.broadcast %lhs to <32xbf16>
// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32> // x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
// ``` // ```
struct VectorContractToPackedTypeDotProduct struct VectorContractToPackedTypeDotProduct
: public OpRewritePattern<vector::ContractionOp> { : public OpRewritePattern<vector::ContractionOp> {
@@ -384,7 +384,7 @@ struct VectorContractToPackedTypeDotProduct
rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim); rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
if (lhsTy.getElementType().isBF16()) { if (lhsTy.getElementType().isBF16()) {
dp = x86vector::DotBF16Op::create( dp = x86::DotBF16Op::create(
rewriter, loc, rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc, VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
bitcastUnitDimPkType, castNonUnitDim); bitcastUnitDimPkType, castNonUnitDim);
@@ -392,12 +392,12 @@ struct VectorContractToPackedTypeDotProduct
if (lhsTy.getElementType().isSignlessInteger(8)) { if (lhsTy.getElementType().isSignlessInteger(8)) {
if (nonUnitDimAcc.front() == 16) { if (nonUnitDimAcc.front() == 16) {
dp = x86vector::AVX10DotInt8Op::create( dp = x86::AVX10DotInt8Op::create(
rewriter, loc, rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)), VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim); castAcc, bitcastUnitDimPkType, castNonUnitDim);
} else { } else {
dp = x86vector::DotInt8Op::create( dp = x86::DotInt8Op::create(
rewriter, loc, rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)), VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim); castAcc, bitcastUnitDimPkType, castNonUnitDim);
@@ -415,7 +415,7 @@ struct VectorContractToPackedTypeDotProduct
} // namespace } // namespace
void x86vector::populateVectorContractToPackedTypeDotProductPatterns( void x86::populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext()); patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
} }

View File

@@ -1,8 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorUtils add_mlir_dialect_library(MLIRX86Utils
X86VectorUtils.cpp X86Utils.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector/Utils ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86/Utils
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRAffineDialect MLIRAffineDialect

View File

@@ -1,4 +1,4 @@
//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===// //===- X86Utils.cpp - MLIR Utilities for X86Ops -------------------------===//
// //
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h" #include "mlir/Dialect/X86/Utils/X86Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -23,7 +23,7 @@
#include <cassert> #include <cassert>
namespace mlir { namespace mlir {
namespace x86vector { namespace x86 {
static FailureOr<SmallVector<mlir::utils::IteratorType>> static FailureOr<SmallVector<mlir::utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) { inferIteratorsFromOutMap(AffineMap map) {
@@ -410,5 +410,5 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
return true; return true;
} }
} // namespace x86vector } // namespace x86
} // namespace mlir } // namespace mlir

View File

@@ -96,7 +96,7 @@
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" #include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h"
@@ -152,7 +152,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
ub::UBDialect, ub::UBDialect,
vector::VectorDialect, vector::VectorDialect,
wasmssa::WasmSSADialect, wasmssa::WasmSSADialect,
x86vector::X86VectorDialect, x86::X86Dialect,
xegpu::XeGPUDialect, xegpu::XeGPUDialect,
xevm::XeVMDialect>(); xevm::XeVMDialect>();
// clang-format on // clang-format on

View File

@@ -56,7 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -114,7 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
transform::registerSMTExtension(registry); transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry); transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry); vector::registerTransformDialectExtension(registry);
x86vector::registerTransformDialectExtension(registry); x86::registerTransformDialectExtension(registry);
xegpu::registerTransformDialectExtension(registry); xegpu::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry);

View File

@@ -328,11 +328,11 @@ declare_mlir_dialect_extension_python_bindings(
declare_mlir_dialect_extension_python_bindings( declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/X86VectorTransformOps.td TD_FILE dialects/X86TransformOps.td
SOURCES SOURCES
dialects/transform/x86vector.py dialects/transform/x86.py
DIALECT_NAME transform DIALECT_NAME transform
EXTENSION_NAME x86vector_transform) EXTENSION_NAME x86_transform)
declare_mlir_dialect_extension_python_bindings( declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -522,9 +522,9 @@ declare_mlir_dialect_python_bindings(
declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/X86Vector.td TD_FILE dialects/X86.td
SOURCES dialects/x86vector.py SOURCES dialects/x86.py
DIALECT_NAME x86vector) DIALECT_NAME x86)
declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources.Dialects

View File

@@ -1,4 +1,4 @@
//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===// //===-- X86.td - Entry point for x86 bindings --------------*- tablegen -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_X86VECTOR #ifndef PYTHON_BINDINGS_X86
#define PYTHON_BINDINGS_X86VECTOR #define PYTHON_BINDINGS_X86
include "mlir/Dialect/X86Vector/X86Vector.td" include "mlir/Dialect/X86/X86.td"
#endif // PYTHON_BINDINGS_X86VECTOR #endif // PYTHON_BINDINGS_X86

View File

@@ -1,4 +1,4 @@
//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===// //===-- X86TransformOps.td ---------------------------------*- tablegen -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS #ifndef PYTHON_BINDINGS_X86TRANSFORMOPS
#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS #define PYTHON_BINDINGS_X86TRANSFORMOPS
include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td" include "mlir/Dialect/X86/TransformOps/X86TransformOps.td"
#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS #endif // PYTHON_BINDINGS_X86TRANSFORMOPS

View File

@@ -2,4 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information. # See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._x86vector_transform_ops_gen import * from .._x86_transform_ops_gen import *

View File

@@ -2,5 +2,5 @@
# See https://llvm.org/LICENSE.txt for license information. # See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._x86vector_ops_gen import * from ._x86_ops_gen import *
from ._x86vector_ops_gen import _Dialect from ._x86_ops_gen import _Dialect

View File

@@ -33,7 +33,7 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
"Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.") "Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
option(MLIR_RUN_AMX_TESTS "Run AMX tests.") option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.") option(MLIR_RUN_X86_TESTS "Run X86 tests.")
option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.") option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.") option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.") option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
@@ -78,7 +78,7 @@ llvm_canonicalize_cmake_booleans(
MLIR_INCLUDE_INTEGRATION_TESTS MLIR_INCLUDE_INTEGRATION_TESTS
MLIR_RUN_AMX_TESTS MLIR_RUN_AMX_TESTS
MLIR_RUN_CUDA_TENSOR_CORE_TESTS MLIR_RUN_CUDA_TENSOR_CORE_TESTS
MLIR_RUN_X86VECTOR_TESTS MLIR_RUN_X86_TESTS
MLIR_RUN_ARM_SVE_TESTS MLIR_RUN_ARM_SVE_TESTS
MLIR_RUN_ARM_SME_TESTS MLIR_RUN_ARM_SME_TESTS
MLIR_RUN_CUDA_SM80_TESTS MLIR_RUN_CUDA_SM80_TESTS

View File

@@ -21,7 +21,7 @@
// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}} // CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}} // CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}} // CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}} // CHECK-SAME: enable-x86={{[aA-zZ0-9]+}}
// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}} // CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}} // CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
// DEFAULT: vector-contract-lowering=dot // DEFAULT: vector-contract-lowering=dot

View File

@@ -691,7 +691,7 @@ func.func @rsqrt_scalar(%arg0: f32) -> f32 {
// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32> // AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32>
// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32> // AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32>
// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1> // AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1>
// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32> // AVX2: %[[VAL_9:.*]] = x86.avx.rsqrt %[[VAL_0]] : vector<8xf32>
// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32> // AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32>
// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32> // AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32>
// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32> // AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32>
@@ -725,9 +725,9 @@ func.func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> {
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32> // AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32> // AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32>
// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0] // AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0]
// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]] // AVX2: %[[RSQRT0:.*]] = x86.avx.rsqrt %[[VEC0]]
// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1] // AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1]
// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]] // AVX2: %[[RSQRT1:.*]] = x86.avx.rsqrt %[[VEC1]]
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0] // AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1] // AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32> // AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32>
@@ -746,9 +746,9 @@ func.func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> {
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32> // AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
// AVX2-NOT: vector.shape_cast // AVX2-NOT: vector.shape_cast
// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0] // AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0]
// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]] // AVX2: %[[RSQRT0:.*]] = x86.avx.rsqrt %[[VEC0]]
// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1] // AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1]
// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]] // AVX2: %[[RSQRT1:.*]] = x86.avx.rsqrt %[[VEC1]]
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0] // AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1] // AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
// AVX2-NOT: vector.shape_cast // AVX2-NOT: vector.shape_cast
@@ -768,13 +768,13 @@ func.func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> {
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32> // AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32>
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32> // AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32>
// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0] // AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0]
// AVX2: %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]] // AVX2: %[[RSQRT00:.*]] = x86.avx.rsqrt %[[VEC00]]
// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1] // AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1]
// AVX2: %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]] // AVX2: %[[RSQRT01:.*]] = x86.avx.rsqrt %[[VEC01]]
// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0] // AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0]
// AVX2: %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]] // AVX2: %[[RSQRT10:.*]] = x86.avx.rsqrt %[[VEC10]]
// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1] // AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1]
// AVX2: %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]] // AVX2: %[[RSQRT11:.*]] = x86.avx.rsqrt %[[VEC11]]
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0] // AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0]
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1] // AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1]
// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0] // AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0]

View File

@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
// NOTE: This file tests lowerings that are implemented in the X86Vector // NOTE: This file tests lowerings that are implemented in the X86
// dialect. Since X86 does not support scalable vectors, all examples in this // dialect. Since X86 does not support scalable vectors, all examples in this
// file use fixed-width vectors. // file use fixed-width vectors.

View File

@@ -1,7 +1,7 @@
// REQUIRES: target=x86{{.*}} // REQUIRES: target=x86{{.*}}
// RUN: mlir-opt %s \ // RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ // RUN: -convert-vector-to-llvm="enable-x86" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \ // RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sapphirerapids | \ // RUN: llc -mcpu=sapphirerapids | \
@@ -9,7 +9,7 @@
func.func @avx512bf16_cvt_packed_f32_to_bf16_256( func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> vector<8xbf16> { %a: vector<8xf32>) -> vector<8xbf16> {
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> %0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16> return %0 : vector<8xbf16>
} }
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256: // CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
@@ -17,7 +17,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
func.func @avx512bf16_cvt_packed_f32_to_bf16_512( func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> vector<16xbf16> { %a: vector<16xf32>) -> vector<16xbf16> {
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> %0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16> return %0 : vector<16xbf16>
} }
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512: // CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:

View File

@@ -1,7 +1,7 @@
// REQUIRES: target=x86{{.*}} // REQUIRES: target=x86{{.*}}
// RUN: mlir-opt %s \ // RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ // RUN: -convert-vector-to-llvm="enable-x86" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \ // RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sapphirerapids | \ // RUN: llc -mcpu=sapphirerapids | \
@@ -9,7 +9,7 @@
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>, func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> vector<4xf32> { %b: vector<8xbf16>) -> vector<4xf32> {
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32> %0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
// CHECK-LABEL: avx512bf16_dot_128: // CHECK-LABEL: avx512bf16_dot_128:
@@ -17,7 +17,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>, func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> vector<8xf32> { %b: vector<16xbf16>) -> vector<8xf32> {
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32> %0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
// CHECK-LABEL: avx512bf16_dot_256: // CHECK-LABEL: avx512bf16_dot_256:
@@ -25,7 +25,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>, func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> vector<16xf32> { %b: vector<32xbf16>) -> vector<16xf32> {
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> %0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32> return %0 : vector<16xf32>
} }
// CHECK-LABEL: avx512bf16_dot_512: // CHECK-LABEL: avx512bf16_dot_512:

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" | mlir-opt | FileCheck %s // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s
// CHECK-LABEL: func @avx512_mask_rndscale // CHECK-LABEL: func @avx512_mask_rndscale
func.func @avx512_mask_rndscale( func.func @avx512_mask_rndscale(
@@ -9,14 +9,14 @@ func.func @avx512_mask_rndscale(
%rnd_k = arith.constant 15 : i32 %rnd_k = arith.constant 15 : i32
%rnd = arith.constant 42 : i32 %rnd = arith.constant 42 : i32
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512"
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32> %0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512"
%1 = x86vector.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64> %1 = x86.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512"
%2 = x86vector.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32> %2 = x86.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512"
%3 = x86vector.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64> %3 = x86.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64>
// Keep results alive. // Keep results alive.
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
@@ -29,13 +29,13 @@ func.func @avx512_mask_compress(
{ {
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>) // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>)
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32> %0 = x86.avx512.mask.compress %k1, %a1 : vector<16xf32>
// CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>) // CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>)
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
%1 = x86vector.avx512.mask.compress %k1, %a1 %1 = x86.avx512.mask.compress %k1, %a1
{constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> %2 = x86.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
} }
@@ -44,9 +44,9 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512"
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32> %0, %1 = x86.avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512"
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64> %2, %3 = x86.avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
} }
@@ -55,7 +55,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> (vector<4xf32>) %b: vector<8xbf16>) -> (vector<4xf32>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128" // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128"
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32> %0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -64,7 +64,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> (vector<8xf32>) %b: vector<16xbf16>) -> (vector<8xf32>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256" // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256"
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32> %0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -73,7 +73,7 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> (vector<16xf32>) %b: vector<32xbf16>) -> (vector<16xf32>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512"
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> %0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32> return %0 : vector<16xf32>
} }
@@ -82,7 +82,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> (vector<8xbf16>) %a: vector<8xf32>) -> (vector<8xbf16>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256" // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256"
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> %0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16> return %0 : vector<8xbf16>
} }
@@ -91,7 +91,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> (vector<16xbf16>) %a: vector<16xf32>) -> (vector<16xbf16>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512"
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> %0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16> return %0 : vector<16xbf16>
} }
@@ -99,7 +99,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>, func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> { %b: vector<64xi8>) -> vector<16xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512" // CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32> %0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32> return %0 : vector<16xi32>
} }
@@ -108,7 +108,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32> %a: memref<8xbf16>) -> vector<4xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -117,7 +117,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32> %a: memref<16xbf16>) -> vector<8xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -126,7 +126,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32> %a: memref<8xbf16>) -> vector<4xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -135,7 +135,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32> %a: memref<16xbf16>) -> vector<8xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -144,7 +144,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
%a: memref<1xbf16>) -> vector<4xf32> %a: memref<1xbf16>) -> vector<4xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128" // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -153,7 +153,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
%a: memref<1xbf16>) -> vector<8xf32> %a: memref<1xbf16>) -> vector<8xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256" // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -162,7 +162,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32> %a: memref<8xf16>) -> vector<4xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -171,7 +171,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32> %a: memref<16xf16>) -> vector<8xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -180,7 +180,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32> %a: memref<8xf16>) -> vector<4xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -189,7 +189,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32> %a: memref<16xf16>) -> vector<8xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256" // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -198,7 +198,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_128(
%a: memref<1xf16>) -> vector<4xf32> %a: memref<1xf16>) -> vector<4xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128" // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -207,7 +207,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
%a: memref<1xf16>) -> vector<8xf32> %a: memref<1xf16>) -> vector<8xf32>
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256" // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -215,7 +215,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{ {
// CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256" // CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256"
%0 = x86vector.avx.rsqrt %a : vector<8xf32> %0 = x86.avx.rsqrt %a : vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -224,7 +224,7 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
{ {
// CHECK: llvm.mlir.constant(-1 : i8) // CHECK: llvm.mlir.constant(-1 : i8)
// CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256" // CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256"
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %0 = x86.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -232,7 +232,7 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>, func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
%b: vector<16xi8>) -> vector<4xi32> { %b: vector<16xi8>) -> vector<4xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128" // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> %0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
return %0 : vector<4xi32> return %0 : vector<4xi32>
} }
@@ -240,6 +240,6 @@ func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>, func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
%b: vector<32xi8>) -> vector<8xi32> { %b: vector<32xi8>) -> vector<8xi32> {
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256" // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> %0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
return %0 : vector<8xi32> return %0 : vector<8xi32>
} }

View File

@@ -4,10 +4,10 @@
func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
-> (vector<16xf32>, vector<8xf64>) -> (vector<16xf32>, vector<8xf64>)
{ {
// CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<16xf32> // CHECK: x86.avx512.mask.rndscale {{.*}}: vector<16xf32>
%0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32> %0 = x86.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>
// CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<8xf64> // CHECK: x86.avx512.mask.rndscale {{.*}}: vector<8xf64>
%1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64> %1 = x86.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>
return %0, %1: vector<16xf32>, vector<8xf64> return %0, %1: vector<16xf32>, vector<8xf64>
} }
@@ -15,10 +15,10 @@ func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32
func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
-> (vector<16xf32>, vector<8xf64>) -> (vector<16xf32>, vector<8xf64>)
{ {
// CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<16xf32> // CHECK: x86.avx512.mask.scalef {{.*}}: vector<16xf32>
%0 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> %0 = x86.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
// CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<8xf64> // CHECK: x86.avx512.mask.scalef {{.*}}: vector<8xf64>
%1 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64> %1 = x86.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
return %0, %1: vector<16xf32>, vector<8xf64> return %0, %1: vector<16xf32>, vector<8xf64>
} }
@@ -26,10 +26,10 @@ func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16:
func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
{ {
// CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<16xi32> // CHECK: x86.avx512.vp2intersect {{.*}} : vector<16xi32>
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32> %0, %1 = x86.avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<8xi64> // CHECK: x86.avx512.vp2intersect {{.*}} : vector<8xi64>
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64> %2, %3 = x86.avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
} }
@@ -38,12 +38,12 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
%k2: vector<8xi1>, %a2: vector<8xi64>) %k2: vector<8xi1>, %a2: vector<8xi64>)
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>) -> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
{ {
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32> // CHECK: x86.avx512.mask.compress {{.*}} : vector<16xf32>
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32> %0 = x86.avx512.mask.compress %k1, %a1 : vector<16xf32>
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32> // CHECK: x86.avx512.mask.compress {{.*}} : vector<16xf32>
%1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> %1 = x86.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<8xi64> // CHECK: x86.avx512.mask.compress {{.*}} : vector<8xi64>
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> %2 = x86.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
} }
@@ -51,8 +51,8 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>, func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> (vector<4xf32>) %b: vector<8xbf16>) -> (vector<4xf32>)
{ {
// CHECK: x86vector.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32> // CHECK: x86.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32> %0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -60,8 +60,8 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>, func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> (vector<8xf32>) %b: vector<16xbf16>) -> (vector<8xf32>)
{ {
// CHECK: x86vector.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32> // CHECK: x86.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32> %0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -69,8 +69,8 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>, func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> (vector<16xf32>) %b: vector<32xbf16>) -> (vector<16xf32>)
{ {
// CHECK: x86vector.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32> // CHECK: x86.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> %0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32> return %0 : vector<16xf32>
} }
@@ -78,9 +78,9 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
func.func @avx512bf16_cvt_packed_f32_to_bf16_256( func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> (vector<8xbf16>) %a: vector<8xf32>) -> (vector<8xbf16>)
{ {
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} : // CHECK: x86.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK-SAME: vector<8xf32> -> vector<8xbf16> // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> %0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16> return %0 : vector<8xbf16>
} }
@@ -88,17 +88,17 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
func.func @avx512bf16_cvt_packed_f32_to_bf16_512( func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> (vector<16xbf16>) %a: vector<16xf32>) -> (vector<16xbf16>)
{ {
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} : // CHECK: x86.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK-SAME: vector<16xf32> -> vector<16xbf16> // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> %0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16> return %0 : vector<16xbf16>
} }
// CHECK-LABEL: func @avx10_dot_i8_512 // CHECK-LABEL: func @avx10_dot_i8_512
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>, func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> { %b: vector<64xi8>) -> vector<16xi32> {
// CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32> // CHECK: x86.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32> %0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32> return %0 : vector<16xi32>
} }
@@ -106,9 +106,9 @@ func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32> %a: memref<8xbf16>) -> vector<4xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xbf16> -> vector<4xf32> // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -116,9 +116,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32> %a: memref<16xbf16>) -> vector<8xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xbf16> -> vector<8xf32> // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -126,9 +126,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32> %a: memref<8xbf16>) -> vector<4xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xbf16> -> vector<4xf32> // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -136,9 +136,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32> %a: memref<16xbf16>) -> vector<8xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xbf16> -> vector<8xf32> // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -146,9 +146,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
func.func @avxbf16_bcst_bf16_to_f32_128( func.func @avxbf16_bcst_bf16_to_f32_128(
%a: memref<1xbf16>) -> vector<4xf32> %a: memref<1xbf16>) -> vector<4xf32>
{ {
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xbf16> -> vector<4xf32> // CHECK-SAME: memref<1xbf16> -> vector<4xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -156,9 +156,9 @@ func.func @avxbf16_bcst_bf16_to_f32_128(
func.func @avxbf16_bcst_bf16_to_f32_256( func.func @avxbf16_bcst_bf16_to_f32_256(
%a: memref<1xbf16>) -> vector<8xf32> %a: memref<1xbf16>) -> vector<8xf32>
{ {
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xbf16> -> vector<8xf32> // CHECK-SAME: memref<1xbf16> -> vector<8xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -166,9 +166,9 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32> %a: memref<8xf16>) -> vector<4xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32> // CHECK-SAME: memref<8xf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -176,9 +176,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32> %a: memref<16xf16>) -> vector<8xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32> // CHECK-SAME: memref<16xf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -186,9 +186,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32> %a: memref<8xf16>) -> vector<4xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32> // CHECK-SAME: memref<8xf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -196,9 +196,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32> %a: memref<16xf16>) -> vector<8xf32>
{ {
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32> // CHECK-SAME: memref<16xf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -206,9 +206,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
func.func @avxf16_bcst_f16_to_f32_128( func.func @avxf16_bcst_f16_to_f32_128(
%a: memref<1xf16>) -> vector<4xf32> %a: memref<1xf16>) -> vector<4xf32>
{ {
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<4xf32> // CHECK-SAME: memref<1xf16> -> vector<4xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -216,40 +216,40 @@ func.func @avxf16_bcst_f16_to_f32_128(
func.func @avxf16_bcst_f16_to_f32_256( func.func @avxf16_bcst_f16_to_f32_256(
%a: memref<1xf16>) -> vector<8xf32> %a: memref<1xf16>) -> vector<8xf32>
{ {
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<8xf32> // CHECK-SAME: memref<1xf16> -> vector<8xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
// CHECK-LABEL: func @avx_rsqrt // CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{ {
// CHECK: x86vector.avx.rsqrt {{.*}} : vector<8xf32> // CHECK: x86.avx.rsqrt {{.*}} : vector<8xf32>
%0 = x86vector.avx.rsqrt %a : vector<8xf32> %0 = x86.avx.rsqrt %a : vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
// CHECK-LABEL: func @avx_dot // CHECK-LABEL: func @avx_dot
func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
{ {
// CHECK: x86vector.avx.intr.dot {{.*}} : vector<8xf32> // CHECK: x86.avx.intr.dot {{.*}} : vector<8xf32>
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %0 = x86.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
// CHECK-LABEL: func @avx_dot_i8_128 // CHECK-LABEL: func @avx_dot_i8_128
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>, func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
%b: vector<16xi8>) -> vector<4xi32> { %b: vector<16xi8>) -> vector<4xi32> {
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32> // CHECK: x86.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> %0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
return %0 : vector<4xi32> return %0 : vector<4xi32>
} }
// CHECK-LABEL: func @avx_dot_i8_256 // CHECK-LABEL: func @avx_dot_i8_256
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>, func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
%b: vector<32xi8>) -> vector<8xi32> { %b: vector<32xi8>) -> vector<8xi32> {
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32> // CHECK: x86.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> %0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
return %0 : vector<8xi32> return %0 : vector<8xi32>
} }

View File

@@ -8,17 +8,17 @@ func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
{ {
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec %0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec %1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec %2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec %3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec %4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec %5 = vector.fma %3, %4, %2 : !vec
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec %8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec %12 = vector.fma %5, %11, %arg6 : !vec
return %12 : !vec return %12 : !vec
@@ -28,25 +28,25 @@ func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
// The vector.fma at %5 is moved along with its operands after %8. // The vector.fma at %5 is moved along with its operands after %8.
// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32 // CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
// Odd-Indexed FMAs // Odd-Indexed FMAs
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0 // CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 // CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6 // CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3 // CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 // CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6 // CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
// Even-Indexed FMAs // Even-Indexed FMAs
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4 // CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 // CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]] // CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1 // CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 // CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]] // CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -62,17 +62,17 @@ func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
{ {
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec %0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec %1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec %2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec %3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec %4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %4, %3, %2 : !vec %5 = vector.fma %4, %3, %2 : !vec
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec %8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec %12 = vector.fma %5, %11, %arg6 : !vec
return %12 : !vec return %12 : !vec
@@ -81,25 +81,25 @@ func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
// The vector.fma at %5 is moved along with its operands after %8. // The vector.fma at %5 is moved along with its operands after %8.
// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32 // CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
// Odd-Indexed FMAs // Odd-Indexed FMAs
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0 // CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 // CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6 // CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3 // CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 // CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6 // CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
// Even-Indexed FMAs // Even-Indexed FMAs
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4 // CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 // CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]] // CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 // CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1 // CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg1
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]] // CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -116,18 +116,18 @@ func.func @shuffle_fma_with_shape_cast(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
{ {
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec %0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec %1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec %2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec %3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec %4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec %5 = vector.fma %3, %4, %2 : !vec
%res1 = vector.shape_cast %5 : !vec to !vecOut %res1 = vector.shape_cast %5 : !vec to !vecOut
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec %8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%res2 = vector.shape_cast %11 : !vec to !vecOut %res2 = vector.shape_cast %11 : !vec to !vecOut
%12 = arith.addf %res1, %res2 : !vecOut %12 = arith.addf %res1, %res2 : !vecOut
@@ -136,19 +136,19 @@ func.func @shuffle_fma_with_shape_cast(
// CHECK-LABEL: @shuffle_fma_with_shape_cast // CHECK-LABEL: @shuffle_fma_with_shape_cast
// Odd-Indexed FMAs // Odd-Indexed FMAs
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0 // CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 // CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6 // CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3 // CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 // CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6 // CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
// Even-Indexed FMAs // Even-Indexed FMAs
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4 // CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg4
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 // CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]] // CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
// CHECK: vector.shape_cast // CHECK: vector.shape_cast
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1 // CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg1
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 // CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]] // CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
// CHECK: vector.shape_cast // CHECK: vector.shape_cast
@@ -156,7 +156,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -172,16 +172,16 @@ func.func @negative_fma_operand_has_multiple_consumer(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
%arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec %arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
{ {
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec %0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec %1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg5 : !vec %2 = vector.fma %0, %1, %arg5 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec %3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec %4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec %5 = vector.fma %3, %4, %2 : !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
%8 = vector.fma %3, %7, %arg5 : !vec %8 = vector.fma %3, %7, %arg5 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg5 : !vec %12 = vector.fma %5, %11, %arg5 : !vec
return %12 : !vec return %12 : !vec
@@ -190,18 +190,18 @@ func.func @negative_fma_operand_has_multiple_consumer(
// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore, // The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore,
// the rewrite is not applied. // the rewrite is not applied.
// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer // CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -217,17 +217,17 @@ func.func @negative_fma_has_multiple_consumer(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
{ {
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec %0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec %1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec %2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec %3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec %4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec %5 = vector.fma %3, %4, %2 : !vec
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %5 : !vec %8 = vector.fma %6, %7, %5 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec %12 = vector.fma %5, %11, %arg6 : !vec
return %12 : !vec return %12 : !vec
@@ -235,17 +235,17 @@ func.func @negative_fma_has_multiple_consumer(
// vector.fma at %5 has two uses; therefore no re-write applied. // vector.fma at %5 has two uses; therefore no re-write applied.
// CHECK-LABEL: @negative_fma_has_multiple_consumer // CHECK-LABEL: @negative_fma_has_multiple_consumer
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -260,28 +260,28 @@ func.func @negative_no_shuffle_outside_block(
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
{ {
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec %0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec %1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
%2 = vector.fma %0, %1, %arg6 : !vec %2 = vector.fma %0, %1, %arg6 : !vec
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec %3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec %4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
%5 = vector.fma %3, %4, %2 : !vec %5 = vector.fma %3, %4, %2 : !vec
%loop = scf.if %arg7 -> (vector<8xf32>) { %loop = scf.if %arg7 -> (vector<8xf32>) {
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec %8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec %12 = vector.fma %5, %11, %arg6 : !vec
scf.yield %12 : vector<8xf32> scf.yield %12 : vector<8xf32>
} else { } else {
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec %6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec %7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
%8 = vector.fma %6, %7, %arg6 : !vec %8 = vector.fma %6, %7, %arg6 : !vec
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec %9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec %10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
%11 = vector.fma %9, %10, %8 : !vec %11 = vector.fma %9, %10, %8 : !vec
%12 = vector.fma %5, %11, %arg6 : !vec %12 = vector.fma %5, %11, %arg6 : !vec
scf.yield %12 : vector<8xf32> scf.yield %12 : vector<8xf32>
@@ -293,19 +293,19 @@ func.func @negative_no_shuffle_outside_block(
// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not // vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
// applied. // applied.
// CHECK-LABEL: @negative_no_shuffle_outside_block // CHECK-LABEL: @negative_no_shuffle_outside_block
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: scf.if // CHECK: scf.if
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.shuffle_vector_fma_ops transform.apply_patterns.x86.shuffle_vector_fma_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }

View File

@@ -24,7 +24,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -57,7 +57,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -90,7 +90,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -134,7 +134,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -160,7 +160,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -191,9 +191,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops transform.apply_patterns.x86.sink_vector_producer_ops
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
} }

View File

@@ -30,18 +30,18 @@ func.func @brgemm_to_fma(
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 1] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}> // CHECK: memref.subview %arg0[%c0, %c0, %c0, 1] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 0] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}> // CHECK: memref.subview %arg0[%c0, %c0, %c0, 0] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg1[%c0, %c0, %c0, %c0] {{.*}} : memref<1x1x32x2xbf16> to memref<1x1x8x2xbf16, {{.*}}> // CHECK: memref.subview %arg1[%c0, %c0, %c0, %c0] {{.*}} : memref<1x1x32x2xbf16> to memref<1x1x8x2xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>> // CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>> // CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -76,18 +76,18 @@ func.func @brgemm_to_fma_load(
} }
// CHECK-LABEL: @brgemm_to_fma_load // CHECK-LABEL: @brgemm_to_fma_load
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -122,18 +122,18 @@ func.func @brgemm_to_fma_load_bcst_B(
} }
// CHECK-LABEL: @brgemm_to_fma_load_bcst_B // CHECK-LABEL: @brgemm_to_fma_load_bcst_B
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -168,18 +168,18 @@ func.func @batch_matmul_fma_load(
} }
// CHECK-LABEL: @batch_matmul_fma_load // CHECK-LABEL: @batch_matmul_fma_load
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -214,18 +214,18 @@ func.func @matmul_outer_product_to_fma_load(
} }
// CHECK-LABEL: @matmul_outer_product_to_fma_load // CHECK-LABEL: @matmul_outer_product_to_fma_load
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -285,10 +285,10 @@ func.func @matmul_to_fma_flat_layout(
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}> // CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}> // CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>> // CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32> // CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
@@ -297,7 +297,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -357,10 +357,10 @@ func.func @matmul_to_fma_flat_layout_load(
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}> // CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}> // CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>> // CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32> // CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
@@ -369,7 +369,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -448,10 +448,10 @@ func.func @matmul_to_fma_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: m
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: scf.for // CHECK: scf.for
// CHECK: scf.for // CHECK: scf.for
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32> // CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
@@ -461,7 +461,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -524,10 +524,10 @@ func.func @matmul_to_fma_flat_layout_bcstB(
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}> // CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}>
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}> // CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}>
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>> // CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>> // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
// CHECK: vector.fma {{.*}} : vector<8xf32> // CHECK: vector.fma {{.*}} : vector<8xf32>
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32> // CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> // CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
@@ -536,7 +536,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -572,18 +572,18 @@ func.func @matmul_dynamic_offset(
// CHECK-LABEL: @matmul_dynamic_offset // CHECK-LABEL: @matmul_dynamic_offset
// CHECK: memref.subview %arg0[%arg3, %c0, 1]{{.*}} // CHECK: memref.subview %arg0[%arg3, %c0, 1]{{.*}}
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -618,18 +618,18 @@ func.func @matmul_to_fma_load_bcst_B(
} }
// CHECK-LABEL: @matmul_to_fma_load_bcst_B // CHECK-LABEL: @matmul_to_fma_load_bcst_B
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -664,18 +664,18 @@ func.func @many_dimensions(
} }
// CHECK-LABEL: @many_dimensions // CHECK-LABEL: @many_dimensions
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
// CHECK: x86vector.avx.bcst_to_f32.packed // CHECK: x86.avx.bcst_to_f32.packed
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK: vector.fma // CHECK: vector.fma
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -744,7 +744,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -785,7 +785,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 { transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -841,17 +841,17 @@ func.func @negative_offset_diff_is_not_8(
} }
// CHECK-LABEL: @negative_offset_diff_is_not_8 // CHECK-LABEL: @negative_offset_diff_is_not_8
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed // CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32> // CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -907,17 +907,17 @@ func.func @negative_vector_contracts_not_in_order(
} }
// CHECK-LABEL: @negative_vector_contracts_not_in_order // CHECK-LABEL: @negative_vector_contracts_not_in_order
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed // CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32> // CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -976,17 +976,17 @@ func.func @negative_flat_layout_dynamic_index(
} }
// CHECK-LABEL: @negative_flat_layout_dynamic_index // CHECK-LABEL: @negative_flat_layout_dynamic_index
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed // CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32> // CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1045,17 +1045,17 @@ func.func @negative_non_unit_K_dim(
} }
// CHECK-LABEL: @negative_non_unit_K_dim // CHECK-LABEL: @negative_non_unit_K_dim
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed // CHECK-NOT: x86.avx.bcst_to_f32.packed
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32> // CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32 // CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1089,7 +1089,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1132,7 +1132,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1176,7 +1176,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1222,7 +1222,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1266,7 +1266,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1308,9 +1308,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma transform.apply_patterns.x86.vector_contract_bf16_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
} }

View File

@@ -27,7 +27,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -61,7 +61,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -129,7 +129,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -163,7 +163,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -197,7 +197,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -233,7 +233,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -269,7 +269,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -303,7 +303,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -337,7 +337,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_fma transform.apply_patterns.x86.vector_contract_to_fma
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }

View File

@@ -20,13 +20,13 @@ func.func @brgemm_to_bf16dp(
// CHECK-LABEL: @brgemm_to_bf16dp // CHECK-LABEL: @brgemm_to_bf16dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -54,13 +54,13 @@ func.func @brgemm_to_bf16dp_bcst_B(
// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B // CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -88,13 +88,13 @@ func.func @brgemm_to_avx10int8dp(
// CHECK-LABEL: @brgemm_to_avx10int8dp // CHECK-LABEL: @brgemm_to_avx10int8dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx10.dot.i8 // CHECK: x86.avx10.dot.i8
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -123,13 +123,13 @@ func.func @batch_matmul_avx10int8dp_bcst_B(
// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B // CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx10.dot.i8 // CHECK: x86.avx10.dot.i8
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -157,13 +157,13 @@ func.func @brgemm_to_int8dp(
// CHECK-LABEL: @brgemm_to_int8dp // CHECK-LABEL: @brgemm_to_int8dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8 // CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -191,13 +191,13 @@ func.func @batch_matmul_bf16dp(
// CHECK-LABEL: @batch_matmul_bf16dp // CHECK-LABEL: @batch_matmul_bf16dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -226,13 +226,13 @@ func.func @batch_matmul_int8dp(
// CHECK-LABEL: @batch_matmul_int8dp // CHECK-LABEL: @batch_matmul_int8dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8 // CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -261,13 +261,13 @@ func.func @batch_matmul_int8dp_bcst_B(
// CHECK-LABEL: @batch_matmul_int8dp_bcst_B // CHECK-LABEL: @batch_matmul_int8dp_bcst_B
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8 // CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -295,13 +295,13 @@ func.func @matmul_outer_product_to_bf16dp(
// CHECK-LABEL: @matmul_outer_product_to_bf16dp // CHECK-LABEL: @matmul_outer_product_to_bf16dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -329,13 +329,13 @@ func.func @matmul_outer_product_to_int8dp(
// CHECK-LABEL: @matmul_outer_product_to_int8dp // CHECK-LABEL: @matmul_outer_product_to_int8dp
// CHECK: vector.broadcast // CHECK: vector.broadcast
// CHECK: x86vector.avx.dot.i8 // CHECK: x86.avx.dot.i8
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -395,8 +395,8 @@ func.func @matmul_bf16dp_flat_layout(
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> // CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16> // CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16> // CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> // CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> // CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
@@ -404,7 +404,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -485,8 +485,8 @@ func.func @brmatmul_bf16dp_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1:
// CHECK: scf.for // CHECK: scf.for
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16> // CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16> // CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> // CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
@@ -496,7 +496,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -566,15 +566,15 @@ func.func @matmul_bf16dp_flat_layout_B_shuffled(
} }
// CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled // CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
// CHECK: x86vector.avx512.dot // CHECK: x86.avx512.dot
// CHECK-NOT: vector.contract // CHECK-NOT: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -601,14 +601,14 @@ func.func @negative_invalid_vc_kind(
} }
// CHECK-LABEL: @negative_invalid_vc_kind // CHECK-LABEL: @negative_invalid_vc_kind
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -635,14 +635,14 @@ func.func @negative_false_vnni_bf16(
} }
// CHECK-LABEL: @negative_false_vnni_bf16 // CHECK-LABEL: @negative_false_vnni_bf16
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -669,14 +669,14 @@ func.func @negative_false_vnni_int8(
} }
// CHECK-LABEL: @negative_false_vnni_int8 // CHECK-LABEL: @negative_false_vnni_int8
// CHECK-NOT: x86vector.avx.dot.i8 // CHECK-NOT: x86.avx.dot.i8
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -703,14 +703,14 @@ func.func @negative_batch_dimension(
} }
// CHECK-LABEL: @negative_batch_dimension // CHECK-LABEL: @negative_batch_dimension
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -737,14 +737,14 @@ func.func @negative_brgemm_dimension(
} }
// CHECK-LABEL: @negative_brgemm_dimension // CHECK-LABEL: @negative_brgemm_dimension
// CHECK-NOT: x86vector.avx.dot.i8 // CHECK-NOT: x86.avx.dot.i8
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -771,14 +771,14 @@ func.func @negative_float_acc_type(
} }
// CHECK-LABEL: @negative_float_acc_type // CHECK-LABEL: @negative_float_acc_type
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -805,14 +805,14 @@ func.func @negative_int_acc_type(
} }
// CHECK-LABEL: @negative_int_acc_type // CHECK-LABEL: @negative_int_acc_type
// CHECK-NOT: x86vector.avx.dot.i8 // CHECK-NOT: x86.avx.dot.i8
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -839,14 +839,14 @@ func.func @negative_wrong_vnni_blocking_factor_bf16(
} }
// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16 // CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -873,14 +873,14 @@ func.func @negative_brgemm_not_vnni(
} }
// CHECK-LABEL: @negative_brgemm_not_vnni // CHECK-LABEL: @negative_brgemm_not_vnni
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -907,15 +907,15 @@ func.func @negative_wrong_vector_shape_int8(
} }
// CHECK-LABEL: @negative_wrong_vector_shape_int8 // CHECK-LABEL: @negative_wrong_vector_shape_int8
// CHECK-NOT: x86vector.avx.dot.i8 // CHECK-NOT: x86.avx.dot.i8
// CHECK-NOT: x86vector.avx10.dot.i8 // CHECK-NOT: x86.avx10.dot.i8
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -942,14 +942,14 @@ func.func @negative_wrong_vector_shape_bf16(
} }
// CHECK-LABEL: @negative_wrong_vector_shape_bf16 // CHECK-LABEL: @negative_wrong_vector_shape_bf16
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1005,14 +1005,14 @@ func.func @negative_flat_other_dim_is_not_2(
} }
// CHECK-LABEL: @negative_flat_other_dim_is_not_2 // CHECK-LABEL: @negative_flat_other_dim_is_not_2
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1069,14 +1069,14 @@ func.func @negative_flat_offset_diff_is_not16(
} }
// CHECK-LABEL: @negative_flat_offset_diff_is_not16 // CHECK-LABEL: @negative_flat_offset_diff_is_not16
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1132,14 +1132,14 @@ func.func @negative_flat_dynamic_offset(
} }
// CHECK-LABEL: @negative_flat_dynamic_offset // CHECK-LABEL: @negative_flat_dynamic_offset
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1196,14 +1196,14 @@ func.func @negative_flat_read_after_contract(
} }
// CHECK-LABEL: @negative_flat_read_after_contract // CHECK-LABEL: @negative_flat_read_after_contract
// CHECK-NOT: x86vector.avx512.dot // CHECK-NOT: x86.avx512.dot
// CHECK: vector.contract // CHECK: vector.contract
module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1271,7 +1271,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1338,7 +1338,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1406,7 +1406,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }
@@ -1473,7 +1473,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func { transform.apply_patterns to %func {
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
} : !transform.any_op } : !transform.any_op
transform.yield transform.yield
} }

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \ // RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s // RUN: FileCheck %s
@@ -9,7 +9,7 @@ func.func @entry() -> i32 {
%a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32> %a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32>
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
%r = x86vector.avx.intr.dot %a, %b : vector<8xf32> %r = x86.avx.intr.dot %a, %b : vector<8xf32>
%1 = vector.extract %r[%i0] : f32 from vector<8xf32> %1 = vector.extract %r[%i0] : f32 from vector<8xf32>
%2 = vector.extract %r[%i4] : f32 from vector<8xf32> %2 = vector.extract %r[%i4] : f32 from vector<8xf32>

View File

@@ -1,7 +1,7 @@
import sys import sys
# X86Vector tests must be enabled via build flag. # X86 tests must be enabled via build flag.
if not config.mlir_run_x86vector_tests: if not config.mlir_run_x86_tests:
config.unsupported = True config.unsupported = True
# No JIT on win32. # No JIT on win32.

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_c_runner_utils | \ // RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s // RUN: FileCheck %s
@@ -8,8 +8,8 @@ func.func @entry() -> i32 {
%a = arith.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32> %a = arith.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32>
%k = arith.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1> %k = arith.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1>
%r1 = x86vector.avx512.mask.compress %k, %a : vector<16xf32> %r1 = x86.avx512.mask.compress %k, %a : vector<16xf32>
%r2 = x86vector.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> %r2 = x86.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
vector.print %r1 : vector<16xf32> vector.print %r1 : vector<16xf32>
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
@@ -18,7 +18,7 @@ func.func @entry() -> i32 {
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 ) // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 )
%src = arith.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32> %src = arith.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32>
%r3 = x86vector.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32> %r3 = x86.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
vector.print %r3 : vector<16xf32> vector.print %r3 : vector<16xf32>
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 ) // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 )

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \ // RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s // RUN: FileCheck %s
@@ -7,7 +7,7 @@ func.func @entry() -> i32 {
%i0 = arith.constant 0 : i32 %i0 = arith.constant 0 : i32
%v = arith.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32> %v = arith.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32>
%r = x86vector.avx.rsqrt %v : vector<8xf32> %r = x86.avx.rsqrt %v : vector<8xf32>
// `rsqrt` may produce slightly different results on Intel and AMD machines: accept both results here. // `rsqrt` may produce slightly different results on Intel and AMD machines: accept both results here.
// CHECK: {{( 2.82[0-9]*, 1.99[0-9]*, 1.41[0-9]*, 0.99[0-9]*, 0.70[0-9]*, 0.49[0-9]*, 0.35[0-9]*, 0.24[0-9]* )}} // CHECK: {{( 2.82[0-9]*, 1.99[0-9]*, 1.41[0-9]*, 0.99[0-9]*, 0.70[0-9]*, 0.49[0-9]*, 0.35[0-9]*, 0.24[0-9]* )}}
vector.print %r : vector<8xf32> vector.print %r : vector<8xf32>

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \ // RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s // RUN: FileCheck %s
@@ -35,11 +35,11 @@
func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>, func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>,
%v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 { %v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 {
// Compute intersection of indices. // Compute intersection of indices.
%k0, %k1 = x86vector.avx512.vp2intersect %v_A, %v_C : vector<8xi64> %k0, %k1 = x86.avx512.vp2intersect %v_A, %v_C : vector<8xi64>
// Filter out values without match and compress vector. // Filter out values without match and compress vector.
%p0 = x86vector.avx512.mask.compress %k0, %v_B : vector<8xf64> %p0 = x86.avx512.mask.compress %k0, %v_B : vector<8xf64>
%p1 = x86vector.avx512.mask.compress %k1, %v_D : vector<8xf64> %p1 = x86.avx512.mask.compress %k1, %v_D : vector<8xf64>
// Dense vector dot product. // Dense vector dot product.
%acc = arith.constant 0.0 : f64 %acc = arith.constant 0.0 : f64

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: mlir-translate --mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \ // RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s // RUN: FileCheck %s
@@ -40,7 +40,7 @@ func.func @entry() -> i32 {
vector.print %w9 : vector<16xi32> vector.print %w9 : vector<16xi32>
// CHECK: ( 1, 1, 1, 1, 2, 1, 1, -219, 12, 12, 12, 0, 0, 0, 1, 0 ) // CHECK: ( 1, 1, 1, 1, 2, 1, 1, -219, 12, 12, 12, 0, 0, 0, 1, 0 )
%k1, %k2 = x86vector.avx512.vp2intersect %v9, %w9 : vector<16xi32> %k1, %k2 = x86.avx512.vp2intersect %v9, %w9 : vector<16xi32>
vector.print %k1 : vector<16xi1> vector.print %k1 : vector<16xi1>
// CHECK: ( 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1 ) // CHECK: ( 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1 )

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm -reconcile-unrealized-casts \ // RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: | mlir-translate --mlir-to-llvmir \ // RUN: | mlir-translate --mlir-to-llvmir \
// RUN: | FileCheck %s // RUN: | FileCheck %s
@@ -11,9 +11,9 @@ func.func @LLVM_x86_avx512_mask_ps_512(
%rnd_k = arith.constant 15 : i32 %rnd_k = arith.constant 15 : i32
%rnd = arith.constant 42 : i32 %rnd = arith.constant 42 : i32
// CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32> %0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32>
// CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float> // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32> %1 = x86.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32>
return %1 : vector<16xf32> return %1 : vector<16xf32>
} }
@@ -26,9 +26,9 @@ func.func @LLVM_x86_avx512_mask_pd_512(
%rnd_k = arith.constant 15 : i32 %rnd_k = arith.constant 15 : i32
%rnd = arith.constant 42 : i32 %rnd = arith.constant 42 : i32
// CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64> %0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64>
// CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64> %1 = x86.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64>
return %1 : vector<8xf64> return %1 : vector<8xf64>
} }
@@ -37,7 +37,7 @@ func.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
-> vector<16xf32> -> vector<16xf32>
{ {
// CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32( // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
%0 = x86vector.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32> %0 = x86.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32>
return %0 : vector<16xf32> return %0 : vector<16xf32>
} }
@@ -46,7 +46,7 @@ func.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
-> (vector<16xi1>, vector<16xi1>) -> (vector<16xi1>, vector<16xi1>)
{ {
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32> // CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<16xi32> %0, %1 = x86.avx512.vp2intersect %a, %b : vector<16xi32>
return %0, %1 : vector<16xi1>, vector<16xi1> return %0, %1 : vector<16xi1>, vector<16xi1>
} }
@@ -55,7 +55,7 @@ func.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
-> (vector<8 x i1>, vector<8 x i1>) -> (vector<8 x i1>, vector<8 x i1>)
{ {
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64> // CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<8xi64> %0, %1 = x86.avx512.vp2intersect %a, %b : vector<8xi64>
return %0, %1 : vector<8 x i1>, vector<8 x i1> return %0, %1 : vector<8 x i1>, vector<8 x i1>
} }
@@ -65,7 +65,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_128(
) -> vector<4xf32> ) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128( // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32> %0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -75,7 +75,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_256(
) -> vector<8xf32> ) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256( // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32> %0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -85,7 +85,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_512(
) -> vector<16xf32> ) -> vector<16xf32>
{ {
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512( // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> %0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32> return %0 : vector<16xf32>
} }
@@ -94,7 +94,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
%a: vector<8xf32>) -> vector<8xbf16> %a: vector<8xf32>) -> vector<8xbf16>
{ {
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256( // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a %0 = x86.avx512.cvt.packed.f32_to_bf16 %a
: vector<8xf32> -> vector<8xbf16> : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16> return %0 : vector<8xbf16>
} }
@@ -104,7 +104,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
%a: vector<16xf32>) -> vector<16xbf16> %a: vector<16xf32>) -> vector<16xbf16>
{ {
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512( // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a %0 = x86.avx512.cvt.packed.f32_to_bf16 %a
: vector<16xf32> -> vector<16xbf16> : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16> return %0 : vector<16xbf16>
} }
@@ -113,7 +113,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>, func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> { %b: vector<64xi8>) -> vector<16xi32> {
// CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512( // CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32> %0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32> return %0 : vector<16xi32>
} }
@@ -122,7 +122,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: memref<8xbf16>) -> vector<4xf32> %a: memref<8xbf16>) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128( // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -131,7 +131,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
%a: memref<16xbf16>) -> vector<8xf32> %a: memref<16xbf16>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256( // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -140,7 +140,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
%a: memref<8xbf16>) -> vector<4xf32> %a: memref<8xbf16>) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128( // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -149,7 +149,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
%a: memref<16xbf16>) -> vector<8xf32> %a: memref<16xbf16>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256( // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -158,7 +158,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
%a: memref<1xbf16>) -> vector<4xf32> %a: memref<1xbf16>) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128( // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -167,7 +167,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
%a: memref<1xbf16>) -> vector<8xf32> %a: memref<1xbf16>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256( // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -176,7 +176,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
%a: memref<8xf16>) -> vector<4xf32> %a: memref<8xf16>) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128( // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -185,7 +185,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
%a: memref<16xf16>) -> vector<8xf32> %a: memref<16xf16>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256( // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -194,7 +194,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
%a: memref<8xf16>) -> vector<4xf32> %a: memref<8xf16>) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128( // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -203,7 +203,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
%a: memref<16xf16>) -> vector<8xf32> %a: memref<16xf16>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256( // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> %0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -212,7 +212,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
%a: memref<1xf16>) -> vector<4xf32> %a: memref<1xf16>) -> vector<4xf32>
{ {
// CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128( // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32> return %0 : vector<4xf32>
} }
@@ -221,7 +221,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
%a: memref<1xf16>) -> vector<8xf32> %a: memref<1xf16>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256( // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> %0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -229,7 +229,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float> // CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
%0 = x86vector.avx.rsqrt %a : vector<8xf32> %0 = x86.avx.rsqrt %a : vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -239,7 +239,7 @@ func.func @LLVM_x86_avx_dp_ps_256(
) -> vector<8xf32> ) -> vector<8xf32>
{ {
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256( // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %0 = x86.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32> return %0 : vector<8xf32>
} }
@@ -247,7 +247,7 @@ func.func @LLVM_x86_avx_dp_ps_256(
func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>, func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
%b: vector<16xi8>) -> vector<4xi32> { %b: vector<16xi8>) -> vector<4xi32> {
// CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128( // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32> %0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
return %0 : vector<4xi32> return %0 : vector<4xi32>
} }
@@ -255,6 +255,6 @@ func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>, func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
%b: vector<32xi8>) -> vector<8xi32> { %b: vector<32xi8>) -> vector<8xi32> {
// CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256( // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32> %0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
return %0 : vector<8xi32> return %0 : vector<8xi32>
} }

View File

@@ -10,5 +10,5 @@ mlir_target_link_libraries(MLIRMathTestPasses PUBLIC
MLIRPass MLIRPass
MLIRTransformUtils MLIRTransformUtils
MLIRVectorDialect MLIRVectorDialect
MLIRX86VectorDialect MLIRX86Dialect
) )

View File

@@ -15,7 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -37,7 +37,7 @@ struct TestMathPolynomialApproximationPass
registry.insert<arith::ArithDialect, math::MathDialect, registry.insert<arith::ArithDialect, math::MathDialect,
vector::VectorDialect>(); vector::VectorDialect>();
if (enableAvx2) if (enableAvx2)
registry.insert<x86vector::X86VectorDialect>(); registry.insert<x86::X86Dialect>();
} }
StringRef getArgument() const final { StringRef getArgument() const final {
return "test-math-polynomial-approximation"; return "test-math-polynomial-approximation";
@@ -49,7 +49,7 @@ struct TestMathPolynomialApproximationPass
Option<bool> enableAvx2{ Option<bool> enableAvx2{
*this, "enable-avx2", *this, "enable-avx2",
llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the " llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
"X86Vector dialect"), "X86 dialect"),
llvm::cl::init(false)}; llvm::cl::init(false)};
}; };
} // namespace } // namespace

View File

@@ -20,5 +20,5 @@ mlir_target_link_libraries(MLIRVectorTestPasses PUBLIC
MLIRTransformUtils MLIRTransformUtils
MLIRVectorDialect MLIRVectorDialect
MLIRVectorToSCF MLIRVectorToSCF
MLIRX86VectorDialect MLIRX86Dialect
) )

View File

@@ -54,7 +54,7 @@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
if config.mlir_run_arm_sve_tests: if config.mlir_run_arm_sve_tests:
config.available_features.add("mlir_arm_sve_tests") config.available_features.add("mlir_arm_sve_tests")
config.mlir_run_arm_sme_tests = @MLIR_RUN_ARM_SME_TESTS@ config.mlir_run_arm_sme_tests = @MLIR_RUN_ARM_SME_TESTS@
config.mlir_run_x86vector_tests = @MLIR_RUN_X86VECTOR_TESTS@ config.mlir_run_x86_tests = @MLIR_RUN_X86_TESTS@
config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@" config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@"
config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@ config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@
config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@ config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@

View File

@@ -41,7 +41,7 @@
// CHECK-SAME: tosa // CHECK-SAME: tosa
// CHECK-SAME: transform // CHECK-SAME: transform
// CHECK-SAME: vector // CHECK-SAME: vector
// CHECK-SAME: x86vector // CHECK-SAME: x86
// RUN: mlir-opt --help-hidden | FileCheck %s -check-prefix=CHECK-HELP // RUN: mlir-opt --help-hidden | FileCheck %s -check-prefix=CHECK-HELP
// CHECK-HELP: -p - Alias for --pass-pipeline // CHECK-HELP: -p - Alias for --pass-pipeline

View File

@@ -0,0 +1,5 @@
import sys
# X86 tests must be enabled via build flag.
if not config.mlir_run_x86_tests:
config.unsupported = True

View File

@@ -3,7 +3,7 @@
// RUN: -convert-scf-to-cf \ // RUN: -convert-scf-to-cf \
// RUN: -convert-arith-to-llvm \ // RUN: -convert-arith-to-llvm \
// RUN: -convert-cf-to-llvm \ // RUN: -convert-cf-to-llvm \
// RUN: -convert-vector-to-llvm="enable-x86vector" \ // RUN: -convert-vector-to-llvm="enable-x86" \
// RUN: -convert-math-to-llvm \ // RUN: -convert-math-to-llvm \
// RUN: -convert-func-to-llvm \ // RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \ // RUN: -reconcile-unrealized-casts \

View File

@@ -1,5 +0,0 @@
import sys
# X86Vector tests must be enabled via build flag.
if not config.mlir_run_x86vector_tests:
config.unsupported = True

View File

@@ -2,7 +2,7 @@
from mlir.ir import * from mlir.ir import *
from mlir.dialects import transform from mlir.dialects import transform
from mlir.dialects.transform import x86vector from mlir.dialects.transform import x86
def run_apply_patterns(f): def run_apply_patterns(f):
@@ -28,13 +28,13 @@ def run_apply_patterns(f):
def non_configurable_patterns(): def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns # CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns # CHECK: apply_patterns
# CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma # CHECK: transform.apply_patterns.x86.vector_contract_to_fma
x86vector.ApplyVectorContractToFMAPatternsOp() x86.ApplyVectorContractToFMAPatternsOp()
# CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product # CHECK: transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp() x86.ApplyVectorContractToPackedTypeDotProductPatternsOp()
# CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma # CHECK: transform.apply_patterns.x86.vector_contract_bf16_to_fma
x86vector.ApplyVectorContractBF16ToFMAPatternsOp() x86.ApplyVectorContractBF16ToFMAPatternsOp()
# CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops # CHECK: transform.apply_patterns.x86.sink_vector_producer_ops
x86vector.ApplySinkVectorProducerOpsPatternsOp() x86.ApplySinkVectorProducerOpsPatternsOp()
# CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops # CHECK: transform.apply_patterns.x86.shuffle_vector_fma_ops
x86vector.ApplyShuffleVectorFMAOpsPatternsOp() x86.ApplyShuffleVectorFMAOpsPatternsOp()

View File

@@ -3,7 +3,7 @@
from mlir.ir import * from mlir.ir import *
import mlir.dialects.builtin as builtin import mlir.dialects.builtin as builtin
import mlir.dialects.func as func import mlir.dialects.func as func
import mlir.dialects.x86vector as x86vector import mlir.dialects.x86 as x86
def run(f): def run(f):
@@ -21,13 +21,11 @@ def testAvxOp():
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get())) @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
def avx_op(arg): def avx_op(arg):
return x86vector.BcstToPackedF32Op( return x86.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
a=arg, dst=VectorType.get((8,), F32Type.get())
)
# CHECK-LABEL: func @avx_op( # CHECK-LABEL: func @avx_op(
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> { # CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
# CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]] # CHECK: %[[VAL:.+]] = x86.avx.bcst_to_f32.packed %[[ARG]]
# CHECK: return %[[VAL]] : vector<8xf32> # CHECK: return %[[VAL]] : vector<8xf32>
# CHECK: } # CHECK: }
print(module) print(module)
@@ -41,13 +39,13 @@ def testAvx512Op():
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get())) @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
def avx512_op(arg): def avx512_op(arg):
return x86vector.CvtPackedF32ToBF16Op( return x86.CvtPackedF32ToBF16Op(
a=arg, dst=VectorType.get((8,), BF16Type.get()) a=arg, dst=VectorType.get((8,), BF16Type.get())
) )
# CHECK-LABEL: func @avx512_op( # CHECK-LABEL: func @avx512_op(
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> { # CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
# CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]] # CHECK: %[[VAL:.+]] = x86.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
# CHECK: return %[[VAL]] : vector<8xbf16> # CHECK: return %[[VAL]] : vector<8xbf16>
# CHECK: } # CHECK: }
print(module) print(module)
@@ -65,12 +63,12 @@ def testAvx10Op():
VectorType.get((64,), IntegerType.get(8)), VectorType.get((64,), IntegerType.get(8)),
) )
def avx10_op(*args): def avx10_op(*args):
return x86vector.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2]) return x86.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
# CHECK-LABEL: func @avx10_op( # CHECK-LABEL: func @avx10_op(
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>, # CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
# CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> { # CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
# CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]] # CHECK: %[[VAL:.+]] = x86.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
# CHECK: return %[[VAL]] : vector<16xi32> # CHECK: return %[[VAL]] : vector<16xi32>
# CHECK: } # CHECK: }
print(module) print(module)