[mlir][x86] Rename x86vector to x86 (#183311)
Renames 'x86vector' dialect to 'x86'. This is the first PR in series of cleanups around dialects targeting x86 platforms. The new naming scheme is shorter, cleaner, and opens possibility of integrating other x86-specific operations not strictly fitting pure vector representation. For example, the generalization will allow for future merger of AMX dialect into the x86 dialect to create one-stop x86 operations collection and boost discoverability.
This commit is contained in:
@@ -105,7 +105,7 @@ available, should be contacted first, as they're more active in those areas.
|
|||||||
* ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
|
* ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
|
||||||
* ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
|
* ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
|
||||||
* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
|
* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
|
||||||
* ‘x86vector’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
|
* ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
|
||||||
* ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
|
* ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
|
||||||
|
|
||||||
#### Paradigm Dialects
|
#### Paradigm Dialects
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ overall flow is two-stage:
|
|||||||
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
|
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
|
||||||
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
|
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
|
||||||
dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
|
dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
|
||||||
[X86Vector](Dialects/X86Vector.md) or [ArmNeon](Dialects/ArmNeon.md);
|
[X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
|
||||||
2. **translation** of MLIR dialects to LLVM IR.
|
2. **translation** of MLIR dialects to LLVM IR.
|
||||||
|
|
||||||
This flow allows the non-trivial transformation to be performed within MLIR
|
This flow allows the non-trivial transformation to be performed within MLIR
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//===-- mlir-c/Dialect/x86Vector.h - C API for x86Vector Dialect --*- C -*-===//
|
//===-- mlir-c/Dialect/x86.h - C API for x86 Dialect --------------*- C -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||||
// Exceptions.
|
// Exceptions.
|
||||||
@@ -7,8 +7,8 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_C_DIALECT_X86VECTOR_H
|
#ifndef MLIR_C_DIALECT_X86_H
|
||||||
#define MLIR_C_DIALECT_X86VECTOR_H
|
#define MLIR_C_DIALECT_X86_H
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
@@ -16,10 +16,10 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(X86Vector, x86vector);
|
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(X86, x86);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // MLIR_C_DIALECT_X86VECTOR_H
|
#endif // MLIR_C_DIALECT_X86_H
|
||||||
@@ -1521,7 +1521,7 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
|
|||||||
operations. The lowering pass provides several options to control
|
operations. The lowering pass provides several options to control
|
||||||
the kinds of optimizations that are allowed. It also provides options
|
the kinds of optimizations that are allowed. It also provides options
|
||||||
that enable the use of one or more architectural-specific dialects
|
that enable the use of one or more architectural-specific dialects
|
||||||
(AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the
|
(AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
|
||||||
architectural-neutral vector dialect lowering.
|
architectural-neutral vector dialect lowering.
|
||||||
|
|
||||||
}];
|
}];
|
||||||
@@ -1564,10 +1564,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
|
|||||||
"bool", /*default=*/"false",
|
"bool", /*default=*/"false",
|
||||||
"Enables the use of Arm FEAT_BF16 instructions while lowering "
|
"Enables the use of Arm FEAT_BF16 instructions while lowering "
|
||||||
"the vector dialect.">,
|
"the vector dialect.">,
|
||||||
Option<"x86Vector", "enable-x86vector",
|
Option<"x86", "enable-x86",
|
||||||
"bool", /*default=*/"false",
|
"bool", /*default=*/"false",
|
||||||
"Enables the use of X86Vector dialect while lowering the vector "
|
"Enables the use of X86 dialect while lowering the vector dialect.">,
|
||||||
"dialect.">,
|
|
||||||
Option<"vectorContractLowering", "vector-contract-lowering",
|
Option<"vectorContractLowering", "vector-contract-lowering",
|
||||||
"vector::VectorContractLowering",
|
"vector::VectorContractLowering",
|
||||||
/*default=*/"vector::VectorContractLowering::Dot",
|
/*default=*/"vector::VectorContractLowering::Dot",
|
||||||
|
|||||||
@@ -42,5 +42,5 @@ add_subdirectory(UB)
|
|||||||
add_subdirectory(Utils)
|
add_subdirectory(Utils)
|
||||||
add_subdirectory(Vector)
|
add_subdirectory(Vector)
|
||||||
add_subdirectory(WasmSSA)
|
add_subdirectory(WasmSSA)
|
||||||
add_subdirectory(X86Vector)
|
add_subdirectory(X86)
|
||||||
add_subdirectory(XeGPU)
|
add_subdirectory(XeGPU)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/TilingInterface.h"
|
#include "mlir/Interfaces/TilingInterface.h"
|
||||||
|
|||||||
@@ -118,9 +118,9 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
|
|||||||
desc("Enables the use of ArmSVE dialect while lowering the vector "
|
desc("Enables the use of ArmSVE dialect while lowering the vector "
|
||||||
"dialect"),
|
"dialect"),
|
||||||
init(false)};
|
init(false)};
|
||||||
PassOptions::Option<bool> x86Vector{
|
PassOptions::Option<bool> x86{
|
||||||
*this, "enable-x86vector",
|
*this, "enable-x86",
|
||||||
desc("Enables the use of X86Vector dialect while lowering the vector "
|
desc("Enables the use of X86 dialect while lowering the vector "
|
||||||
"dialect"),
|
"dialect"),
|
||||||
init(false)};
|
init(false)};
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
|
|||||||
opts.armNeon = armNeon;
|
opts.armNeon = armNeon;
|
||||||
opts.armSVE = armSVE;
|
opts.armSVE = armSVE;
|
||||||
opts.amx = amx;
|
opts.amx = amx;
|
||||||
opts.x86Vector = x86Vector;
|
opts.x86 = x86;
|
||||||
return opts;
|
return opts;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
7
mlir/include/mlir/Dialect/X86/CMakeLists.txt
Normal file
7
mlir/include/mlir/Dialect/X86/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
add_mlir_dialect(X86 x86)
|
||||||
|
add_mlir_doc(X86 X86 Dialects/ -gen-dialect-doc -dialect=x86)
|
||||||
|
|
||||||
|
add_mlir_interface(X86Interfaces)
|
||||||
|
add_dependencies(MLIRX86IncGen MLIRX86InterfacesIncGen)
|
||||||
|
|
||||||
|
add_subdirectory(TransformOps)
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS X86TransformOps.td)
|
||||||
|
mlir_tablegen(X86TransformOps.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(X86TransformOps.cpp.inc -gen-op-defs)
|
||||||
|
add_mlir_dialect_tablegen_target(MLIRX86TransformOpsIncGen)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
|
//===- X86TransformOps.h - X86 transform ops --------------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,26 +6,26 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
|
#ifndef MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
|
||||||
#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
|
#define MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// X86Vector Transform Operations
|
// X86 Transform Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
|
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h.inc"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class DialectRegistry;
|
class DialectRegistry;
|
||||||
|
|
||||||
namespace x86vector {
|
namespace x86 {
|
||||||
void registerTransformDialectExtension(DialectRegistry ®istry);
|
void registerTransformDialectExtension(DialectRegistry ®istry);
|
||||||
|
|
||||||
} // namespace x86vector
|
} // namespace x86
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
|
#endif // MLIR_DIALECT_X86_TRANSFORMOPS_X86TRANSFORMOPS_H
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===//
|
//===- X86TransformOps.td - X86 transform ops --------------*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,8 +6,8 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef X86VECTOR_TRANSFORM_OPS
|
#ifndef X86_TRANSFORM_OPS
|
||||||
#define X86VECTOR_TRANSFORM_OPS
|
#define X86_TRANSFORM_OPS
|
||||||
|
|
||||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||||
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
||||||
@@ -18,7 +18,7 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
|
|||||||
include "mlir/IR/RegionKindInterface.td"
|
include "mlir/IR/RegionKindInterface.td"
|
||||||
|
|
||||||
def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
|
def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
|
||||||
"apply_patterns.x86vector.vector_contract_to_fma",
|
"apply_patterns.x86.vector_contract_to_fma",
|
||||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Collect patterns to lower a F32 type vector.contract operation to a FMA.
|
Collect patterns to lower a F32 type vector.contract operation to a FMA.
|
||||||
@@ -28,7 +28,7 @@ def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
|
|||||||
}
|
}
|
||||||
|
|
||||||
def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
|
def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
|
||||||
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
|
"apply_patterns.x86.vector_contract_to_packed_type_dot_product",
|
||||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Collect patterns to lower a BF16/Int8 type vector.contract operation
|
Collect patterns to lower a BF16/Int8 type vector.contract operation
|
||||||
@@ -39,7 +39,7 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
|
|||||||
}
|
}
|
||||||
|
|
||||||
def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
|
def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
|
||||||
"apply_patterns.x86vector.vector_contract_bf16_to_fma",
|
"apply_patterns.x86.vector_contract_bf16_to_fma",
|
||||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Collect patterns to lower a BF16 type vector.contract operation
|
Collect patterns to lower a BF16 type vector.contract operation
|
||||||
@@ -50,7 +50,7 @@ def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
|
|||||||
}
|
}
|
||||||
|
|
||||||
def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
|
def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
|
||||||
"apply_patterns.x86vector.sink_vector_producer_ops",
|
"apply_patterns.x86.sink_vector_producer_ops",
|
||||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Collect patterns to sink vector producer operations forward in a block to
|
Collect patterns to sink vector producer operations forward in a block to
|
||||||
@@ -61,10 +61,10 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
|
|||||||
}
|
}
|
||||||
|
|
||||||
def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
|
def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
|
||||||
"apply_patterns.x86vector.shuffle_vector_fma_ops",
|
"apply_patterns.x86.shuffle_vector_fma_ops",
|
||||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Collect patterns to shuffle FMAs with x86vector operations as operands
|
Collect patterns to shuffle FMAs with x86 operations as operands
|
||||||
such that FMAs are grouped with respect to odd/even packed index.
|
such that FMAs are grouped with respect to odd/even packed index.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
@@ -72,5 +72,4 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // X86VECTOR_TRANSFORM_OPS
|
#endif // X86_TRANSFORM_OPS
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=//
|
//=- Transforms.h - X86 Dialect Transformation Entrypoints --------*- C++ -*-=//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,8 +6,8 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
|
#ifndef MLIR_DIALECT_X86_TRANSFORMS_H
|
||||||
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
|
#define MLIR_DIALECT_X86_TRANSFORMS_H
|
||||||
|
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ class LLVMConversionTarget;
|
|||||||
class LLVMTypeConverter;
|
class LLVMTypeConverter;
|
||||||
class RewritePatternSet;
|
class RewritePatternSet;
|
||||||
|
|
||||||
namespace x86vector {
|
namespace x86 {
|
||||||
|
|
||||||
/// Helper class to factor out the creation and extraction of masks from nibs.
|
/// Helper class to factor out the creation and extraction of masks from nibs.
|
||||||
struct MaskHelper {
|
struct MaskHelper {
|
||||||
@@ -100,7 +100,7 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
|
|||||||
// range by placing them at their earliest legal use site.
|
// range by placing them at their earliest legal use site.
|
||||||
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
|
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
// Shuffles FMAs with x86vector operations as operands such that FMAs are
|
// Shuffles FMAs with x86 operations as operands such that FMAs are
|
||||||
// grouped with respect to odd/even packed index.
|
// grouped with respect to odd/even packed index.
|
||||||
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
|
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
@@ -196,17 +196,17 @@ void populateSpecializedTransposeLoweringPatterns(
|
|||||||
int benefit = 10);
|
int benefit = 10);
|
||||||
|
|
||||||
} // namespace avx2
|
} // namespace avx2
|
||||||
} // namespace x86vector
|
} // namespace x86
|
||||||
|
|
||||||
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
|
/// Collect a set of patterns to lower X86 ops to ops that map to LLVM
|
||||||
/// intrinsics.
|
/// intrinsics.
|
||||||
void populateX86VectorLegalizeForLLVMExportPatterns(
|
void populateX86LegalizeForLLVMExportPatterns(
|
||||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||||
|
|
||||||
/// Configure the target to support lowering X86Vector ops to ops that map to
|
/// Configure the target to support lowering X86 ops to ops that map to
|
||||||
/// LLVM intrinsics.
|
/// LLVM intrinsics.
|
||||||
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target);
|
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
|
#endif // MLIR_DIALECT_X86_TRANSFORMS_H
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorUtils.h - X86Vector Utilities -------------------*- C++ -*-===//
|
//===- X86Utils.h - X86 Utilities -------------------------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,8 +6,8 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
|
#ifndef MLIR_DIALECT_X86_UTILS_X86UTILS_H_
|
||||||
#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
|
#define MLIR_DIALECT_X86_UTILS_X86UTILS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
@@ -22,7 +22,7 @@ namespace mlir {
|
|||||||
class AffineMap;
|
class AffineMap;
|
||||||
class Operation;
|
class Operation;
|
||||||
|
|
||||||
namespace x86vector {
|
namespace x86 {
|
||||||
|
|
||||||
// Return true if the operation is in VNNI layout.
|
// Return true if the operation is in VNNI layout.
|
||||||
// Optionally, the check can be constrained to a specific VNNI blocking factor.
|
// Optionally, the check can be constrained to a specific VNNI blocking factor.
|
||||||
@@ -63,7 +63,7 @@ LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
|
|||||||
Operation *opA, Operation *opB,
|
Operation *opA, Operation *opB,
|
||||||
int64_t nonUnitDimAcc, VectorType accTy);
|
int64_t nonUnitDimAcc, VectorType accTy);
|
||||||
|
|
||||||
} // namespace x86vector
|
} // namespace x86
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
|
#endif // MLIR_DIALECT_X86_UTILS_X86UTILS_H_
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===-- X86VectorOps.td - X86Vector dialect operation defs -*- tablegen -*-===//
|
//===-- X86Ops.td - X86 dialect operation defs -------------*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,25 +6,25 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This file defines the basic operations for the X86Vector dialect.
|
// This file defines the basic operations for the X86 dialect.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef X86VECTOR_OPS
|
#ifndef X86_OPS
|
||||||
#define X86VECTOR_OPS
|
#define X86_OPS
|
||||||
|
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||||
include "mlir/Dialect/X86Vector/X86VectorInterfaces.td"
|
include "mlir/Dialect/X86/X86Interfaces.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// X86Vector dialect definition
|
// X86 dialect definition
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def X86Vector_Dialect : Dialect {
|
def X86_Dialect : Dialect {
|
||||||
let name = "x86vector";
|
let name = "x86";
|
||||||
let cppNamespace = "::mlir::x86vector";
|
let cppNamespace = "::mlir::x86";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -33,7 +33,7 @@ def X86Vector_Dialect : Dialect {
|
|||||||
|
|
||||||
// Operation that is part of the input dialect.
|
// Operation that is part of the input dialect.
|
||||||
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
|
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
|
Op<X86_Dialect, "avx512." # mnemonic, traits> {}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
// MaskCompressOp
|
// MaskCompressOp
|
||||||
@@ -279,7 +279,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
%dst = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
|
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
|
||||||
@@ -323,7 +323,7 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
%dst = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
|
let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
|
||||||
@@ -349,7 +349,7 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
|
|||||||
|
|
||||||
// Operation that is part of the input dialect.
|
// Operation that is part of the input dialect.
|
||||||
class AVX10_Op<string mnemonic, list<Trait> traits = []> :
|
class AVX10_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {}
|
Op<X86_Dialect, "avx10." # mnemonic, traits> {}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
// AVX10 Int8 Dot
|
// AVX10 Int8 Dot
|
||||||
@@ -376,7 +376,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
%dst = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
|
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
|
||||||
@@ -401,13 +401,13 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
|
|||||||
|
|
||||||
// Operation that is part of the input dialect.
|
// Operation that is part of the input dialect.
|
||||||
class AVX_Op<string mnemonic, list<Trait> traits = []> :
|
class AVX_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<X86Vector_Dialect, "avx." # mnemonic, traits> {}
|
Op<X86_Dialect, "avx." # mnemonic, traits> {}
|
||||||
|
|
||||||
// Operation that may be part of the input dialect, but whose
|
// Operation that may be part of the input dialect, but whose
|
||||||
// form is somewhere between the user view of the operation
|
// form is somewhere between the user view of the operation
|
||||||
// and the actual lower level intrinsic in LLVM IR.
|
// and the actual lower level intrinsic in LLVM IR.
|
||||||
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
|
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
|
||||||
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
|
Op<X86_Dialect, "avx.intr." # mnemonic, traits> {}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
// AVX Rsqrt
|
// AVX Rsqrt
|
||||||
@@ -448,7 +448,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
|
||||||
%1 = vector.extract %0[%i0] : f32 from vector<8xf32>
|
%1 = vector.extract %0[%i0] : f32 from vector<8xf32>
|
||||||
%2 = vector.extract %0[%i4] : f32 from vector<8xf32>
|
%2 = vector.extract %0[%i4] : f32 from vector<8xf32>
|
||||||
%d = arith.addf %1, %2 : f32
|
%d = arith.addf %1, %2 : f32
|
||||||
@@ -500,7 +500,7 @@ def DotInt8Op : AVX_Op<"dot.i8", [Pure,
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
%dst = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
|
let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
|
||||||
@@ -543,8 +543,8 @@ def BcstToPackedF32Op
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
%dst = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
||||||
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
%dst = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
|
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
|
||||||
@@ -595,8 +595,8 @@ def CvtPackedEvenIndexedToF32Op
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%dst = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%dst = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
|
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
|
||||||
@@ -642,8 +642,8 @@ def CvtPackedOddIndexedToF32Op
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%dst = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%dst = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
|
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
|
||||||
@@ -673,4 +673,4 @@ def CvtPackedOddIndexedToF32Op
|
|||||||
::mlir::RewriterBase &rewriter);
|
::mlir::RewriterBase &rewriter);
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
#endif // X86VECTOR_OPS
|
#endif // X86_OPS
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorDialect.h - MLIR Dialect for X86Vector ----------*- C++ -*-===//
|
//===- X86Dialect.h - MLIR Dialect for X86 ----------------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,12 +6,12 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This file declares the Target dialect for X86Vector in MLIR.
|
// This file declares the Target dialect for X86 in MLIR.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
|
#ifndef MLIR_DIALECT_X86_X86DIALECT_H_
|
||||||
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
|
#define MLIR_DIALECT_X86_X86DIALECT_H_
|
||||||
|
|
||||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
@@ -25,11 +25,11 @@
|
|||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
/// Include the generated interface declarations.
|
/// Include the generated interface declarations.
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc"
|
#include "mlir/Dialect/X86/X86Interfaces.h.inc"
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc"
|
#include "mlir/Dialect/X86/X86Dialect.h.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/X86Vector/X86Vector.h.inc"
|
#include "mlir/Dialect/X86/X86.h.inc"
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
|
#endif // MLIR_DIALECT_X86_X86DIALECT_H_
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorInterfaces.td - X86Vector interfaces -------*- tablegen -*-===//
|
//===- X86Interfaces.td - X86 interfaces -------------------*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,12 +6,12 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This file defines interfaces for the X86Vector dialect.
|
// This file defines interfaces for the X86 dialect.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef X86VECTOR_INTERFACES
|
#ifndef X86_INTERFACES
|
||||||
#define X86VECTOR_INTERFACES
|
#define X86_INTERFACES
|
||||||
|
|
||||||
include "mlir/IR/Interfaces.td"
|
include "mlir/IR/Interfaces.td"
|
||||||
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
|
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
|
||||||
@@ -25,7 +25,7 @@ def X86IntrinsicOpInterface
|
|||||||
let description = [{
|
let description = [{
|
||||||
A wrapper interface for operations representing x86 LLVM intrinsics.
|
A wrapper interface for operations representing x86 LLVM intrinsics.
|
||||||
}];
|
}];
|
||||||
let cppNamespace = "::mlir::x86vector";
|
let cppNamespace = "::mlir::x86";
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // X86VECTOR_INTERFACES
|
#endif // X86_INTERFACES
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
add_mlir_dialect(X86Vector x86vector)
|
|
||||||
add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
|
|
||||||
|
|
||||||
add_mlir_interface(X86VectorInterfaces)
|
|
||||||
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
|
|
||||||
|
|
||||||
add_subdirectory(TransformOps)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
|
|
||||||
mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
|
|
||||||
mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
|
|
||||||
add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
|
|
||||||
@@ -517,13 +517,13 @@ add_mlir_upstream_c_api_library(MLIRCAPIWasmSSA
|
|||||||
MLIRWasmSSADialect
|
MLIRWasmSSADialect
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_upstream_c_api_library(MLIRCAPIX86Vector
|
add_mlir_upstream_c_api_library(MLIRCAPIX86
|
||||||
X86Vector.cpp
|
X86.cpp
|
||||||
|
|
||||||
PARTIAL_SOURCES_INTENDED
|
PARTIAL_SOURCES_INTENDED
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRCAPIIR
|
MLIRCAPIIR
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_upstream_c_api_library(MLIRCAPIXeGPU
|
add_mlir_upstream_c_api_library(MLIRCAPIXeGPU
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86Vector.cpp - C Interface for X86Vector dialect ------------------===//
|
//===- X86.cpp - C Interface for X86 dialect ------------------------------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,9 +6,8 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir-c/Dialect/X86Vector.h"
|
#include "mlir-c/Dialect/X86.h"
|
||||||
#include "mlir/CAPI/Registration.h"
|
#include "mlir/CAPI/Registration.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(X86Vector, x86vector,
|
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(X86, x86, mlir::x86::X86Dialect)
|
||||||
mlir::x86vector::X86VectorDialect)
|
|
||||||
@@ -40,6 +40,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
|
|||||||
MLIRArmSVETransforms
|
MLIRArmSVETransforms
|
||||||
MLIRAMXDialect
|
MLIRAMXDialect
|
||||||
MLIRAMXTransforms
|
MLIRAMXTransforms
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
MLIRX86VectorTransforms
|
MLIRX86Transforms
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
@@ -53,8 +53,8 @@ struct ConvertVectorToLLVMPass
|
|||||||
registry.insert<arm_sve::ArmSVEDialect>();
|
registry.insert<arm_sve::ArmSVEDialect>();
|
||||||
if (amx)
|
if (amx)
|
||||||
registry.insert<amx::AMXDialect>();
|
registry.insert<amx::AMXDialect>();
|
||||||
if (x86Vector)
|
if (x86)
|
||||||
registry.insert<x86vector::X86VectorDialect>();
|
registry.insert<x86::X86Dialect>();
|
||||||
}
|
}
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
@@ -140,9 +140,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
|
|||||||
configureAMXLegalizeForExportTarget(target);
|
configureAMXLegalizeForExportTarget(target);
|
||||||
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
|
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
|
||||||
}
|
}
|
||||||
if (x86Vector) {
|
if (x86) {
|
||||||
configureX86VectorLegalizeForExportTarget(target);
|
configureX86LegalizeForExportTarget(target);
|
||||||
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
|
populateX86LegalizeForLLVMExportPatterns(converter, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(
|
if (failed(
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ add_subdirectory(UB)
|
|||||||
add_subdirectory(Utils)
|
add_subdirectory(Utils)
|
||||||
add_subdirectory(Vector)
|
add_subdirectory(Vector)
|
||||||
add_subdirectory(WasmSSA)
|
add_subdirectory(WasmSSA)
|
||||||
add_subdirectory(X86Vector)
|
add_subdirectory(X86)
|
||||||
add_subdirectory(XeGPU)
|
add_subdirectory(XeGPU)
|
||||||
|
|
||||||
set(LLVM_OPTIONAL_SOURCES
|
set(LLVM_OPTIONAL_SOURCES
|
||||||
|
|||||||
@@ -21,6 +21,6 @@ add_mlir_dialect_library(MLIRMathTransforms
|
|||||||
MLIRSCFDialect
|
MLIRSCFDialect
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
MLIRVectorDialect
|
MLIRVectorDialect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@
|
|||||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
@@ -1740,7 +1740,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
|
|||||||
// Compute an approximate result.
|
// Compute an approximate result.
|
||||||
Value yApprox = handleMultidimensionalVectors(
|
Value yApprox = handleMultidimensionalVectors(
|
||||||
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
|
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
|
||||||
return x86vector::RsqrtOp::create(builder, operands);
|
return x86::RsqrtOp::create(builder, operands);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Do a single step of Newton-Raphson iteration to improve the approximation.
|
// Do a single step of Newton-Raphson iteration to improve the approximation.
|
||||||
|
|||||||
@@ -18,5 +18,5 @@ add_mlir_dialect_library(MLIRVectorTransformOps
|
|||||||
MLIRTransformDialect
|
MLIRTransformDialect
|
||||||
MLIRVectorDialect
|
MLIRVectorDialect
|
||||||
MLIRVectorToSCF
|
MLIRVectorToSCF
|
||||||
MLIRX86VectorTransforms
|
MLIRX86Transforms
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
@@ -191,12 +191,10 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
|
|||||||
vector::populateVectorTransposeLoweringPatterns(patterns,
|
vector::populateVectorTransposeLoweringPatterns(patterns,
|
||||||
getLoweringStrategy());
|
getLoweringStrategy());
|
||||||
if (getAvx2LoweringStrategy()) {
|
if (getAvx2LoweringStrategy()) {
|
||||||
auto avx2LoweringOptions =
|
auto avx2LoweringOptions = x86::avx2::LoweringOptions().setTransposeOptions(
|
||||||
x86vector::avx2::LoweringOptions().setTransposeOptions(
|
x86::avx2::TransposeLoweringOptions().lower4x8xf32(true).lower8x8xf32(
|
||||||
x86vector::avx2::TransposeLoweringOptions()
|
true));
|
||||||
.lower4x8xf32(true)
|
x86::avx2::populateSpecializedTransposeLoweringPatterns(
|
||||||
.lower8x8xf32(true));
|
|
||||||
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
|
||||||
patterns, avx2LoweringOptions, /*benefit=*/10);
|
patterns, avx2LoweringOptions, /*benefit=*/10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
add_mlir_dialect_library(MLIRX86VectorDialect
|
add_mlir_dialect_library(MLIRX86Dialect
|
||||||
X86VectorDialect.cpp
|
X86Dialect.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRX86VectorIncGen
|
MLIRX86IncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
|
//===- X86Dialect.cpp - MLIR X86 ops implementation -----------------------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,25 +6,25 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This file implements the X86Vector dialect and its operations.
|
// This file implements the X86 dialect and its operations.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
|
#include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
|
#include "mlir/Dialect/X86/X86Dialect.cpp.inc"
|
||||||
|
|
||||||
void x86vector::X86VectorDialect::initialize() {
|
void x86::X86Dialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
|
#include "mlir/Dialect/X86/X86.cpp.inc"
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
|
|||||||
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
|
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult x86vector::MaskCompressOp::verify() {
|
LogicalResult x86::MaskCompressOp::verify() {
|
||||||
if (getSrc() && getConstantSrc())
|
if (getSrc() && getConstantSrc())
|
||||||
return emitError("cannot use both src and constant_src");
|
return emitError("cannot use both src and constant_src");
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ LogicalResult x86vector::MaskCompressOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
|
SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
|
||||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||||
RewriterBase &rewriter) {
|
RewriterBase &rewriter) {
|
||||||
auto loc = getLoc();
|
auto loc = getLoc();
|
||||||
@@ -71,9 +71,9 @@ SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value>
|
SmallVector<Value>
|
||||||
x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||||
const LLVMTypeConverter &typeConverter,
|
const LLVMTypeConverter &typeConverter,
|
||||||
RewriterBase &rewriter) {
|
RewriterBase &rewriter) {
|
||||||
SmallVector<Value> intrinsicOperands(operands);
|
SmallVector<Value> intrinsicOperands(operands);
|
||||||
// Dot product of all elements, broadcasted to all elements.
|
// Dot product of all elements, broadcasted to all elements.
|
||||||
Value scale =
|
Value scale =
|
||||||
@@ -83,7 +83,7 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
|||||||
return intrinsicOperands;
|
return intrinsicOperands;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
|
SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
|
||||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||||
RewriterBase &rewriter) {
|
RewriterBase &rewriter) {
|
||||||
Adaptor adaptor(operands, *this);
|
Adaptor adaptor(operands, *this);
|
||||||
@@ -91,7 +91,7 @@ SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
|
|||||||
typeConverter, rewriter)};
|
typeConverter, rewriter)};
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
|
SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
|
||||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||||
RewriterBase &rewriter) {
|
RewriterBase &rewriter) {
|
||||||
Adaptor adaptor(operands, *this);
|
Adaptor adaptor(operands, *this);
|
||||||
@@ -99,7 +99,7 @@ SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
|
|||||||
typeConverter, rewriter)};
|
typeConverter, rewriter)};
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
|
SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
|
||||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||||
RewriterBase &rewriter) {
|
RewriterBase &rewriter) {
|
||||||
Adaptor adaptor(operands, *this);
|
Adaptor adaptor(operands, *this);
|
||||||
@@ -108,4 +108,4 @@ SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
|
#include "mlir/Dialect/X86/X86.cpp.inc"
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
add_mlir_dialect_library(MLIRX86VectorTransformOps
|
add_mlir_dialect_library(MLIRX86TransformOps
|
||||||
X86VectorTransformOps.cpp
|
X86TransformOps.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRX86VectorTransformOpsIncGen
|
MLIRX86TransformOpsIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
@@ -12,6 +12,6 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps
|
|||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
MLIRTransformDialect
|
MLIRTransformDialect
|
||||||
MLIRTransformDialectUtils
|
MLIRTransformDialectUtils
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
MLIRX86VectorTransforms
|
MLIRX86Transforms
|
||||||
)
|
)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorTransformOps.cpp ------------------------------------------===//
|
//===- X86TransformOps.cpp ------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,45 +6,45 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
|
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/RegionKindInterface.h"
|
#include "mlir/IR/RegionKindInterface.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
using namespace mlir::transform;
|
using namespace mlir::transform;
|
||||||
|
|
||||||
void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
|
void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
x86vector::populateVectorContractToFMAPatterns(patterns);
|
x86::populateVectorContractToFMAPatterns(patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
|
void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
|
||||||
populatePatterns(RewritePatternSet &patterns) {
|
populatePatterns(RewritePatternSet &patterns) {
|
||||||
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
|
x86::populateVectorContractToPackedTypeDotProductPatterns(patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
|
void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
x86vector::populateVectorContractBF16ToFMAPatterns(patterns);
|
x86::populateVectorContractBF16ToFMAPatterns(patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
|
void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
|
x86::populateSinkVectorProducerOpsPatterns(patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
|
void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
x86vector::populateShuffleVectorFMAOpsPatterns(patterns);
|
x86::populateShuffleVectorFMAOpsPatterns(patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -52,28 +52,26 @@ void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class X86VectorTransformDialectExtension
|
class X86TransformDialectExtension
|
||||||
: public transform::TransformDialectExtension<
|
: public transform::TransformDialectExtension<
|
||||||
X86VectorTransformDialectExtension> {
|
X86TransformDialectExtension> {
|
||||||
public:
|
public:
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86TransformDialectExtension)
|
||||||
X86VectorTransformDialectExtension)
|
|
||||||
|
|
||||||
X86VectorTransformDialectExtension() {
|
X86TransformDialectExtension() {
|
||||||
declareGeneratedDialect<x86vector::X86VectorDialect>();
|
declareGeneratedDialect<x86::X86Dialect>();
|
||||||
declareGeneratedDialect<LLVM::LLVMDialect>();
|
declareGeneratedDialect<LLVM::LLVMDialect>();
|
||||||
registerTransformOps<
|
registerTransformOps<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
|
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
|
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.cpp.inc"
|
||||||
|
|
||||||
void mlir::x86vector::registerTransformDialectExtension(
|
void mlir::x86::registerTransformDialectExtension(DialectRegistry ®istry) {
|
||||||
DialectRegistry ®istry) {
|
registry.addExtensions<X86TransformDialectExtension>();
|
||||||
registry.addExtensions<X86VectorTransformDialectExtension>();
|
|
||||||
}
|
}
|
||||||
@@ -15,20 +15,21 @@
|
|||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "llvm/Support/Format.h"
|
#include "llvm/Support/Format.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
using namespace mlir::x86vector::avx2;
|
using namespace mlir::x86::avx2;
|
||||||
using namespace mlir::x86vector::avx2::inline_asm;
|
using namespace mlir::x86::avx2::inline_asm;
|
||||||
using namespace mlir::x86vector::avx2::intrin;
|
using namespace mlir::x86::avx2::intrin;
|
||||||
|
|
||||||
Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
|
Value mlir::x86::avx2::inline_asm::mm256BlendPsAsm(ImplicitLocOpBuilder &b,
|
||||||
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
|
Value v1, Value v2,
|
||||||
|
uint8_t mask) {
|
||||||
auto asmDialectAttr =
|
auto asmDialectAttr =
|
||||||
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
|
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
|
||||||
const auto *asmTp = "vblendps $0, $1, $2, {0}";
|
const auto *asmTp = "vblendps $0, $1, $2, {0}";
|
||||||
@@ -45,14 +46,14 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
|
|||||||
return asmOp.getResult(0);
|
return asmOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
|
Value mlir::x86::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
|
||||||
Value v1, Value v2) {
|
Value v1, Value v2) {
|
||||||
return vector::ShuffleOp::create(b, v1, v2,
|
return vector::ShuffleOp::create(b, v1, v2,
|
||||||
ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
|
ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
|
||||||
}
|
}
|
||||||
|
|
||||||
Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
|
Value mlir::x86::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
|
||||||
Value v1, Value v2) {
|
Value v1, Value v2) {
|
||||||
return vector::ShuffleOp::create(
|
return vector::ShuffleOp::create(
|
||||||
b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
|
b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
|
||||||
}
|
}
|
||||||
@@ -60,9 +61,8 @@ Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
|
|||||||
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
|
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
|
||||||
/// 0:127 | 128:255
|
/// 0:127 | 128:255
|
||||||
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
|
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
|
||||||
Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
|
Value mlir::x86::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
|
||||||
Value v1, Value v2,
|
Value v2, uint8_t mask) {
|
||||||
uint8_t mask) {
|
|
||||||
uint8_t b01, b23, b45, b67;
|
uint8_t b01, b23, b45, b67;
|
||||||
MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
|
MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
|
||||||
SmallVector<int64_t> shuffleMask = {
|
SmallVector<int64_t> shuffleMask = {
|
||||||
@@ -76,8 +76,9 @@ Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
|
|||||||
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
|
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
|
||||||
// 0 1 2 3
|
// 0 1 2 3
|
||||||
// imm[0:1] out of imm[4:7].
|
// imm[0:1] out of imm[4:7].
|
||||||
Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
|
Value mlir::x86::avx2::intrin::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
|
||||||
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
|
Value v1, Value v2,
|
||||||
|
uint8_t mask) {
|
||||||
SmallVector<int64_t> shuffleMask;
|
SmallVector<int64_t> shuffleMask;
|
||||||
auto appendToMask = [&](uint8_t control) {
|
auto appendToMask = [&](uint8_t control) {
|
||||||
if (control == 0)
|
if (control == 0)
|
||||||
@@ -99,9 +100,8 @@ Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
|
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
|
||||||
Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
|
Value mlir::x86::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, Value v1,
|
||||||
Value v1, Value v2,
|
Value v2, uint8_t mask) {
|
||||||
uint8_t mask) {
|
|
||||||
SmallVector<int64_t, 8> shuffleMask;
|
SmallVector<int64_t, 8> shuffleMask;
|
||||||
for (int i = 0; i < 8; ++i) {
|
for (int i = 0; i < 8; ++i) {
|
||||||
bool isSet = mask & (1 << i);
|
bool isSet = mask & (1 << i);
|
||||||
@@ -111,8 +111,8 @@ Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
|
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
|
||||||
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
|
void mlir::x86::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
|
||||||
MutableArrayRef<Value> vs) {
|
MutableArrayRef<Value> vs) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
|
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
|
||||||
assert(vs.size() == 4 && "expects 4 vectors");
|
assert(vs.size() == 4 && "expects 4 vectors");
|
||||||
@@ -136,8 +136,8 @@ void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
|
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
|
||||||
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
|
void mlir::x86::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
|
||||||
MutableArrayRef<Value> vs) {
|
MutableArrayRef<Value> vs) {
|
||||||
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
|
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
|
||||||
(void)vt;
|
(void)vt;
|
||||||
assert(vs.size() == 8 && "expects 8 vectors");
|
assert(vs.size() == 8 && "expects 8 vectors");
|
||||||
@@ -284,7 +284,7 @@ private:
|
|||||||
LoweringOptions loweringOptions;
|
LoweringOptions loweringOptions;
|
||||||
};
|
};
|
||||||
|
|
||||||
void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
void mlir::x86::avx2::populateSpecializedTransposeLoweringPatterns(
|
||||||
RewritePatternSet &patterns, LoweringOptions options, int benefit) {
|
RewritePatternSet &patterns, LoweringOptions options, int benefit) {
|
||||||
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
|
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
add_mlir_dialect_library(MLIRX86VectorTransforms
|
add_mlir_dialect_library(MLIRX86Transforms
|
||||||
AVXTranspose.cpp
|
AVXTranspose.cpp
|
||||||
LegalizeForLLVMExport.cpp
|
LegalizeForLLVMExport.cpp
|
||||||
VectorContractToFMA.cpp
|
VectorContractToFMA.cpp
|
||||||
@@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
|
|||||||
MLIRLLVMDialect
|
MLIRLLVMDialect
|
||||||
MLIRVectorDialect
|
MLIRVectorDialect
|
||||||
MLIRVectorUtils
|
MLIRVectorUtils
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
MLIRX86VectorUtils
|
MLIRX86Utils
|
||||||
)
|
)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
|
//===- LegalizeForLLVMExport.cpp - Prepare X86 for LLVM translation -------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,26 +6,26 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
|
|
||||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
/// Generic one-to-one conversion of simply mappable operations into calls
|
/// Generic one-to-one conversion of simply mappable operations into calls
|
||||||
/// to their respective LLVM intrinsics.
|
/// to their respective LLVM intrinsics.
|
||||||
struct X86IntrinsicOpConversion
|
struct X86IntrinsicOpConversion
|
||||||
: public ConvertOpInterfaceToLLVMPattern<x86vector::X86IntrinsicOp> {
|
: public ConvertOpInterfaceToLLVMPattern<x86::X86IntrinsicOp> {
|
||||||
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
|
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
|
matchAndRewrite(x86::X86IntrinsicOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
const LLVMTypeConverter &typeConverter = *getTypeConverter();
|
const LLVMTypeConverter &typeConverter = *getTypeConverter();
|
||||||
return LLVM::detail::intrinsicRewrite(
|
return LLVM::detail::intrinsicRewrite(
|
||||||
@@ -37,13 +37,12 @@ struct X86IntrinsicOpConversion
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/// Populate the given list with patterns that convert from X86Vector to LLVM.
|
/// Populate the given list with patterns that convert from X86 to LLVM.
|
||||||
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
|
void mlir::populateX86LegalizeForLLVMExportPatterns(
|
||||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||||
patterns.add<X86IntrinsicOpConversion>(converter);
|
patterns.add<X86IntrinsicOpConversion>(converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::configureX86VectorLegalizeForExportTarget(
|
void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) {
|
||||||
LLVMConversionTarget &target) {
|
target.addIllegalDialect<X86Dialect>();
|
||||||
target.addIllegalDialect<X86VectorDialect>();
|
|
||||||
}
|
}
|
||||||
@@ -7,8 +7,8 @@
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
@@ -17,17 +17,17 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Validates whether the given operation is an x86vector operation and has only
|
// Validates whether the given operation is an x86 operation and has only
|
||||||
// one consumer.
|
// one consumer.
|
||||||
static bool validateFMAOperands(Value op) {
|
static bool validateFMAOperands(Value op) {
|
||||||
if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
|
if (auto cvt = op.getDefiningOp<x86::CvtPackedEvenIndexedToF32Op>())
|
||||||
return cvt.getResult().hasOneUse();
|
return cvt.getResult().hasOneUse();
|
||||||
|
|
||||||
if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
|
if (auto bcst = op.getDefiningOp<x86::BcstToPackedF32Op>())
|
||||||
return bcst.getResult().hasOneUse();
|
return bcst.getResult().hasOneUse();
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
@@ -36,14 +36,14 @@ static bool validateFMAOperands(Value op) {
|
|||||||
// Validates the vector.fma operation on the following conditions:
|
// Validates the vector.fma operation on the following conditions:
|
||||||
// (i) one of the lhs or rhs defining operation should be
|
// (i) one of the lhs or rhs defining operation should be
|
||||||
// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
|
// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
|
||||||
// an x86vector operation and has only one consumer, (iii) all operations
|
// an x86 operation and has only one consumer, (iii) all operations
|
||||||
// are in the same block, and (iv) ths FMA has only one user.
|
// are in the same block, and (iv) ths FMA has only one user.
|
||||||
static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
|
static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
|
||||||
Value lhs = fmaOp.getLhs();
|
Value lhs = fmaOp.getLhs();
|
||||||
Value rhs = fmaOp.getRhs();
|
Value rhs = fmaOp.getRhs();
|
||||||
|
|
||||||
if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
|
if (!isa<x86::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
|
||||||
!isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
|
!isa<x86::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
|
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
|
||||||
@@ -93,38 +93,38 @@ static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shuffle FMAs with x86vector operations as operands such that
|
// Shuffle FMAs with x86 operations as operands such that
|
||||||
// FMAs are grouped with respect to odd/even packed index.
|
// FMAs are grouped with respect to odd/even packed index.
|
||||||
//
|
//
|
||||||
// For example:
|
// For example:
|
||||||
// ```
|
// ```
|
||||||
// %1 = x86vector.avx.bcst_to_f32.packed
|
// %1 = x86.avx.bcst_to_f32.packed
|
||||||
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// %2 = x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// %3 = vector.fma %1, %2, %arg1
|
// %3 = vector.fma %1, %2, %arg1
|
||||||
// %4 = x86vector.avx.bcst_to_f32.packed
|
// %4 = x86.avx.bcst_to_f32.packed
|
||||||
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
// %5 = x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// %6 = vector.fma %4, %5, %3
|
// %6 = vector.fma %4, %5, %3
|
||||||
// %7 = x86vector.avx.bcst_to_f32.packed
|
// %7 = x86.avx.bcst_to_f32.packed
|
||||||
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// %8 = x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// %9 = vector.fma %7, %8, %arg2
|
// %9 = vector.fma %7, %8, %arg2
|
||||||
// %10 = x86vector.avx.bcst_to_f32.packed
|
// %10 = x86.avx.bcst_to_f32.packed
|
||||||
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
// %11 = x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// %12 = vector.fma %10, %11, %9
|
// %12 = vector.fma %10, %11, %9
|
||||||
// yield %6, %12
|
// yield %6, %12
|
||||||
// ```
|
// ```
|
||||||
// to
|
// to
|
||||||
// ```
|
// ```
|
||||||
// %1 = x86vector.avx.bcst_to_f32.packed
|
// %1 = x86.avx.bcst_to_f32.packed
|
||||||
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// %2 = x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// %3 = vector.fma %1, %2, %arg1
|
// %3 = vector.fma %1, %2, %arg1
|
||||||
// %7 = x86vector.avx.bcst_to_f32.packed
|
// %7 = x86.avx.bcst_to_f32.packed
|
||||||
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// %8 = x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// %9 = vector.fma %7, %8, %arg2
|
// %9 = vector.fma %7, %8, %arg2
|
||||||
// %4 = x86vector.avx.bcst_to_f32.packed
|
// %4 = x86.avx.bcst_to_f32.packed
|
||||||
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
// %5 = x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// %6 = vector.fma %4, %5, %3
|
// %6 = vector.fma %4, %5, %3
|
||||||
// %10 = x86vector.avx.bcst_to_f32.packed
|
// %10 = x86.avx.bcst_to_f32.packed
|
||||||
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
// %11 = x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// %12 = vector.fma %10, %11, %9
|
// %12 = vector.fma %10, %11, %9
|
||||||
// yield %9, %12
|
// yield %9, %12
|
||||||
// ```
|
// ```
|
||||||
@@ -150,10 +150,9 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
|
|||||||
if (!fma)
|
if (!fma)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
|
bool hasX86CvtOperand =
|
||||||
fma.getLhs().getDefiningOp()) ||
|
isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getLhs().getDefiningOp()) ||
|
||||||
isa<x86vector::CvtPackedEvenIndexedToF32Op>(
|
isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getRhs().getDefiningOp());
|
||||||
fma.getRhs().getDefiningOp());
|
|
||||||
|
|
||||||
if (hasX86CvtOperand && stopAtNextDependentFMA)
|
if (hasX86CvtOperand && stopAtNextDependentFMA)
|
||||||
break;
|
break;
|
||||||
@@ -180,7 +179,6 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void x86vector::populateShuffleVectorFMAOpsPatterns(
|
void x86::populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns) {
|
||||||
RewritePatternSet &patterns) {
|
|
||||||
patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
|
patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
|
||||||
}
|
}
|
||||||
@@ -8,8 +8,8 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
|
|
||||||
static FailureOr<llvm::SmallVector<Operation *>>
|
static FailureOr<llvm::SmallVector<Operation *>>
|
||||||
getSameBlockUsers(Operation *op) {
|
getSameBlockUsers(Operation *op) {
|
||||||
@@ -141,8 +141,7 @@ struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void x86vector::populateSinkVectorProducerOpsPatterns(
|
void x86::populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns) {
|
||||||
RewritePatternSet &patterns) {
|
|
||||||
patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
|
patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
|
||||||
SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
|
SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
|
||||||
}
|
}
|
||||||
@@ -11,9 +11,9 @@
|
|||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
|
#include "mlir/Dialect/X86/Utils/X86Utils.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
@@ -25,7 +25,7 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
|
|
||||||
// Verifies that the LHS and RHS operands of a vector.contract are load or
|
// Verifies that the LHS and RHS operands of a vector.contract are load or
|
||||||
// vector.transfer_read operations on a memref source buffer, and checks
|
// vector.transfer_read operations on a memref source buffer, and checks
|
||||||
@@ -61,7 +61,7 @@ static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Return false if the two innermost strides of the memref are not contiguous.
|
// Return false if the two innermost strides of the memref are not contiguous.
|
||||||
// The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
|
// The x86.avx.cvt.packed.even/odd.indexed_to_f32 operations require
|
||||||
// an eight-element tuple of bf16 values to be contiguous.
|
// an eight-element tuple of bf16 values to be contiguous.
|
||||||
int dimsToCheck = isVnni ? 2 : 1;
|
int dimsToCheck = isVnni ? 2 : 1;
|
||||||
if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
|
if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
|
||||||
@@ -162,7 +162,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
|
|||||||
subviews.push_back(subview);
|
subviews.push_back(subview);
|
||||||
|
|
||||||
// For unit-dims, two subviews should be created for the odd and even
|
// For unit-dims, two subviews should be created for the odd and even
|
||||||
// element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
|
// element in the VNNI tuple (2xbf16) because x86.avx.bcst_to_f32.packed
|
||||||
// op loads and broadcast the first BF16 element into packed F32. It
|
// op loads and broadcast the first BF16 element into packed F32. It
|
||||||
// cannot distinguish between even and odd BF16 elements within a
|
// cannot distinguish between even and odd BF16 elements within a
|
||||||
// packed pair.
|
// packed pair.
|
||||||
@@ -193,11 +193,11 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
|
|||||||
// ```
|
// ```
|
||||||
// to
|
// to
|
||||||
// ```
|
// ```
|
||||||
// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
|
// %1 = x86.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
|
||||||
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
|
// %2 = x86.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
|
||||||
// %3 = vector.fma %1, %2, %arg1
|
// %3 = vector.fma %1, %2, %arg1
|
||||||
// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
|
// %4 = x86.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
|
||||||
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
|
// %5 = x86.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
|
||||||
// return vector.fma %4, %5, %3
|
// return vector.fma %4, %5, %3
|
||||||
// ```
|
// ```
|
||||||
//
|
//
|
||||||
@@ -212,10 +212,10 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
|
|||||||
// ```
|
// ```
|
||||||
// to
|
// to
|
||||||
// ```
|
// ```
|
||||||
// %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
|
// %1 = x86.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
|
||||||
// %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
|
// %2 = x86.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
|
||||||
// %3 = vector.fma %1, %2, %arg1
|
// %3 = vector.fma %1, %2, %arg1
|
||||||
// %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
|
// %4 = x86.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
|
||||||
// %5 = vector.fma %1, %4, %arg2
|
// %5 = vector.fma %1, %4, %arg2
|
||||||
// scf.yield %3, %5
|
// scf.yield %3, %5
|
||||||
struct VectorContractBF16ToFMA
|
struct VectorContractBF16ToFMA
|
||||||
@@ -446,11 +446,10 @@ struct VectorContractBF16ToFMA
|
|||||||
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
|
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
|
||||||
contractOp.getAcc());
|
contractOp.getAcc());
|
||||||
|
|
||||||
auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
|
auto loadBcstBF16ElementToF32 = x86::BcstToPackedF32Op::create(
|
||||||
rewriter, loc, dstType, unitDimSubview[0]);
|
rewriter, loc, dstType, unitDimSubview[0]);
|
||||||
auto loadEvenIdxElementF32 =
|
auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
|
||||||
x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
|
rewriter, loc, dstType, nonUnitDimSubview[0]);
|
||||||
nonUnitDimSubview[0]);
|
|
||||||
auto evenIdxFMA =
|
auto evenIdxFMA =
|
||||||
vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
|
vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
|
||||||
loadEvenIdxElementF32, castAcc);
|
loadEvenIdxElementF32, castAcc);
|
||||||
@@ -468,7 +467,7 @@ struct VectorContractBF16ToFMA
|
|||||||
accTyPairCont.getElementType()),
|
accTyPairCont.getElementType()),
|
||||||
pairContractOp.getAcc());
|
pairContractOp.getAcc());
|
||||||
|
|
||||||
auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
|
auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
|
||||||
rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
|
rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
|
||||||
auto oddIdxFMA = vector::FMAOp::create(
|
auto oddIdxFMA = vector::FMAOp::create(
|
||||||
rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
|
rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
|
||||||
@@ -481,18 +480,18 @@ struct VectorContractBF16ToFMA
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load, broadcast, and do FMA for odd indexed BF16 elements.
|
// Load, broadcast, and do FMA for odd indexed BF16 elements.
|
||||||
auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
|
auto loadBcstOddIdxElementToF32 = x86::BcstToPackedF32Op::create(
|
||||||
rewriter, loc, dstType, unitDimSubview[0]);
|
rewriter, loc, dstType, unitDimSubview[0]);
|
||||||
auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
|
auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
|
||||||
rewriter, loc, dstType, nonUnitDimSubview[0]);
|
rewriter, loc, dstType, nonUnitDimSubview[0]);
|
||||||
auto oddIdxFMA =
|
auto oddIdxFMA =
|
||||||
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
|
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
|
||||||
loadOddIdxElementF32, castAcc);
|
loadOddIdxElementF32, castAcc);
|
||||||
|
|
||||||
// Load, broadcast, and do FMA for even indexed BF16 elements.
|
// Load, broadcast, and do FMA for even indexed BF16 elements.
|
||||||
auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
|
auto loadBcstEvenIdxElementToF32 = x86::BcstToPackedF32Op::create(
|
||||||
rewriter, loc, dstType, unitDimSubview[1]);
|
rewriter, loc, dstType, unitDimSubview[1]);
|
||||||
auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
|
auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
|
||||||
rewriter, loc, dstType, nonUnitDimSubview[0]);
|
rewriter, loc, dstType, nonUnitDimSubview[0]);
|
||||||
vector::FMAOp fma =
|
vector::FMAOp fma =
|
||||||
vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
|
vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
|
||||||
@@ -504,7 +503,6 @@ struct VectorContractBF16ToFMA
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void x86vector::populateVectorContractBF16ToFMAPatterns(
|
void x86::populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns) {
|
||||||
RewritePatternSet &patterns) {
|
|
||||||
patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
|
patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
|
||||||
}
|
}
|
||||||
@@ -8,8 +8,8 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@@ -137,7 +137,6 @@ struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void x86vector::populateVectorContractToFMAPatterns(
|
void x86::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) {
|
||||||
RewritePatternSet &patterns) {
|
|
||||||
patterns.add<VectorContractToFMA>(patterns.getContext());
|
patterns.add<VectorContractToFMA>(patterns.getContext());
|
||||||
}
|
}
|
||||||
@@ -11,9 +11,9 @@
|
|||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
#include "mlir/Dialect/X86/Transforms.h"
|
||||||
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
|
#include "mlir/Dialect/X86/Utils/X86Utils.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
@@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
using namespace mlir::x86vector;
|
using namespace mlir::x86;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
|
|||||||
// to
|
// to
|
||||||
// ```
|
// ```
|
||||||
// vector.broadcast %lhs to <32xbf16>
|
// vector.broadcast %lhs to <32xbf16>
|
||||||
// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
|
// x86.avx512.dot vector<32xbf16> -> vector<16xf32>
|
||||||
// ```
|
// ```
|
||||||
//
|
//
|
||||||
// For example - for bf16 type (Flat layout):
|
// For example - for bf16 type (Flat layout):
|
||||||
@@ -126,9 +126,9 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
|
|||||||
// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
|
// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
|
||||||
// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
|
// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
|
||||||
// vector.broadcast %lhs to <32xbf16>
|
// vector.broadcast %lhs to <32xbf16>
|
||||||
// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
|
// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
|
||||||
// vector.broadcast %lhs to <32xbf16>
|
// vector.broadcast %lhs to <32xbf16>
|
||||||
// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
|
// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
|
||||||
// ```
|
// ```
|
||||||
struct VectorContractToPackedTypeDotProduct
|
struct VectorContractToPackedTypeDotProduct
|
||||||
: public OpRewritePattern<vector::ContractionOp> {
|
: public OpRewritePattern<vector::ContractionOp> {
|
||||||
@@ -384,7 +384,7 @@ struct VectorContractToPackedTypeDotProduct
|
|||||||
rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
|
rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
|
||||||
|
|
||||||
if (lhsTy.getElementType().isBF16()) {
|
if (lhsTy.getElementType().isBF16()) {
|
||||||
dp = x86vector::DotBF16Op::create(
|
dp = x86::DotBF16Op::create(
|
||||||
rewriter, loc,
|
rewriter, loc,
|
||||||
VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
|
VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
|
||||||
bitcastUnitDimPkType, castNonUnitDim);
|
bitcastUnitDimPkType, castNonUnitDim);
|
||||||
@@ -392,12 +392,12 @@ struct VectorContractToPackedTypeDotProduct
|
|||||||
|
|
||||||
if (lhsTy.getElementType().isSignlessInteger(8)) {
|
if (lhsTy.getElementType().isSignlessInteger(8)) {
|
||||||
if (nonUnitDimAcc.front() == 16) {
|
if (nonUnitDimAcc.front() == 16) {
|
||||||
dp = x86vector::AVX10DotInt8Op::create(
|
dp = x86::AVX10DotInt8Op::create(
|
||||||
rewriter, loc,
|
rewriter, loc,
|
||||||
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
|
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
|
||||||
castAcc, bitcastUnitDimPkType, castNonUnitDim);
|
castAcc, bitcastUnitDimPkType, castNonUnitDim);
|
||||||
} else {
|
} else {
|
||||||
dp = x86vector::DotInt8Op::create(
|
dp = x86::DotInt8Op::create(
|
||||||
rewriter, loc,
|
rewriter, loc,
|
||||||
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
|
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
|
||||||
castAcc, bitcastUnitDimPkType, castNonUnitDim);
|
castAcc, bitcastUnitDimPkType, castNonUnitDim);
|
||||||
@@ -415,7 +415,7 @@ struct VectorContractToPackedTypeDotProduct
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
|
void x86::populateVectorContractToPackedTypeDotProductPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
|
patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
add_mlir_dialect_library(MLIRX86VectorUtils
|
add_mlir_dialect_library(MLIRX86Utils
|
||||||
X86VectorUtils.cpp
|
X86Utils.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector/Utils
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86/Utils
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRAffineDialect
|
MLIRAffineDialect
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===//
|
//===- X86Utils.cpp - MLIR Utilities for X86Ops -------------------------===//
|
||||||
//
|
//
|
||||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
|
#include "mlir/Dialect/X86/Utils/X86Utils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
||||||
@@ -23,7 +23,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace x86vector {
|
namespace x86 {
|
||||||
|
|
||||||
static FailureOr<SmallVector<mlir::utils::IteratorType>>
|
static FailureOr<SmallVector<mlir::utils::IteratorType>>
|
||||||
inferIteratorsFromOutMap(AffineMap map) {
|
inferIteratorsFromOutMap(AffineMap map) {
|
||||||
@@ -410,5 +410,5 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace x86vector
|
} // namespace x86
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
@@ -96,7 +96,7 @@
|
|||||||
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
|
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
|
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
|
#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/Interfaces/CastInterfaces.h"
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
@@ -152,7 +152,7 @@ void mlir::registerAllDialects(DialectRegistry ®istry) {
|
|||||||
ub::UBDialect,
|
ub::UBDialect,
|
||||||
vector::VectorDialect,
|
vector::VectorDialect,
|
||||||
wasmssa::WasmSSADialect,
|
wasmssa::WasmSSADialect,
|
||||||
x86vector::X86VectorDialect,
|
x86::X86Dialect,
|
||||||
xegpu::XeGPUDialect,
|
xegpu::XeGPUDialect,
|
||||||
xevm::XeVMDialect>();
|
xevm::XeVMDialect>();
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|||||||
@@ -56,7 +56,7 @@
|
|||||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
|
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
|
||||||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
|
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
|
||||||
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
||||||
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
|
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
|
||||||
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
|
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
|
||||||
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
|
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
|
||||||
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
|
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
|
||||||
@@ -114,7 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
|
|||||||
transform::registerSMTExtension(registry);
|
transform::registerSMTExtension(registry);
|
||||||
transform::registerTuneExtension(registry);
|
transform::registerTuneExtension(registry);
|
||||||
vector::registerTransformDialectExtension(registry);
|
vector::registerTransformDialectExtension(registry);
|
||||||
x86vector::registerTransformDialectExtension(registry);
|
x86::registerTransformDialectExtension(registry);
|
||||||
xegpu::registerTransformDialectExtension(registry);
|
xegpu::registerTransformDialectExtension(registry);
|
||||||
arm_neon::registerTransformDialectExtension(registry);
|
arm_neon::registerTransformDialectExtension(registry);
|
||||||
arm_sve::registerTransformDialectExtension(registry);
|
arm_sve::registerTransformDialectExtension(registry);
|
||||||
|
|||||||
@@ -328,11 +328,11 @@ declare_mlir_dialect_extension_python_bindings(
|
|||||||
declare_mlir_dialect_extension_python_bindings(
|
declare_mlir_dialect_extension_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
TD_FILE dialects/X86VectorTransformOps.td
|
TD_FILE dialects/X86TransformOps.td
|
||||||
SOURCES
|
SOURCES
|
||||||
dialects/transform/x86vector.py
|
dialects/transform/x86.py
|
||||||
DIALECT_NAME transform
|
DIALECT_NAME transform
|
||||||
EXTENSION_NAME x86vector_transform)
|
EXTENSION_NAME x86_transform)
|
||||||
|
|
||||||
declare_mlir_dialect_extension_python_bindings(
|
declare_mlir_dialect_extension_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
@@ -522,9 +522,9 @@ declare_mlir_dialect_python_bindings(
|
|||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||||
TD_FILE dialects/X86Vector.td
|
TD_FILE dialects/X86.td
|
||||||
SOURCES dialects/x86vector.py
|
SOURCES dialects/x86.py
|
||||||
DIALECT_NAME x86vector)
|
DIALECT_NAME x86)
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
|
//===-- X86.td - Entry point for x86 bindings --------------*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,9 +6,9 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef PYTHON_BINDINGS_X86VECTOR
|
#ifndef PYTHON_BINDINGS_X86
|
||||||
#define PYTHON_BINDINGS_X86VECTOR
|
#define PYTHON_BINDINGS_X86
|
||||||
|
|
||||||
include "mlir/Dialect/X86Vector/X86Vector.td"
|
include "mlir/Dialect/X86/X86.td"
|
||||||
|
|
||||||
#endif // PYTHON_BINDINGS_X86VECTOR
|
#endif // PYTHON_BINDINGS_X86
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
|
//===-- X86TransformOps.td ---------------------------------*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
@@ -6,9 +6,9 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
|
#ifndef PYTHON_BINDINGS_X86TRANSFORMOPS
|
||||||
#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
|
#define PYTHON_BINDINGS_X86TRANSFORMOPS
|
||||||
|
|
||||||
include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
|
include "mlir/Dialect/X86/TransformOps/X86TransformOps.td"
|
||||||
|
|
||||||
#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
|
#endif // PYTHON_BINDINGS_X86TRANSFORMOPS
|
||||||
@@ -2,4 +2,4 @@
|
|||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
from .._x86vector_transform_ops_gen import *
|
from .._x86_transform_ops_gen import *
|
||||||
@@ -2,5 +2,5 @@
|
|||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
from ._x86vector_ops_gen import *
|
from ._x86_ops_gen import *
|
||||||
from ._x86vector_ops_gen import _Dialect
|
from ._x86_ops_gen import _Dialect
|
||||||
@@ -33,7 +33,7 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
|
|||||||
set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
|
set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
|
||||||
"Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
|
"Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
|
||||||
option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
|
option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
|
||||||
option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.")
|
option(MLIR_RUN_X86_TESTS "Run X86 tests.")
|
||||||
option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
|
option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
|
||||||
option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
|
option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
|
||||||
option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
|
option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
|
||||||
@@ -78,7 +78,7 @@ llvm_canonicalize_cmake_booleans(
|
|||||||
MLIR_INCLUDE_INTEGRATION_TESTS
|
MLIR_INCLUDE_INTEGRATION_TESTS
|
||||||
MLIR_RUN_AMX_TESTS
|
MLIR_RUN_AMX_TESTS
|
||||||
MLIR_RUN_CUDA_TENSOR_CORE_TESTS
|
MLIR_RUN_CUDA_TENSOR_CORE_TESTS
|
||||||
MLIR_RUN_X86VECTOR_TESTS
|
MLIR_RUN_X86_TESTS
|
||||||
MLIR_RUN_ARM_SVE_TESTS
|
MLIR_RUN_ARM_SVE_TESTS
|
||||||
MLIR_RUN_ARM_SME_TESTS
|
MLIR_RUN_ARM_SME_TESTS
|
||||||
MLIR_RUN_CUDA_SM80_TESTS
|
MLIR_RUN_CUDA_SM80_TESTS
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
|
// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
|
||||||
// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
|
// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
|
||||||
// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
|
// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
|
||||||
// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}}
|
// CHECK-SAME: enable-x86={{[aA-zZ0-9]+}}
|
||||||
// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
|
// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
|
||||||
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
|
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
|
||||||
// DEFAULT: vector-contract-lowering=dot
|
// DEFAULT: vector-contract-lowering=dot
|
||||||
|
|||||||
@@ -691,7 +691,7 @@ func.func @rsqrt_scalar(%arg0: f32) -> f32 {
|
|||||||
// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32>
|
// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32>
|
||||||
// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32>
|
// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32>
|
||||||
// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1>
|
// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1>
|
||||||
// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32>
|
// AVX2: %[[VAL_9:.*]] = x86.avx.rsqrt %[[VAL_0]] : vector<8xf32>
|
||||||
// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32>
|
// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32>
|
||||||
// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32>
|
// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32>
|
||||||
// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32>
|
// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32>
|
||||||
@@ -725,9 +725,9 @@ func.func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> {
|
|||||||
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
|
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
|
||||||
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32>
|
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32>
|
||||||
// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0]
|
// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0]
|
||||||
// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
|
// AVX2: %[[RSQRT0:.*]] = x86.avx.rsqrt %[[VEC0]]
|
||||||
// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1]
|
// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1]
|
||||||
// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
|
// AVX2: %[[RSQRT1:.*]] = x86.avx.rsqrt %[[VEC1]]
|
||||||
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
|
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
|
||||||
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
|
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
|
||||||
// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32>
|
// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32>
|
||||||
@@ -746,9 +746,9 @@ func.func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> {
|
|||||||
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
|
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
|
||||||
// AVX2-NOT: vector.shape_cast
|
// AVX2-NOT: vector.shape_cast
|
||||||
// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0]
|
// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0]
|
||||||
// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
|
// AVX2: %[[RSQRT0:.*]] = x86.avx.rsqrt %[[VEC0]]
|
||||||
// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1]
|
// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1]
|
||||||
// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
|
// AVX2: %[[RSQRT1:.*]] = x86.avx.rsqrt %[[VEC1]]
|
||||||
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
|
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
|
||||||
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
|
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
|
||||||
// AVX2-NOT: vector.shape_cast
|
// AVX2-NOT: vector.shape_cast
|
||||||
@@ -768,13 +768,13 @@ func.func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> {
|
|||||||
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32>
|
// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32>
|
||||||
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32>
|
// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32>
|
||||||
// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0]
|
// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0]
|
||||||
// AVX2: %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]]
|
// AVX2: %[[RSQRT00:.*]] = x86.avx.rsqrt %[[VEC00]]
|
||||||
// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1]
|
// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1]
|
||||||
// AVX2: %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]]
|
// AVX2: %[[RSQRT01:.*]] = x86.avx.rsqrt %[[VEC01]]
|
||||||
// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0]
|
// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0]
|
||||||
// AVX2: %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]]
|
// AVX2: %[[RSQRT10:.*]] = x86.avx.rsqrt %[[VEC10]]
|
||||||
// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1]
|
// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1]
|
||||||
// AVX2: %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]]
|
// AVX2: %[[RSQRT11:.*]] = x86.avx.rsqrt %[[VEC11]]
|
||||||
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0]
|
// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0]
|
||||||
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1]
|
// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1]
|
||||||
// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0]
|
// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
|
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
|
||||||
|
|
||||||
// NOTE: This file tests lowerings that are implemented in the X86Vector
|
// NOTE: This file tests lowerings that are implemented in the X86
|
||||||
// dialect. Since X86 does not support scalable vectors, all examples in this
|
// dialect. Since X86 does not support scalable vectors, all examples in this
|
||||||
// file use fixed-width vectors.
|
// file use fixed-width vectors.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// REQUIRES: target=x86{{.*}}
|
// REQUIRES: target=x86{{.*}}
|
||||||
|
|
||||||
// RUN: mlir-opt %s \
|
// RUN: mlir-opt %s \
|
||||||
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
|
// RUN: -convert-vector-to-llvm="enable-x86" -convert-to-llvm \
|
||||||
// RUN: -reconcile-unrealized-casts | \
|
// RUN: -reconcile-unrealized-casts | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: llc -mcpu=sapphirerapids | \
|
// RUN: llc -mcpu=sapphirerapids | \
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
||||||
%a: vector<8xf32>) -> vector<8xbf16> {
|
%a: vector<8xf32>) -> vector<8xbf16> {
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
||||||
return %0 : vector<8xbf16>
|
return %0 : vector<8xbf16>
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
|
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
|
||||||
@@ -17,7 +17,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
|||||||
|
|
||||||
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
||||||
%a: vector<16xf32>) -> vector<16xbf16> {
|
%a: vector<16xf32>) -> vector<16xbf16> {
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
||||||
return %0 : vector<16xbf16>
|
return %0 : vector<16xbf16>
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
|
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
// REQUIRES: target=x86{{.*}}
|
// REQUIRES: target=x86{{.*}}
|
||||||
|
|
||||||
// RUN: mlir-opt %s \
|
// RUN: mlir-opt %s \
|
||||||
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
|
// RUN: -convert-vector-to-llvm="enable-x86" -convert-to-llvm \
|
||||||
// RUN: -reconcile-unrealized-casts | \
|
// RUN: -reconcile-unrealized-casts | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: llc -mcpu=sapphirerapids | \
|
// RUN: llc -mcpu=sapphirerapids | \
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
||||||
%b: vector<8xbf16>) -> vector<4xf32> {
|
%b: vector<8xbf16>) -> vector<4xf32> {
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: avx512bf16_dot_128:
|
// CHECK-LABEL: avx512bf16_dot_128:
|
||||||
@@ -17,7 +17,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
|||||||
|
|
||||||
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
||||||
%b: vector<16xbf16>) -> vector<8xf32> {
|
%b: vector<16xbf16>) -> vector<8xf32> {
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: avx512bf16_dot_256:
|
// CHECK-LABEL: avx512bf16_dot_256:
|
||||||
@@ -25,7 +25,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
|||||||
|
|
||||||
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
||||||
%b: vector<32xbf16>) -> vector<16xf32> {
|
%b: vector<32xbf16>) -> vector<16xf32> {
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||||
return %0 : vector<16xf32>
|
return %0 : vector<16xf32>
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: avx512bf16_dot_512:
|
// CHECK-LABEL: avx512bf16_dot_512:
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" | mlir-opt | FileCheck %s
|
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @avx512_mask_rndscale
|
// CHECK-LABEL: func @avx512_mask_rndscale
|
||||||
func.func @avx512_mask_rndscale(
|
func.func @avx512_mask_rndscale(
|
||||||
@@ -9,14 +9,14 @@ func.func @avx512_mask_rndscale(
|
|||||||
%rnd_k = arith.constant 15 : i32
|
%rnd_k = arith.constant 15 : i32
|
||||||
%rnd = arith.constant 42 : i32
|
%rnd = arith.constant 42 : i32
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512"
|
||||||
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32>
|
%0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32>
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512"
|
||||||
%1 = x86vector.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64>
|
%1 = x86.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64>
|
||||||
|
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512"
|
||||||
%2 = x86vector.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32>
|
%2 = x86.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32>
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512"
|
||||||
%3 = x86vector.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64>
|
%3 = x86.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64>
|
||||||
|
|
||||||
// Keep results alive.
|
// Keep results alive.
|
||||||
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
|
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
|
||||||
@@ -29,13 +29,13 @@ func.func @avx512_mask_compress(
|
|||||||
{
|
{
|
||||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>)
|
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>)
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
||||||
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32>
|
%0 = x86.avx512.mask.compress %k1, %a1 : vector<16xf32>
|
||||||
// CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>)
|
// CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>)
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
||||||
%1 = x86vector.avx512.mask.compress %k1, %a1
|
%1 = x86.avx512.mask.compress %k1, %a1
|
||||||
{constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
{constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
||||||
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
%2 = x86.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
||||||
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,9 +44,9 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
|||||||
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
|
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512"
|
||||||
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32>
|
%0, %1 = x86.avx512.vp2intersect %a, %a : vector<16xi32>
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512"
|
||||||
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
|
%2, %3 = x86.avx512.vp2intersect %b, %b : vector<8xi64>
|
||||||
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
|
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
|||||||
%b: vector<8xbf16>) -> (vector<4xf32>)
|
%b: vector<8xbf16>) -> (vector<4xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128"
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
|||||||
%b: vector<16xbf16>) -> (vector<8xf32>)
|
%b: vector<16xbf16>) -> (vector<8xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256"
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
|||||||
%b: vector<32xbf16>) -> (vector<16xf32>)
|
%b: vector<32xbf16>) -> (vector<16xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512"
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||||
return %0 : vector<16xf32>
|
return %0 : vector<16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
|||||||
%a: vector<8xf32>) -> (vector<8xbf16>)
|
%a: vector<8xf32>) -> (vector<8xbf16>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256"
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
||||||
return %0 : vector<8xbf16>
|
return %0 : vector<8xbf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
|||||||
%a: vector<16xf32>) -> (vector<16xbf16>)
|
%a: vector<16xf32>) -> (vector<16xbf16>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512"
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
||||||
return %0 : vector<16xbf16>
|
return %0 : vector<16xbf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
|||||||
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
||||||
%b: vector<64xi8>) -> vector<16xi32> {
|
%b: vector<64xi8>) -> vector<16xi32> {
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
|
||||||
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
%0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
||||||
return %0 : vector<16xi32>
|
return %0 : vector<16xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
|
|||||||
%a: memref<8xbf16>) -> vector<4xf32>
|
%a: memref<8xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
|
|||||||
%a: memref<16xbf16>) -> vector<8xf32>
|
%a: memref<16xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
|
|||||||
%a: memref<8xbf16>) -> vector<4xf32>
|
%a: memref<8xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +135,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
|
|||||||
%a: memref<16xbf16>) -> vector<8xf32>
|
%a: memref<16xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
|
|||||||
%a: memref<1xbf16>) -> vector<4xf32>
|
%a: memref<1xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
|
|||||||
%a: memref<1xbf16>) -> vector<8xf32>
|
%a: memref<1xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,7 +162,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
|
|||||||
%a: memref<8xf16>) -> vector<4xf32>
|
%a: memref<8xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,7 +171,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
|
|||||||
%a: memref<16xf16>) -> vector<8xf32>
|
%a: memref<16xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
|
|||||||
%a: memref<8xf16>) -> vector<4xf32>
|
%a: memref<8xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
|
|||||||
%a: memref<16xf16>) -> vector<8xf32>
|
%a: memref<16xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +198,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_128(
|
|||||||
%a: memref<1xf16>) -> vector<4xf32>
|
%a: memref<1xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
|
|||||||
%a: memref<1xf16>) -> vector<8xf32>
|
%a: memref<1xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +215,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
|
|||||||
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
|
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256"
|
||||||
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
|
%0 = x86.avx.rsqrt %a : vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,7 +224,7 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
|
|||||||
{
|
{
|
||||||
// CHECK: llvm.mlir.constant(-1 : i8)
|
// CHECK: llvm.mlir.constant(-1 : i8)
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256"
|
||||||
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,7 +232,7 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
|
|||||||
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
||||||
%b: vector<16xi8>) -> vector<4xi32> {
|
%b: vector<16xi8>) -> vector<4xi32> {
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
|
||||||
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
|
%0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
|
||||||
return %0 : vector<4xi32>
|
return %0 : vector<4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +240,6 @@ func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
|||||||
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
|
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
|
||||||
%b: vector<32xi8>) -> vector<8xi32> {
|
%b: vector<32xi8>) -> vector<8xi32> {
|
||||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
|
// CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
|
||||||
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
%0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
||||||
return %0 : vector<8xi32>
|
return %0 : vector<8xi32>
|
||||||
}
|
}
|
||||||
@@ -4,10 +4,10 @@
|
|||||||
func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
|
func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
|
||||||
-> (vector<16xf32>, vector<8xf64>)
|
-> (vector<16xf32>, vector<8xf64>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<16xf32>
|
// CHECK: x86.avx512.mask.rndscale {{.*}}: vector<16xf32>
|
||||||
%0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>
|
%0 = x86.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>
|
||||||
// CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<8xf64>
|
// CHECK: x86.avx512.mask.rndscale {{.*}}: vector<8xf64>
|
||||||
%1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>
|
%1 = x86.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>
|
||||||
return %0, %1: vector<16xf32>, vector<8xf64>
|
return %0, %1: vector<16xf32>, vector<8xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -15,10 +15,10 @@ func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32
|
|||||||
func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
|
func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
|
||||||
-> (vector<16xf32>, vector<8xf64>)
|
-> (vector<16xf32>, vector<8xf64>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<16xf32>
|
// CHECK: x86.avx512.mask.scalef {{.*}}: vector<16xf32>
|
||||||
%0 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
|
%0 = x86.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
|
||||||
// CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<8xf64>
|
// CHECK: x86.avx512.mask.scalef {{.*}}: vector<8xf64>
|
||||||
%1 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
|
%1 = x86.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
|
||||||
return %0, %1: vector<16xf32>, vector<8xf64>
|
return %0, %1: vector<16xf32>, vector<8xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,10 +26,10 @@ func.func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16:
|
|||||||
func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
||||||
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
|
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<16xi32>
|
// CHECK: x86.avx512.vp2intersect {{.*}} : vector<16xi32>
|
||||||
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32>
|
%0, %1 = x86.avx512.vp2intersect %a, %a : vector<16xi32>
|
||||||
// CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<8xi64>
|
// CHECK: x86.avx512.vp2intersect {{.*}} : vector<8xi64>
|
||||||
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
|
%2, %3 = x86.avx512.vp2intersect %b, %b : vector<8xi64>
|
||||||
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
|
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,12 +38,12 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
|
|||||||
%k2: vector<8xi1>, %a2: vector<8xi64>)
|
%k2: vector<8xi1>, %a2: vector<8xi64>)
|
||||||
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
|
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32>
|
// CHECK: x86.avx512.mask.compress {{.*}} : vector<16xf32>
|
||||||
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32>
|
%0 = x86.avx512.mask.compress %k1, %a1 : vector<16xf32>
|
||||||
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32>
|
// CHECK: x86.avx512.mask.compress {{.*}} : vector<16xf32>
|
||||||
%1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
%1 = x86.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||||
// CHECK: x86vector.avx512.mask.compress {{.*}} : vector<8xi64>
|
// CHECK: x86.avx512.mask.compress {{.*}} : vector<8xi64>
|
||||||
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
%2 = x86.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
||||||
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,8 +51,8 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
|
|||||||
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
||||||
%b: vector<8xbf16>) -> (vector<4xf32>)
|
%b: vector<8xbf16>) -> (vector<4xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
|
// CHECK: x86.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,8 +60,8 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
|||||||
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
||||||
%b: vector<16xbf16>) -> (vector<8xf32>)
|
%b: vector<16xbf16>) -> (vector<8xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
|
// CHECK: x86.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,8 +69,8 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
|||||||
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
||||||
%b: vector<32xbf16>) -> (vector<16xf32>)
|
%b: vector<32xbf16>) -> (vector<16xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
|
// CHECK: x86.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||||
return %0 : vector<16xf32>
|
return %0 : vector<16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,9 +78,9 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
|||||||
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
||||||
%a: vector<8xf32>) -> (vector<8xbf16>)
|
%a: vector<8xf32>) -> (vector<8xbf16>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
|
// CHECK: x86.avx512.cvt.packed.f32_to_bf16 {{.*}} :
|
||||||
// CHECK-SAME: vector<8xf32> -> vector<8xbf16>
|
// CHECK-SAME: vector<8xf32> -> vector<8xbf16>
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
||||||
return %0 : vector<8xbf16>
|
return %0 : vector<8xbf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,17 +88,17 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
|||||||
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
||||||
%a: vector<16xf32>) -> (vector<16xbf16>)
|
%a: vector<16xf32>) -> (vector<16xbf16>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
|
// CHECK: x86.avx512.cvt.packed.f32_to_bf16 {{.*}} :
|
||||||
// CHECK-SAME: vector<16xf32> -> vector<16xbf16>
|
// CHECK-SAME: vector<16xf32> -> vector<16xbf16>
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
||||||
return %0 : vector<16xbf16>
|
return %0 : vector<16xbf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avx10_dot_i8_512
|
// CHECK-LABEL: func @avx10_dot_i8_512
|
||||||
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
||||||
%b: vector<64xi8>) -> vector<16xi32> {
|
%b: vector<64xi8>) -> vector<16xi32> {
|
||||||
// CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
|
// CHECK: x86.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
|
||||||
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
%0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
||||||
return %0 : vector<16xi32>
|
return %0 : vector<16xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,9 +106,9 @@ func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
|||||||
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
|
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
|
||||||
%a: memref<8xbf16>) -> vector<4xf32>
|
%a: memref<8xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
|
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,9 +116,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
|
|||||||
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
|
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
|
||||||
%a: memref<16xbf16>) -> vector<8xf32>
|
%a: memref<16xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
|
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,9 +126,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
|
|||||||
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
|
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
|
||||||
%a: memref<8xbf16>) -> vector<4xf32>
|
%a: memref<8xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
|
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,9 +136,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
|
|||||||
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
|
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
|
||||||
%a: memref<16xbf16>) -> vector<8xf32>
|
%a: memref<16xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
|
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,9 +146,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
|
|||||||
func.func @avxbf16_bcst_bf16_to_f32_128(
|
func.func @avxbf16_bcst_bf16_to_f32_128(
|
||||||
%a: memref<1xbf16>) -> vector<4xf32>
|
%a: memref<1xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
|
||||||
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
|
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,9 +156,9 @@ func.func @avxbf16_bcst_bf16_to_f32_128(
|
|||||||
func.func @avxbf16_bcst_bf16_to_f32_256(
|
func.func @avxbf16_bcst_bf16_to_f32_256(
|
||||||
%a: memref<1xbf16>) -> vector<8xf32>
|
%a: memref<1xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
|
||||||
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
|
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,9 +166,9 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
|
|||||||
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
|
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
|
||||||
%a: memref<8xf16>) -> vector<4xf32>
|
%a: memref<8xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
|
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,9 +176,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
|
|||||||
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
|
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
|
||||||
%a: memref<16xf16>) -> vector<8xf32>
|
%a: memref<16xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
|
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,9 +186,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
|
|||||||
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
|
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
|
||||||
%a: memref<8xf16>) -> vector<4xf32>
|
%a: memref<8xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
|
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,9 +196,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
|
|||||||
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
|
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
|
||||||
%a: memref<16xf16>) -> vector<8xf32>
|
%a: memref<16xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
|
||||||
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
|
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,9 +206,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
|
|||||||
func.func @avxf16_bcst_f16_to_f32_128(
|
func.func @avxf16_bcst_f16_to_f32_128(
|
||||||
%a: memref<1xf16>) -> vector<4xf32>
|
%a: memref<1xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
|
||||||
// CHECK-SAME: memref<1xf16> -> vector<4xf32>
|
// CHECK-SAME: memref<1xf16> -> vector<4xf32>
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,40 +216,40 @@ func.func @avxf16_bcst_f16_to_f32_128(
|
|||||||
func.func @avxf16_bcst_f16_to_f32_256(
|
func.func @avxf16_bcst_f16_to_f32_256(
|
||||||
%a: memref<1xf16>) -> vector<8xf32>
|
%a: memref<1xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} :
|
||||||
// CHECK-SAME: memref<1xf16> -> vector<8xf32>
|
// CHECK-SAME: memref<1xf16> -> vector<8xf32>
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avx_rsqrt
|
// CHECK-LABEL: func @avx_rsqrt
|
||||||
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
|
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.rsqrt {{.*}} : vector<8xf32>
|
// CHECK: x86.avx.rsqrt {{.*}} : vector<8xf32>
|
||||||
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
|
%0 = x86.avx.rsqrt %a : vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avx_dot
|
// CHECK-LABEL: func @avx_dot
|
||||||
func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
|
func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
|
||||||
{
|
{
|
||||||
// CHECK: x86vector.avx.intr.dot {{.*}} : vector<8xf32>
|
// CHECK: x86.avx.intr.dot {{.*}} : vector<8xf32>
|
||||||
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avx_dot_i8_128
|
// CHECK-LABEL: func @avx_dot_i8_128
|
||||||
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
||||||
%b: vector<16xi8>) -> vector<4xi32> {
|
%b: vector<16xi8>) -> vector<4xi32> {
|
||||||
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
|
// CHECK: x86.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
|
||||||
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
|
%0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
|
||||||
return %0 : vector<4xi32>
|
return %0 : vector<4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @avx_dot_i8_256
|
// CHECK-LABEL: func @avx_dot_i8_256
|
||||||
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
|
func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
|
||||||
%b: vector<32xi8>) -> vector<8xi32> {
|
%b: vector<32xi8>) -> vector<8xi32> {
|
||||||
// CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
|
// CHECK: x86.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
|
||||||
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
%0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
||||||
return %0 : vector<8xi32>
|
return %0 : vector<8xi32>
|
||||||
}
|
}
|
||||||
@@ -8,17 +8,17 @@ func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
|
|||||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
||||||
{
|
{
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%5 = vector.fma %3, %4, %2 : !vec
|
%5 = vector.fma %3, %4, %2 : !vec
|
||||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||||
return %12 : !vec
|
return %12 : !vec
|
||||||
@@ -28,25 +28,25 @@ func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
|
|||||||
// The vector.fma at %5 is moved along with its operands after %8.
|
// The vector.fma at %5 is moved along with its operands after %8.
|
||||||
// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
|
// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
|
||||||
// Odd-Indexed FMAs
|
// Odd-Indexed FMAs
|
||||||
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
|
// CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
|
||||||
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
// CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
||||||
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
||||||
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
|
// CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
|
||||||
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
// CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
||||||
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
||||||
// Even-Indexed FMAs
|
// Even-Indexed FMAs
|
||||||
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
|
// CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg4
|
||||||
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
|
// CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
|
||||||
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
|
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
|
||||||
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
|
// CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg1
|
||||||
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
|
// CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
|
||||||
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
|
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -62,17 +62,17 @@ func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
|
|||||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
||||||
{
|
{
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%5 = vector.fma %4, %3, %2 : !vec
|
%5 = vector.fma %4, %3, %2 : !vec
|
||||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||||
return %12 : !vec
|
return %12 : !vec
|
||||||
@@ -81,25 +81,25 @@ func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
|
|||||||
// The vector.fma at %5 is moved along with its operands after %8.
|
// The vector.fma at %5 is moved along with its operands after %8.
|
||||||
// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
|
// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
|
||||||
// Odd-Indexed FMAs
|
// Odd-Indexed FMAs
|
||||||
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
|
// CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
|
||||||
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
// CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
||||||
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
||||||
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
|
// CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
|
||||||
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
// CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
||||||
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
||||||
// Even-Indexed FMAs
|
// Even-Indexed FMAs
|
||||||
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
|
// CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg4
|
||||||
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
|
// CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
|
||||||
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
|
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
|
||||||
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
|
// CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
|
||||||
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
|
// CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg1
|
||||||
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
|
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -116,18 +116,18 @@ func.func @shuffle_fma_with_shape_cast(
|
|||||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
|
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
|
||||||
{
|
{
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%5 = vector.fma %3, %4, %2 : !vec
|
%5 = vector.fma %3, %4, %2 : !vec
|
||||||
%res1 = vector.shape_cast %5 : !vec to !vecOut
|
%res1 = vector.shape_cast %5 : !vec to !vecOut
|
||||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%res2 = vector.shape_cast %11 : !vec to !vecOut
|
%res2 = vector.shape_cast %11 : !vec to !vecOut
|
||||||
%12 = arith.addf %res1, %res2 : !vecOut
|
%12 = arith.addf %res1, %res2 : !vecOut
|
||||||
@@ -136,19 +136,19 @@ func.func @shuffle_fma_with_shape_cast(
|
|||||||
|
|
||||||
// CHECK-LABEL: @shuffle_fma_with_shape_cast
|
// CHECK-LABEL: @shuffle_fma_with_shape_cast
|
||||||
// Odd-Indexed FMAs
|
// Odd-Indexed FMAs
|
||||||
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
|
// CHECK: %[[BCST0:.*]] = x86.avx.bcst_to_f32.packed %arg0
|
||||||
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
// CHECK: %[[ODD0:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
||||||
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
||||||
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
|
// CHECK: %[[BCST1:.*]] = x86.avx.bcst_to_f32.packed %arg3
|
||||||
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
// CHECK: %[[ODD1:.*]] = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
||||||
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
||||||
// Even-Indexed FMAs
|
// Even-Indexed FMAs
|
||||||
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
|
// CHECK: %[[BCST3:.*]] = x86.avx.bcst_to_f32.packed %arg4
|
||||||
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
|
// CHECK: %[[EVEN1:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg5
|
||||||
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
|
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
|
||||||
// CHECK: vector.shape_cast
|
// CHECK: vector.shape_cast
|
||||||
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
|
// CHECK: %[[BCST2:.*]] = x86.avx.bcst_to_f32.packed %arg1
|
||||||
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
|
// CHECK: %[[EVEN0:.*]] = x86.avx.cvt.packed.even.indexed_to_f32 %arg2
|
||||||
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
|
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
|
||||||
// CHECK: vector.shape_cast
|
// CHECK: vector.shape_cast
|
||||||
|
|
||||||
@@ -156,7 +156,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -172,16 +172,16 @@ func.func @negative_fma_operand_has_multiple_consumer(
|
|||||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
|
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
|
||||||
%arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
|
%arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
|
||||||
{
|
{
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%2 = vector.fma %0, %1, %arg5 : !vec
|
%2 = vector.fma %0, %1, %arg5 : !vec
|
||||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%5 = vector.fma %3, %4, %2 : !vec
|
%5 = vector.fma %3, %4, %2 : !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
|
||||||
%8 = vector.fma %3, %7, %arg5 : !vec
|
%8 = vector.fma %3, %7, %arg5 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%12 = vector.fma %5, %11, %arg5 : !vec
|
%12 = vector.fma %5, %11, %arg5 : !vec
|
||||||
return %12 : !vec
|
return %12 : !vec
|
||||||
@@ -190,18 +190,18 @@ func.func @negative_fma_operand_has_multiple_consumer(
|
|||||||
// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore,
|
// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore,
|
||||||
// the rewrite is not applied.
|
// the rewrite is not applied.
|
||||||
// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
|
// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -217,17 +217,17 @@ func.func @negative_fma_has_multiple_consumer(
|
|||||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
||||||
{
|
{
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%5 = vector.fma %3, %4, %2 : !vec
|
%5 = vector.fma %3, %4, %2 : !vec
|
||||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%8 = vector.fma %6, %7, %5 : !vec
|
%8 = vector.fma %6, %7, %5 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||||
return %12 : !vec
|
return %12 : !vec
|
||||||
@@ -235,17 +235,17 @@ func.func @negative_fma_has_multiple_consumer(
|
|||||||
|
|
||||||
// vector.fma at %5 has two uses; therefore no re-write applied.
|
// vector.fma at %5 has two uses; therefore no re-write applied.
|
||||||
// CHECK-LABEL: @negative_fma_has_multiple_consumer
|
// CHECK-LABEL: @negative_fma_has_multiple_consumer
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -260,28 +260,28 @@ func.func @negative_no_shuffle_outside_block(
|
|||||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
|
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
|
||||||
{
|
{
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
%0 = x86.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%1 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
%3 = x86.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
%4 = x86.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||||
%5 = vector.fma %3, %4, %2 : !vec
|
%5 = vector.fma %3, %4, %2 : !vec
|
||||||
|
|
||||||
%loop = scf.if %arg7 -> (vector<8xf32>) {
|
%loop = scf.if %arg7 -> (vector<8xf32>) {
|
||||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||||
scf.yield %12 : vector<8xf32>
|
scf.yield %12 : vector<8xf32>
|
||||||
} else {
|
} else {
|
||||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
%6 = x86.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%7 = x86.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
%9 = x86.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
%10 = x86.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||||
%11 = vector.fma %9, %10, %8 : !vec
|
%11 = vector.fma %9, %10, %8 : !vec
|
||||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||||
scf.yield %12 : vector<8xf32>
|
scf.yield %12 : vector<8xf32>
|
||||||
@@ -293,19 +293,19 @@ func.func @negative_no_shuffle_outside_block(
|
|||||||
// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
|
// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
|
||||||
// applied.
|
// applied.
|
||||||
// CHECK-LABEL: @negative_no_shuffle_outside_block
|
// CHECK-LABEL: @negative_no_shuffle_outside_block
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: scf.if
|
// CHECK: scf.if
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -24,7 +24,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.sink_vector_producer_ops
|
transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -57,7 +57,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.sink_vector_producer_ops
|
transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.sink_vector_producer_ops
|
transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.sink_vector_producer_ops
|
transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -160,7 +160,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.sink_vector_producer_ops
|
transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -191,9 +191,8 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.sink_vector_producer_ops
|
transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,18 +30,18 @@ func.func @brgemm_to_fma(
|
|||||||
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 1] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
|
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 1] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
|
||||||
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 0] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
|
// CHECK: memref.subview %arg0[%c0, %c0, %c0, 0] {{.*}} : memref<1x4x1x2xbf16> to memref<1x1x1x1xbf16, {{.*}}>
|
||||||
// CHECK: memref.subview %arg1[%c0, %c0, %c0, %c0] {{.*}} : memref<1x1x32x2xbf16> to memref<1x1x8x2xbf16, {{.*}}>
|
// CHECK: memref.subview %arg1[%c0, %c0, %c0, %c0] {{.*}} : memref<1x1x32x2xbf16> to memref<1x1x8x2xbf16, {{.*}}>
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1x1x1xbf16, strided<[8, 2, 2, 1], offset: ?>>
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x1x8x2xbf16, strided<[64, 64, 2, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -76,18 +76,18 @@ func.func @brgemm_to_fma_load(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @brgemm_to_fma_load
|
// CHECK-LABEL: @brgemm_to_fma_load
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -122,18 +122,18 @@ func.func @brgemm_to_fma_load_bcst_B(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @brgemm_to_fma_load_bcst_B
|
// CHECK-LABEL: @brgemm_to_fma_load_bcst_B
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -168,18 +168,18 @@ func.func @batch_matmul_fma_load(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @batch_matmul_fma_load
|
// CHECK-LABEL: @batch_matmul_fma_load
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -214,18 +214,18 @@ func.func @matmul_outer_product_to_fma_load(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @matmul_outer_product_to_fma_load
|
// CHECK-LABEL: @matmul_outer_product_to_fma_load
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -285,10 +285,10 @@ func.func @matmul_to_fma_flat_layout(
|
|||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
|
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
|
||||||
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
|
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
@@ -297,7 +297,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -357,10 +357,10 @@ func.func @matmul_to_fma_flat_layout_load(
|
|||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
|
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<4x1xbf16> to memref<1x1xbf16, {{.*}}>
|
||||||
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
|
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x32xbf16> to memref<1x16xbf16, {{.*}}>
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[1, 1], offset: ?>>
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<1x16xbf16, strided<[32, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
@@ -369,7 +369,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -448,10 +448,10 @@ func.func @matmul_to_fma_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1: m
|
|||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK: scf.for
|
// CHECK: scf.for
|
||||||
// CHECK: scf.for
|
// CHECK: scf.for
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
||||||
@@ -461,7 +461,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -524,10 +524,10 @@ func.func @matmul_to_fma_flat_layout_bcstB(
|
|||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}>
|
// CHECK: memref.subview %arg1[%c0, %c0] {{.*}} : memref<1x4xbf16> to memref<1x1xbf16, {{.*}}>
|
||||||
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}>
|
// CHECK: memref.subview %arg0[%c0, %c0] {{.*}} : memref<32x1xbf16> to memref<16x1xbf16, {{.*}}>
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
|
// CHECK: x86.avx.bcst_to_f32.packed {{.*}} : memref<1x1xbf16, strided<[4, 1], offset: ?>>
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : memref<16x1xbf16, strided<[1, 1], offset: ?>>
|
||||||
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
// CHECK: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
// CHECK: vector.shuffle{{.*}}[0, 8, 1, 9, 2, 10, 3, 11] : vector<8xf32>, vector<8xf32>
|
||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||||
@@ -536,7 +536,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -572,18 +572,18 @@ func.func @matmul_dynamic_offset(
|
|||||||
|
|
||||||
// CHECK-LABEL: @matmul_dynamic_offset
|
// CHECK-LABEL: @matmul_dynamic_offset
|
||||||
// CHECK: memref.subview %arg0[%arg3, %c0, 1]{{.*}}
|
// CHECK: memref.subview %arg0[%arg3, %c0, 1]{{.*}}
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -618,18 +618,18 @@ func.func @matmul_to_fma_load_bcst_B(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @matmul_to_fma_load_bcst_B
|
// CHECK-LABEL: @matmul_to_fma_load_bcst_B
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -664,18 +664,18 @@ func.func @many_dimensions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @many_dimensions
|
// CHECK-LABEL: @many_dimensions
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
// CHECK: x86vector.avx.bcst_to_f32.packed
|
// CHECK: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK: vector.fma
|
// CHECK: vector.fma
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -744,7 +744,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -785,7 +785,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %0 {
|
transform.apply_patterns to %0 {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -841,17 +841,17 @@ func.func @negative_offset_diff_is_not_8(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_offset_diff_is_not_8
|
// CHECK-LABEL: @negative_offset_diff_is_not_8
|
||||||
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
|
// CHECK-NOT: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -907,17 +907,17 @@ func.func @negative_vector_contracts_not_in_order(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_vector_contracts_not_in_order
|
// CHECK-LABEL: @negative_vector_contracts_not_in_order
|
||||||
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
|
// CHECK-NOT: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -976,17 +976,17 @@ func.func @negative_flat_layout_dynamic_index(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_flat_layout_dynamic_index
|
// CHECK-LABEL: @negative_flat_layout_dynamic_index
|
||||||
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
|
// CHECK-NOT: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1045,17 +1045,17 @@ func.func @negative_non_unit_K_dim(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_non_unit_K_dim
|
// CHECK-LABEL: @negative_non_unit_K_dim
|
||||||
// CHECK-NOT: x86vector.avx.bcst_to_f32.packed
|
// CHECK-NOT: x86.avx.bcst_to_f32.packed
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.even.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.even.indexed_to_f32
|
||||||
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
// CHECK-NOT: vector.fma {{.*}} : vector<8xf32>
|
||||||
// CHECK-NOT: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
// CHECK-NOT: x86.avx.cvt.packed.odd.indexed_to_f32
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1089,7 +1089,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1132,7 +1132,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1176,7 +1176,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1222,7 +1222,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1266,7 +1266,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1308,9 +1308,8 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -61,7 +61,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -129,7 +129,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -197,7 +197,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -233,7 +233,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -269,7 +269,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -303,7 +303,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -337,7 +337,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_fma
|
transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -20,13 +20,13 @@ func.func @brgemm_to_bf16dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @brgemm_to_bf16dp
|
// CHECK-LABEL: @brgemm_to_bf16dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -54,13 +54,13 @@ func.func @brgemm_to_bf16dp_bcst_B(
|
|||||||
|
|
||||||
// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
|
// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -88,13 +88,13 @@ func.func @brgemm_to_avx10int8dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @brgemm_to_avx10int8dp
|
// CHECK-LABEL: @brgemm_to_avx10int8dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx10.dot.i8
|
// CHECK: x86.avx10.dot.i8
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -123,13 +123,13 @@ func.func @batch_matmul_avx10int8dp_bcst_B(
|
|||||||
|
|
||||||
// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
|
// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx10.dot.i8
|
// CHECK: x86.avx10.dot.i8
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -157,13 +157,13 @@ func.func @brgemm_to_int8dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @brgemm_to_int8dp
|
// CHECK-LABEL: @brgemm_to_int8dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx.dot.i8
|
// CHECK: x86.avx.dot.i8
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -191,13 +191,13 @@ func.func @batch_matmul_bf16dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @batch_matmul_bf16dp
|
// CHECK-LABEL: @batch_matmul_bf16dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -226,13 +226,13 @@ func.func @batch_matmul_int8dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @batch_matmul_int8dp
|
// CHECK-LABEL: @batch_matmul_int8dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx.dot.i8
|
// CHECK: x86.avx.dot.i8
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -261,13 +261,13 @@ func.func @batch_matmul_int8dp_bcst_B(
|
|||||||
|
|
||||||
// CHECK-LABEL: @batch_matmul_int8dp_bcst_B
|
// CHECK-LABEL: @batch_matmul_int8dp_bcst_B
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx.dot.i8
|
// CHECK: x86.avx.dot.i8
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -295,13 +295,13 @@ func.func @matmul_outer_product_to_bf16dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @matmul_outer_product_to_bf16dp
|
// CHECK-LABEL: @matmul_outer_product_to_bf16dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -329,13 +329,13 @@ func.func @matmul_outer_product_to_int8dp(
|
|||||||
|
|
||||||
// CHECK-LABEL: @matmul_outer_product_to_int8dp
|
// CHECK-LABEL: @matmul_outer_product_to_int8dp
|
||||||
// CHECK: vector.broadcast
|
// CHECK: vector.broadcast
|
||||||
// CHECK: x86vector.avx.dot.i8
|
// CHECK: x86.avx.dot.i8
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -395,8 +395,8 @@ func.func @matmul_bf16dp_flat_layout(
|
|||||||
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
|
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
|
||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
|
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
|
||||||
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
|
// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
|
||||||
|
|
||||||
@@ -404,7 +404,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -485,8 +485,8 @@ func.func @brmatmul_bf16dp_flat_layout_loop(%arg0: memref<16x64x32xbf16>, %arg1:
|
|||||||
// CHECK: scf.for
|
// CHECK: scf.for
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
|
// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
|
||||||
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
|
// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
|
// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
|
||||||
@@ -496,7 +496,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -566,15 +566,15 @@ func.func @matmul_bf16dp_flat_layout_B_shuffled(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled
|
// CHECK-LABEL: @matmul_bf16dp_flat_layout_B_shuffled
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
// CHECK: x86vector.avx512.dot
|
// CHECK: x86.avx512.dot
|
||||||
// CHECK-NOT: vector.contract
|
// CHECK-NOT: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -601,14 +601,14 @@ func.func @negative_invalid_vc_kind(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_invalid_vc_kind
|
// CHECK-LABEL: @negative_invalid_vc_kind
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -635,14 +635,14 @@ func.func @negative_false_vnni_bf16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_false_vnni_bf16
|
// CHECK-LABEL: @negative_false_vnni_bf16
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -669,14 +669,14 @@ func.func @negative_false_vnni_int8(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_false_vnni_int8
|
// CHECK-LABEL: @negative_false_vnni_int8
|
||||||
// CHECK-NOT: x86vector.avx.dot.i8
|
// CHECK-NOT: x86.avx.dot.i8
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -703,14 +703,14 @@ func.func @negative_batch_dimension(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_batch_dimension
|
// CHECK-LABEL: @negative_batch_dimension
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -737,14 +737,14 @@ func.func @negative_brgemm_dimension(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_brgemm_dimension
|
// CHECK-LABEL: @negative_brgemm_dimension
|
||||||
// CHECK-NOT: x86vector.avx.dot.i8
|
// CHECK-NOT: x86.avx.dot.i8
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -771,14 +771,14 @@ func.func @negative_float_acc_type(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_float_acc_type
|
// CHECK-LABEL: @negative_float_acc_type
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -805,14 +805,14 @@ func.func @negative_int_acc_type(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_int_acc_type
|
// CHECK-LABEL: @negative_int_acc_type
|
||||||
// CHECK-NOT: x86vector.avx.dot.i8
|
// CHECK-NOT: x86.avx.dot.i8
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -839,14 +839,14 @@ func.func @negative_wrong_vnni_blocking_factor_bf16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
|
// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -873,14 +873,14 @@ func.func @negative_brgemm_not_vnni(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_brgemm_not_vnni
|
// CHECK-LABEL: @negative_brgemm_not_vnni
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -907,15 +907,15 @@ func.func @negative_wrong_vector_shape_int8(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_wrong_vector_shape_int8
|
// CHECK-LABEL: @negative_wrong_vector_shape_int8
|
||||||
// CHECK-NOT: x86vector.avx.dot.i8
|
// CHECK-NOT: x86.avx.dot.i8
|
||||||
// CHECK-NOT: x86vector.avx10.dot.i8
|
// CHECK-NOT: x86.avx10.dot.i8
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -942,14 +942,14 @@ func.func @negative_wrong_vector_shape_bf16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_wrong_vector_shape_bf16
|
// CHECK-LABEL: @negative_wrong_vector_shape_bf16
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1005,14 +1005,14 @@ func.func @negative_flat_other_dim_is_not_2(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_flat_other_dim_is_not_2
|
// CHECK-LABEL: @negative_flat_other_dim_is_not_2
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1069,14 +1069,14 @@ func.func @negative_flat_offset_diff_is_not16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_flat_offset_diff_is_not16
|
// CHECK-LABEL: @negative_flat_offset_diff_is_not16
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1132,14 +1132,14 @@ func.func @negative_flat_dynamic_offset(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_flat_dynamic_offset
|
// CHECK-LABEL: @negative_flat_dynamic_offset
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1196,14 +1196,14 @@ func.func @negative_flat_read_after_contract(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @negative_flat_read_after_contract
|
// CHECK-LABEL: @negative_flat_read_after_contract
|
||||||
// CHECK-NOT: x86vector.avx512.dot
|
// CHECK-NOT: x86.avx512.dot
|
||||||
// CHECK: vector.contract
|
// CHECK: vector.contract
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1271,7 +1271,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1338,7 +1338,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1406,7 +1406,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1473,7 +1473,7 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.apply_patterns to %func {
|
transform.apply_patterns to %func {
|
||||||
transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
} : !transform.any_op
|
} : !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
|
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
|
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
|
||||||
// RUN: FileCheck %s
|
// RUN: FileCheck %s
|
||||||
@@ -9,7 +9,7 @@ func.func @entry() -> i32 {
|
|||||||
|
|
||||||
%a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32>
|
%a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32>
|
||||||
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
|
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
|
||||||
%r = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
%r = x86.avx.intr.dot %a, %b : vector<8xf32>
|
||||||
|
|
||||||
%1 = vector.extract %r[%i0] : f32 from vector<8xf32>
|
%1 = vector.extract %r[%i0] : f32 from vector<8xf32>
|
||||||
%2 = vector.extract %r[%i4] : f32 from vector<8xf32>
|
%2 = vector.extract %r[%i4] : f32 from vector<8xf32>
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
# X86Vector tests must be enabled via build flag.
|
# X86 tests must be enabled via build flag.
|
||||||
if not config.mlir_run_x86vector_tests:
|
if not config.mlir_run_x86_tests:
|
||||||
config.unsupported = True
|
config.unsupported = True
|
||||||
|
|
||||||
# No JIT on win32.
|
# No JIT on win32.
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
|
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_c_runner_utils | \
|
// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_c_runner_utils | \
|
||||||
// RUN: FileCheck %s
|
// RUN: FileCheck %s
|
||||||
@@ -8,8 +8,8 @@ func.func @entry() -> i32 {
|
|||||||
|
|
||||||
%a = arith.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32>
|
%a = arith.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32>
|
||||||
%k = arith.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1>
|
%k = arith.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1>
|
||||||
%r1 = x86vector.avx512.mask.compress %k, %a : vector<16xf32>
|
%r1 = x86.avx512.mask.compress %k, %a : vector<16xf32>
|
||||||
%r2 = x86vector.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
%r2 = x86.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||||
|
|
||||||
vector.print %r1 : vector<16xf32>
|
vector.print %r1 : vector<16xf32>
|
||||||
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
|
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
@@ -18,7 +18,7 @@ func.func @entry() -> i32 {
|
|||||||
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 )
|
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 )
|
||||||
|
|
||||||
%src = arith.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32>
|
%src = arith.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32>
|
||||||
%r3 = x86vector.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
|
%r3 = x86.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
|
||||||
|
|
||||||
vector.print %r3 : vector<16xf32>
|
vector.print %r3 : vector<16xf32>
|
||||||
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 )
|
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 )
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
|
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
|
// RUN: %lli --entry-function=entry --mattr="avx" --dlopen=%mlir_c_runner_utils | \
|
||||||
// RUN: FileCheck %s
|
// RUN: FileCheck %s
|
||||||
@@ -7,7 +7,7 @@ func.func @entry() -> i32 {
|
|||||||
%i0 = arith.constant 0 : i32
|
%i0 = arith.constant 0 : i32
|
||||||
|
|
||||||
%v = arith.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32>
|
%v = arith.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32>
|
||||||
%r = x86vector.avx.rsqrt %v : vector<8xf32>
|
%r = x86.avx.rsqrt %v : vector<8xf32>
|
||||||
// `rsqrt` may produce slightly different results on Intel and AMD machines: accept both results here.
|
// `rsqrt` may produce slightly different results on Intel and AMD machines: accept both results here.
|
||||||
// CHECK: {{( 2.82[0-9]*, 1.99[0-9]*, 1.41[0-9]*, 0.99[0-9]*, 0.70[0-9]*, 0.49[0-9]*, 0.35[0-9]*, 0.24[0-9]* )}}
|
// CHECK: {{( 2.82[0-9]*, 1.99[0-9]*, 1.41[0-9]*, 0.99[0-9]*, 0.70[0-9]*, 0.49[0-9]*, 0.35[0-9]*, 0.24[0-9]* )}}
|
||||||
vector.print %r : vector<8xf32>
|
vector.print %r : vector<8xf32>
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
|
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
|
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
|
||||||
// RUN: FileCheck %s
|
// RUN: FileCheck %s
|
||||||
@@ -35,11 +35,11 @@
|
|||||||
func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>,
|
func.func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>,
|
||||||
%v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 {
|
%v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 {
|
||||||
// Compute intersection of indices.
|
// Compute intersection of indices.
|
||||||
%k0, %k1 = x86vector.avx512.vp2intersect %v_A, %v_C : vector<8xi64>
|
%k0, %k1 = x86.avx512.vp2intersect %v_A, %v_C : vector<8xi64>
|
||||||
|
|
||||||
// Filter out values without match and compress vector.
|
// Filter out values without match and compress vector.
|
||||||
%p0 = x86vector.avx512.mask.compress %k0, %v_B : vector<8xf64>
|
%p0 = x86.avx512.mask.compress %k0, %v_B : vector<8xf64>
|
||||||
%p1 = x86vector.avx512.mask.compress %k1, %v_D : vector<8xf64>
|
%p1 = x86.avx512.mask.compress %k1, %v_D : vector<8xf64>
|
||||||
|
|
||||||
// Dense vector dot product.
|
// Dense vector dot product.
|
||||||
%acc = arith.constant 0.0 : f64
|
%acc = arith.constant 0.0 : f64
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" -test-lower-to-llvm | \
|
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" -test-lower-to-llvm | \
|
||||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||||
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
|
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_c_runner_utils | \
|
||||||
// RUN: FileCheck %s
|
// RUN: FileCheck %s
|
||||||
@@ -40,7 +40,7 @@ func.func @entry() -> i32 {
|
|||||||
vector.print %w9 : vector<16xi32>
|
vector.print %w9 : vector<16xi32>
|
||||||
// CHECK: ( 1, 1, 1, 1, 2, 1, 1, -219, 12, 12, 12, 0, 0, 0, 1, 0 )
|
// CHECK: ( 1, 1, 1, 1, 2, 1, 1, -219, 12, 12, 12, 0, 0, 0, 1, 0 )
|
||||||
|
|
||||||
%k1, %k2 = x86vector.avx512.vp2intersect %v9, %w9 : vector<16xi32>
|
%k1, %k2 = x86.avx512.vp2intersect %v9, %w9 : vector<16xi32>
|
||||||
|
|
||||||
vector.print %k1 : vector<16xi1>
|
vector.print %k1 : vector<16xi1>
|
||||||
// CHECK: ( 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1 )
|
// CHECK: ( 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1 )
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm -reconcile-unrealized-casts \
|
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \
|
||||||
// RUN: | mlir-translate --mlir-to-llvmir \
|
// RUN: | mlir-translate --mlir-to-llvmir \
|
||||||
// RUN: | FileCheck %s
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
@@ -11,9 +11,9 @@ func.func @LLVM_x86_avx512_mask_ps_512(
|
|||||||
%rnd_k = arith.constant 15 : i32
|
%rnd_k = arith.constant 15 : i32
|
||||||
%rnd = arith.constant 42 : i32
|
%rnd = arith.constant 42 : i32
|
||||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
|
// CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
|
||||||
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32>
|
%0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32>
|
||||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
|
// CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
|
||||||
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32>
|
%1 = x86.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32>
|
||||||
return %1 : vector<16xf32>
|
return %1 : vector<16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,9 +26,9 @@ func.func @LLVM_x86_avx512_mask_pd_512(
|
|||||||
%rnd_k = arith.constant 15 : i32
|
%rnd_k = arith.constant 15 : i32
|
||||||
%rnd = arith.constant 42 : i32
|
%rnd = arith.constant 42 : i32
|
||||||
// CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
|
// CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
|
||||||
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64>
|
%0 = x86.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64>
|
||||||
// CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
|
// CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
|
||||||
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64>
|
%1 = x86.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64>
|
||||||
return %1 : vector<8xf64>
|
return %1 : vector<8xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ func.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
|
|||||||
-> vector<16xf32>
|
-> vector<16xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
|
// CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
|
||||||
%0 = x86vector.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32>
|
%0 = x86.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32>
|
||||||
return %0 : vector<16xf32>
|
return %0 : vector<16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ func.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
|
|||||||
-> (vector<16xi1>, vector<16xi1>)
|
-> (vector<16xi1>, vector<16xi1>)
|
||||||
{
|
{
|
||||||
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
|
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
|
||||||
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<16xi32>
|
%0, %1 = x86.avx512.vp2intersect %a, %b : vector<16xi32>
|
||||||
return %0, %1 : vector<16xi1>, vector<16xi1>
|
return %0, %1 : vector<16xi1>, vector<16xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ func.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
|
|||||||
-> (vector<8 x i1>, vector<8 x i1>)
|
-> (vector<8 x i1>, vector<8 x i1>)
|
||||||
{
|
{
|
||||||
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
|
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
|
||||||
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<8xi64>
|
%0, %1 = x86.avx512.vp2intersect %a, %b : vector<8xi64>
|
||||||
return %0, %1 : vector<8 x i1>, vector<8 x i1>
|
return %0, %1 : vector<8 x i1>, vector<8 x i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_128(
|
|||||||
) -> vector<4xf32>
|
) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
|
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_256(
|
|||||||
) -> vector<8xf32>
|
) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
|
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ func.func @LLVM_x86_avx512bf16_dpbf16ps_512(
|
|||||||
) -> vector<16xf32>
|
) -> vector<16xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
|
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
|
||||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
%0 = x86.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||||
return %0 : vector<16xf32>
|
return %0 : vector<16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
|
|||||||
%a: vector<8xf32>) -> vector<8xbf16>
|
%a: vector<8xf32>) -> vector<8xbf16>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
|
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a
|
||||||
: vector<8xf32> -> vector<8xbf16>
|
: vector<8xf32> -> vector<8xbf16>
|
||||||
return %0 : vector<8xbf16>
|
return %0 : vector<8xbf16>
|
||||||
}
|
}
|
||||||
@@ -104,7 +104,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
|
|||||||
%a: vector<16xf32>) -> vector<16xbf16>
|
%a: vector<16xf32>) -> vector<16xbf16>
|
||||||
{
|
{
|
||||||
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
|
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
|
||||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a
|
%0 = x86.avx512.cvt.packed.f32_to_bf16 %a
|
||||||
: vector<16xf32> -> vector<16xbf16>
|
: vector<16xf32> -> vector<16xbf16>
|
||||||
return %0 : vector<16xbf16>
|
return %0 : vector<16xbf16>
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
|
|||||||
func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
|
||||||
%b: vector<64xi8>) -> vector<16xi32> {
|
%b: vector<64xi8>) -> vector<16xi32> {
|
||||||
// CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
|
// CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
|
||||||
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
%0 = x86.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
|
||||||
return %0 : vector<16xi32>
|
return %0 : vector<16xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
|
|||||||
%a: memref<8xbf16>) -> vector<4xf32>
|
%a: memref<8xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
|
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
|
|||||||
%a: memref<16xbf16>) -> vector<8xf32>
|
%a: memref<16xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
|
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
|
|||||||
%a: memref<8xbf16>) -> vector<4xf32>
|
%a: memref<8xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
|
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,7 +149,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
|
|||||||
%a: memref<16xbf16>) -> vector<8xf32>
|
%a: memref<16xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
|
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +158,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
|
|||||||
%a: memref<1xbf16>) -> vector<4xf32>
|
%a: memref<1xbf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
|
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,7 +167,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
|
|||||||
%a: memref<1xbf16>) -> vector<8xf32>
|
%a: memref<1xbf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
|
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
|
|||||||
%a: memref<8xf16>) -> vector<4xf32>
|
%a: memref<8xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
|
// CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
|
|||||||
%a: memref<16xf16>) -> vector<8xf32>
|
%a: memref<16xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
|
// CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
|
||||||
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
|
|||||||
%a: memref<8xf16>) -> vector<4xf32>
|
%a: memref<8xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
|
// CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
|
|||||||
%a: memref<16xf16>) -> vector<8xf32>
|
%a: memref<16xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
|
// CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
|
||||||
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
%0 = x86.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
|
|||||||
%a: memref<1xf16>) -> vector<4xf32>
|
%a: memref<1xf16>) -> vector<4xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
|
// CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
|
||||||
return %0 : vector<4xf32>
|
return %0 : vector<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,7 +221,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
|
|||||||
%a: memref<1xf16>) -> vector<8xf32>
|
%a: memref<1xf16>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
|
// CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
|
||||||
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
%0 = x86.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +229,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
|
|||||||
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
|
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
|
// CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
|
||||||
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
|
%0 = x86.avx.rsqrt %a : vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,7 +239,7 @@ func.func @LLVM_x86_avx_dp_ps_256(
|
|||||||
) -> vector<8xf32>
|
) -> vector<8xf32>
|
||||||
{
|
{
|
||||||
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
|
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
|
||||||
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
%0 = x86.avx.intr.dot %a, %b : vector<8xf32>
|
||||||
return %0 : vector<8xf32>
|
return %0 : vector<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,7 +247,7 @@ func.func @LLVM_x86_avx_dp_ps_256(
|
|||||||
func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
||||||
%b: vector<16xi8>) -> vector<4xi32> {
|
%b: vector<16xi8>) -> vector<4xi32> {
|
||||||
// CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
|
// CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
|
||||||
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
|
%0 = x86.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
|
||||||
return %0 : vector<4xi32>
|
return %0 : vector<4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,6 +255,6 @@ func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
|
|||||||
func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
|
func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
|
||||||
%b: vector<32xi8>) -> vector<8xi32> {
|
%b: vector<32xi8>) -> vector<8xi32> {
|
||||||
// CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
|
// CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
|
||||||
%0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
%0 = x86.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
|
||||||
return %0 : vector<8xi32>
|
return %0 : vector<8xi32>
|
||||||
}
|
}
|
||||||
@@ -10,5 +10,5 @@ mlir_target_link_libraries(MLIRMathTestPasses PUBLIC
|
|||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRVectorDialect
|
MLIRVectorDialect
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ struct TestMathPolynomialApproximationPass
|
|||||||
registry.insert<arith::ArithDialect, math::MathDialect,
|
registry.insert<arith::ArithDialect, math::MathDialect,
|
||||||
vector::VectorDialect>();
|
vector::VectorDialect>();
|
||||||
if (enableAvx2)
|
if (enableAvx2)
|
||||||
registry.insert<x86vector::X86VectorDialect>();
|
registry.insert<x86::X86Dialect>();
|
||||||
}
|
}
|
||||||
StringRef getArgument() const final {
|
StringRef getArgument() const final {
|
||||||
return "test-math-polynomial-approximation";
|
return "test-math-polynomial-approximation";
|
||||||
@@ -49,7 +49,7 @@ struct TestMathPolynomialApproximationPass
|
|||||||
Option<bool> enableAvx2{
|
Option<bool> enableAvx2{
|
||||||
*this, "enable-avx2",
|
*this, "enable-avx2",
|
||||||
llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
|
llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
|
||||||
"X86Vector dialect"),
|
"X86 dialect"),
|
||||||
llvm::cl::init(false)};
|
llvm::cl::init(false)};
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@@ -20,5 +20,5 @@ mlir_target_link_libraries(MLIRVectorTestPasses PUBLIC
|
|||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRVectorDialect
|
MLIRVectorDialect
|
||||||
MLIRVectorToSCF
|
MLIRVectorToSCF
|
||||||
MLIRX86VectorDialect
|
MLIRX86Dialect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
|
|||||||
if config.mlir_run_arm_sve_tests:
|
if config.mlir_run_arm_sve_tests:
|
||||||
config.available_features.add("mlir_arm_sve_tests")
|
config.available_features.add("mlir_arm_sve_tests")
|
||||||
config.mlir_run_arm_sme_tests = @MLIR_RUN_ARM_SME_TESTS@
|
config.mlir_run_arm_sme_tests = @MLIR_RUN_ARM_SME_TESTS@
|
||||||
config.mlir_run_x86vector_tests = @MLIR_RUN_X86VECTOR_TESTS@
|
config.mlir_run_x86_tests = @MLIR_RUN_X86_TESTS@
|
||||||
config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@"
|
config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@"
|
||||||
config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@
|
config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@
|
||||||
config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@
|
config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
// CHECK-SAME: tosa
|
// CHECK-SAME: tosa
|
||||||
// CHECK-SAME: transform
|
// CHECK-SAME: transform
|
||||||
// CHECK-SAME: vector
|
// CHECK-SAME: vector
|
||||||
// CHECK-SAME: x86vector
|
// CHECK-SAME: x86
|
||||||
|
|
||||||
// RUN: mlir-opt --help-hidden | FileCheck %s -check-prefix=CHECK-HELP
|
// RUN: mlir-opt --help-hidden | FileCheck %s -check-prefix=CHECK-HELP
|
||||||
// CHECK-HELP: -p - Alias for --pass-pipeline
|
// CHECK-HELP: -p - Alias for --pass-pipeline
|
||||||
|
|||||||
5
mlir/test/mlir-runner/X86/lit.local.cfg
Normal file
5
mlir/test/mlir-runner/X86/lit.local.cfg
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
# X86 tests must be enabled via build flag.
|
||||||
|
if not config.mlir_run_x86_tests:
|
||||||
|
config.unsupported = True
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
// RUN: -convert-scf-to-cf \
|
// RUN: -convert-scf-to-cf \
|
||||||
// RUN: -convert-arith-to-llvm \
|
// RUN: -convert-arith-to-llvm \
|
||||||
// RUN: -convert-cf-to-llvm \
|
// RUN: -convert-cf-to-llvm \
|
||||||
// RUN: -convert-vector-to-llvm="enable-x86vector" \
|
// RUN: -convert-vector-to-llvm="enable-x86" \
|
||||||
// RUN: -convert-math-to-llvm \
|
// RUN: -convert-math-to-llvm \
|
||||||
// RUN: -convert-func-to-llvm \
|
// RUN: -convert-func-to-llvm \
|
||||||
// RUN: -reconcile-unrealized-casts \
|
// RUN: -reconcile-unrealized-casts \
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
# X86Vector tests must be enabled via build flag.
|
|
||||||
if not config.mlir_run_x86vector_tests:
|
|
||||||
config.unsupported = True
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from mlir.ir import *
|
from mlir.ir import *
|
||||||
from mlir.dialects import transform
|
from mlir.dialects import transform
|
||||||
from mlir.dialects.transform import x86vector
|
from mlir.dialects.transform import x86
|
||||||
|
|
||||||
|
|
||||||
def run_apply_patterns(f):
|
def run_apply_patterns(f):
|
||||||
@@ -28,13 +28,13 @@ def run_apply_patterns(f):
|
|||||||
def non_configurable_patterns():
|
def non_configurable_patterns():
|
||||||
# CHECK-LABEL: TEST: non_configurable_patterns
|
# CHECK-LABEL: TEST: non_configurable_patterns
|
||||||
# CHECK: apply_patterns
|
# CHECK: apply_patterns
|
||||||
# CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
|
# CHECK: transform.apply_patterns.x86.vector_contract_to_fma
|
||||||
x86vector.ApplyVectorContractToFMAPatternsOp()
|
x86.ApplyVectorContractToFMAPatternsOp()
|
||||||
# CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
|
# CHECK: transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
|
||||||
x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
|
x86.ApplyVectorContractToPackedTypeDotProductPatternsOp()
|
||||||
# CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
|
# CHECK: transform.apply_patterns.x86.vector_contract_bf16_to_fma
|
||||||
x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
|
x86.ApplyVectorContractBF16ToFMAPatternsOp()
|
||||||
# CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
|
# CHECK: transform.apply_patterns.x86.sink_vector_producer_ops
|
||||||
x86vector.ApplySinkVectorProducerOpsPatternsOp()
|
x86.ApplySinkVectorProducerOpsPatternsOp()
|
||||||
# CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
# CHECK: transform.apply_patterns.x86.shuffle_vector_fma_ops
|
||||||
x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
|
x86.ApplyShuffleVectorFMAOpsPatternsOp()
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
from mlir.ir import *
|
from mlir.ir import *
|
||||||
import mlir.dialects.builtin as builtin
|
import mlir.dialects.builtin as builtin
|
||||||
import mlir.dialects.func as func
|
import mlir.dialects.func as func
|
||||||
import mlir.dialects.x86vector as x86vector
|
import mlir.dialects.x86 as x86
|
||||||
|
|
||||||
|
|
||||||
def run(f):
|
def run(f):
|
||||||
@@ -21,13 +21,11 @@ def testAvxOp():
|
|||||||
|
|
||||||
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
|
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
|
||||||
def avx_op(arg):
|
def avx_op(arg):
|
||||||
return x86vector.BcstToPackedF32Op(
|
return x86.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
|
||||||
a=arg, dst=VectorType.get((8,), F32Type.get())
|
|
||||||
)
|
|
||||||
|
|
||||||
# CHECK-LABEL: func @avx_op(
|
# CHECK-LABEL: func @avx_op(
|
||||||
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
|
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
|
||||||
# CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
|
# CHECK: %[[VAL:.+]] = x86.avx.bcst_to_f32.packed %[[ARG]]
|
||||||
# CHECK: return %[[VAL]] : vector<8xf32>
|
# CHECK: return %[[VAL]] : vector<8xf32>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
print(module)
|
print(module)
|
||||||
@@ -41,13 +39,13 @@ def testAvx512Op():
|
|||||||
|
|
||||||
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
|
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
|
||||||
def avx512_op(arg):
|
def avx512_op(arg):
|
||||||
return x86vector.CvtPackedF32ToBF16Op(
|
return x86.CvtPackedF32ToBF16Op(
|
||||||
a=arg, dst=VectorType.get((8,), BF16Type.get())
|
a=arg, dst=VectorType.get((8,), BF16Type.get())
|
||||||
)
|
)
|
||||||
|
|
||||||
# CHECK-LABEL: func @avx512_op(
|
# CHECK-LABEL: func @avx512_op(
|
||||||
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
|
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
|
||||||
# CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
|
# CHECK: %[[VAL:.+]] = x86.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
|
||||||
# CHECK: return %[[VAL]] : vector<8xbf16>
|
# CHECK: return %[[VAL]] : vector<8xbf16>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
print(module)
|
print(module)
|
||||||
@@ -65,12 +63,12 @@ def testAvx10Op():
|
|||||||
VectorType.get((64,), IntegerType.get(8)),
|
VectorType.get((64,), IntegerType.get(8)),
|
||||||
)
|
)
|
||||||
def avx10_op(*args):
|
def avx10_op(*args):
|
||||||
return x86vector.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
|
return x86.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
|
||||||
|
|
||||||
# CHECK-LABEL: func @avx10_op(
|
# CHECK-LABEL: func @avx10_op(
|
||||||
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
|
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
|
||||||
# CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
|
# CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
|
||||||
# CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
|
# CHECK: %[[VAL:.+]] = x86.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
|
||||||
# CHECK: return %[[VAL]] : vector<16xi32>
|
# CHECK: return %[[VAL]] : vector<16xi32>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
print(module)
|
print(module)
|
||||||
Reference in New Issue
Block a user