[mlir][x86] Move AMX dialect into X86 dialect (#183717)

Unifies the two dialects that define x86 operations into a single one.
The AMX dialect is moved into X86 in line with other x86 extensions.

Following the dialect renaming, X86 dialect is now a suitable home for
wider range of operations targeting specific hardware features. Moving
AMX definitions to X86 dialect creates a single, centralized hub for
defining all x86 intrinsic-like operations. The new grouping aims to
eliminate the need for new dialects as new hardware extensions become
available.

The two dialects are simply merged together. X86 dialect refactoring
will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX
integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
This commit is contained in:
Adam Siemieniuk
2026-03-02 11:47:30 +01:00
committed by GitHub
parent e3b01e1329
commit e44fd05035
56 changed files with 1210 additions and 1531 deletions

View File

@@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas.
* arm_neon Dialect ([@banach-space](https://github.com/banach-space))
* arm_sve Dialect ([@banach-space](https://github.com/banach-space))
* ArmSME Dialect ([@banach-space](https://github.com/banach-space))
* amx Dialect ([@adam-smnk](https://github.com/adam-smnk))
* x86 Dialect ([@adam-smnk](https://github.com/adam-smnk))
* vcix Dialect ([@mshockwave](https://github.com/mshockwave))

View File

@@ -5,8 +5,8 @@ overall flow is two-stage:
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
[X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md)
or [ArmNeon](Dialects/ArmNeon.md);
2. **translation** of MLIR dialects to LLVM IR.
This flow allows the non-trivial transformation to be performed within MLIR

View File

@@ -1,25 +0,0 @@
//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_AMX_H
#define MLIR_C_DIALECT_AMX_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_DIALECT_AMX_H

View File

@@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
(AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
architectural-neutral vector dialect lowering.
(X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral
vector dialect lowering.
}];
// Override explicitly in C++ to allow conditional dialect dependence.
@@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"vector access are naturally aligned. If operations have an "
"alignment attribute set, the alignment attribute takes priority "
"over this option ">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
"dialect.">,
Option<"armNeon", "enable-arm-neon",
"bool", /*default=*/"false",
"Enables the use of ArmNeon dialect while lowering the vector "
@@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
//===----------------------------------------------------------------------===//
def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
let summary = "Lower the operations from the vector dialect into the AMX "
"dialect";
let summary = "Lower the operations from the vector dialect into the X86 "
"dialect AMX operations";
let dependentDialects = [
"affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
"affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect",
"memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
];
}

View File

@@ -1,4 +1,4 @@
//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,7 +18,7 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTVECTORTOAMX
#include "mlir/Conversion/Passes.h.inc"
/// Collect a set of patterns to convert from the vector to AMX ops.
/// Collect a set of patterns to convert from the vector to X86 AMX ops.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
} // namespace mlir

View File

@@ -1,440 +0,0 @@
//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the basic operations for the AMX dialect.
//
// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
// multiply unit (TMUL), a tile control register (TILECFG), and eight
// tile registers TMM0 through TMM7 (TILEDATA).
//
// The AMX dialect provides a bridge between MLIR concepts, such as
// 2-d vector, operations, and memrefs, and the lower level details
// of Intel AMX, such as configuration setup, tile sizes, instructions,
// and tile release.
//
// Note that since configuration changes (implicit at dialect level) are
// costly, it is highly recommended to use the AMX dialect on same-shaped
// vectors, at least within a single method.
//
// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
//
//===----------------------------------------------------------------------===//
#ifndef AMX
#define AMX
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/AMX/AMXInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypes.td"
//===----------------------------------------------------------------------===//
// AMX dialect definition.
//===----------------------------------------------------------------------===//
def AMX_Dialect : Dialect {
let name = "amx";
let cppNamespace = "::mlir::amx";
let description = [{
The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
multiply unit (TMUL), a tile control register (TILECFG), and eight
tile registers TMM0 through TMM7 (TILEDATA).
This `AMX` dialect provides a bridge between MLIR concepts such as
vectors and memrefs and the lower level LLVM IR support of AMX.
Note that since configuration changes (implicit at dialect level) are
costly, it is highly recommended to use the AMX dialect on same-shaped
vectors, at least within a single method.
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
// AMX Tile definition.
//===----------------------------------------------------------------------===//
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
: TypeDef<AMX_Dialect, typeName, traits> {
let mnemonic = typeMnemonic;
}
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
let cppFunctionName = "isValidTileTypeElementType";
}
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
let summary = "AMX 2D tile to be used by AMX opertaions.";
let description = [{
This type is used to represent values in AMX tile registers. All AMX operations
work on AMX tiles and these tiles cannot be used in other operations directly.
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
element type for IR verification and lowering to LLVMIR dialect.
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
AMX_TileTypeElementType:$elementType
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
return $_get(elementType.getContext(), shape, elementType);
}]>
];
let extraClassDeclaration = [{
/// Returns if this type is ranked (always true).
bool hasRank() const { return true; }
/// Clone this tile type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return get(shape.value_or(getShape()), elementType);
}
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
class AMXTileOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::amx::TileType">;
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
def AMXTileF32 : AMXTileOf<[F32]>;
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
def AMXTileI32 : AMXTileOf<[I32]>;
def AMXTileI8 : AMXTileOf<[I8]>;
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
class AMX_Op<string mnemonic, list<Trait> traits = []> :
Op<AMX_Dialect, mnemonic, traits> {}
//===----------------------------------------------------------------------===//
// AMX Op definitions
//===----------------------------------------------------------------------===//
//
// Tile reset.
//
def TileZeroOp : AMX_Op<"tile_zero", [
AMXIntrinsicOpInterface,
MemoryEffects<[MemWrite]>
]> {
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
vector type of the result.
The operation is eventually lowered into the "tilezero" instruction
with the corresponding tile configuration.
With the write memory effect, each `amx.tile_zero` operation serves as
a compilation hint to use a separate tile register.
Example:
```mlir
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
```
}];
let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
std::string getIntrinsicName() {
return "llvm.x86.tilezero.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
}
//
// Tile memory operations.
//
def TileLoadOp : AMX_Op<"tile_load", [
AMXIntrinsicOpInterface,
MemoryEffects<[MemWrite]>,
AttrSizedOperandSegments
]> {
let summary = "tile load operation";
let description = [{
Loads a tile from memory defined by a `base` and `indices`, with the
shape defined by the 2-dim vector type of the result.
The tile's rows are populated by reading contiguous elements starting
at the `base`. For each tile row, the `base` is incremented by `stride`
number of elements.
The tile is loaded using the following indexing scheme:
```
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
tile[row, col] = mem_row[col]
```
If the `stride` is not provided, then the `base` buffer must be at least
2-dimensional, and the `stride` is automatically inferred and corresponds
to the stride of the buffer's second innermost dimension.
The operation is eventually lowered into the "tileloadd" instruction
with the corresponding tile configuration.
With the write memory effect, each `amx.tile_load` operation serves as
a compilation hint to use a separate tile register.
Example:
```mlir
// Tile load from a 2-D memref with implicit stride.
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
// Tile load from a 1-D memref with explicit stride.
%0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<Index>:$stride);
let results = (outs AnyAMXTile:$res);
let builders = [
OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
std::string getIntrinsicName() {
return "llvm.x86.tileloadd64.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
"`:` type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}
def TileStoreOp : AMX_Op<"tile_store", [
AMXIntrinsicOpInterface,
AttrSizedOperandSegments
]> {
let summary = "tile store operation";
let description = [{
Stores a tile to memory defined by a `base` and `indices`, with the
shape defined by the 2-dim vector type of the value.
The tile's rows are written contiguously to the buffer starting at
the `base`. For each tile row, the `base` is incremented by `stride`
number of elements.
The tile is stored using the following indexing scheme:
```
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
mem_row[col] = tile[row, col]
```
If the `stride` is not provided, then the `base` buffer must be at least
2-dimensional, and the `stride` is automatically inferred and corresponds
to the stride of the buffer's second innermost dimension.
The operation is eventually lowered into the "tilestored" instruction
with the corresponding tile configuration.
Example:
```mlir
// Tile store to a 2-D memref with implicit stride.
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
// Tile store to a 1-D memref with explicit stride.
amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
AnyAMXTile:$val,
Optional<Index>:$stride);
let builders = [
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
TileType getTileType() {
return ::llvm::cast<TileType>(getVal().getType());
}
std::string getIntrinsicName() {
return "llvm.x86.tilestored64.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
"attr-dict `:` type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}
//
// Tile arithmetic operations.
//
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
AMXIntrinsicOpInterface,
AllTypesMatch<["acc", "res"]>
]> {
let summary = "tile multiplication operation (floating-point)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
pairs of "bf16").
The operation is eventually lowered into the "tdpbf16ps" instruction with
the corresponding tile configuration.
Example:
```mlir
%0 = amx.tile_mulf %a, %b, %c
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
```
}];
let arguments = (ins AMXTileF16OrBF16:$lhs,
AMXTileF16OrBF16:$rhs,
AMXTileF32:$acc);
let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
}
TileType getRhsTileType() {
return ::llvm::cast<TileType>(getRhs().getType());
}
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
std::string getIntrinsicName() {
std::string intr = "llvm.x86.tdp";
auto elementType =
getLhsTileType().getElementType();
intr += elementType.isF16() ? "fp16" : "bf16";
intr += "ps.internal";
return intr;
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"qualified(type($lhs)) `,` qualified(type($rhs))"
" `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
AMXIntrinsicOpInterface,
AllTypesMatch<["acc", "res"]>
]> {
let summary = "tile multiplication operation (integer)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
the attributes and default to sign extended).
The operation is eventually lowered into one of the "tdpbssd",
"tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
tile configuration.
Example:
```mlir
%0 = amx.tile_muli %a zext, %b zext, %c
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
```
}];
let arguments = (ins AMXTileI8:$lhs,
AMXTileI8:$rhs,
AMXTileI32:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
let results = (outs AMXTileI32:$res);
let extraClassDeclaration = [{
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
}
TileType getRhsTileType() {
return ::llvm::cast<TileType>(getRhs().getType());
}
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
std::string getIntrinsicName() {
std::string intr = "llvm.x86.tdpb";
intr += getIsZextLhs() ? "u" : "s";
intr += getIsZextRhs() ? "u" : "s";
intr += "d.internal";
return intr;
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
#endif // AMX

View File

@@ -1,34 +0,0 @@
//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Target dialect for AMX in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_
#define MLIR_DIALECT_AMX_AMXDIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
/// Include the generated interface declarations.
#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMX/AMXTypes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.h.inc"
#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_

View File

@@ -1,31 +0,0 @@
//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines interfaces for the AMX dialect.
//
//===----------------------------------------------------------------------===//
#ifndef AMX_INTERFACES
#define AMX_INTERFACES
include "mlir/IR/Interfaces.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
//===----------------------------------------------------------------------===//
// AMX Intrinsic Interface
//===----------------------------------------------------------------------===//
def AMXIntrinsicOpInterface
: OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
let description = [{
A wrapper interface for operations representing AMX LLVM intrinsics.
}];
let cppNamespace = "::mlir::amx";
}
#endif // AMX_INTERFACES

View File

@@ -1,5 +0,0 @@
add_mlir_dialect(AMX amx)
add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
add_mlir_interface(AMXInterfaces)
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)

View File

@@ -1,33 +0,0 @@
//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
#define MLIR_DIALECT_AMX_TRANSFORMS_H
namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
class DialectRegistry;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

View File

@@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(AMDGPU)
add_subdirectory(AMX)
add_subdirectory(Arith)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSME)

View File

@@ -104,10 +104,6 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
desc("Allows compiler to assume indices fit in 32-bit if that yields "
"faster code"),
init(true)};
PassOptions::Option<bool> amx{
*this, "enable-amx",
desc("Enables the use of AMX dialect while lowering the vector dialect"),
init(false)};
PassOptions::Option<bool> armNeon{
*this, "enable-arm-neon",
desc("Enables the use of ArmNeon dialect while lowering the vector "
@@ -168,7 +164,6 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
opts.force32BitVectorIndices = force32BitVectorIndices;
opts.armNeon = armNeon;
opts.armSVE = armSVE;
opts.amx = amx;
opts.x86 = x86;
return opts;
}

View File

@@ -200,13 +200,16 @@ void populateSpecializedTransposeLoweringPatterns(
/// Collect a set of patterns to lower X86 ops to ops that map to LLVM
/// intrinsics.
void populateX86LegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Configure the target to support lowering X86 ops to ops that map to
/// LLVM intrinsics.
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
/// Register LLVM conversion interface for X86 dialect.
void registerConvertX86ToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_DIALECT_X86_TRANSFORMS_H

View File

@@ -17,6 +17,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/X86/X86Interfaces.td"
include "mlir/IR/BuiltinTypes.td"
//===----------------------------------------------------------------------===//
// X86 dialect definition
@@ -25,6 +26,8 @@ include "mlir/Dialect/X86/X86Interfaces.td"
def X86_Dialect : Dialect {
let name = "x86";
let cppNamespace = "::mlir::x86";
let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -673,4 +676,385 @@ def CvtPackedOddIndexedToF32Op
::mlir::RewriterBase &rewriter);
}];
}
//===----------------------------------------------------------------------===//
// AMX Tile definition
//===----------------------------------------------------------------------===//
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
: TypeDef<X86_Dialect, "AMX" # typeName, traits> {
let mnemonic = "amx." # typeMnemonic;
}
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
let cppFunctionName = "isValidTileTypeElementType";
}
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
let summary = "AMX 2D tile to be used by AMX opertaions.";
let description = [{
This type is used to represent values in AMX tile registers. All AMX operations
work on AMX tiles and these tiles cannot be used in other operations directly.
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
element type for IR verification and lowering to LLVMIR dialect.
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
AMX_TileTypeElementType:$elementType
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
return $_get(elementType.getContext(), shape, elementType);
}]>
];
let extraClassDeclaration = [{
/// Returns if this type is ranked (always true).
bool hasRank() const { return true; }
/// Clone this tile type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
AMXTileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return get(shape.value_or(getShape()), elementType);
}
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::x86::AMXTileType>($_self)">,
CPred<[{::llvm::cast<::mlir::x86::AMXTileType>($_self).getRank() == 2}]>]>;
class AMXTileOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::x86::AMXTileType">;
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
def AMXTileF32 : AMXTileOf<[F32]>;
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
def AMXTileI32 : AMXTileOf<[I32]>;
def AMXTileI8 : AMXTileOf<[I8]>;
//===----------------------------------------------------------------------===//
// AMX Op definitions
//===----------------------------------------------------------------------===//
class AMX_Op<string mnemonic, list<Trait> traits = []>
: Op<X86_Dialect, "amx." # mnemonic, traits> {
let cppNamespace = X86_Dialect.cppNamespace # "::amx";
}
//===----------------------------------------------------------------------===//
// AMX Tile Zero
//===----------------------------------------------------------------------===//
def TileZeroOp : AMX_Op<"tile_zero", [
X86IntrinsicOpInterface,
MemoryEffects<[MemWrite]>
]> {
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
vector type of the result.
The operation is eventually lowered into the "tilezero" instruction
with the corresponding tile configuration.
With the write memory effect, each `x86.amx.tile_zero` operation serves as
a compilation hint to use a separate tile register.
Example:
```mlir
%0 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
```
}];
let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
AMXTileType getTileType() {
return ::llvm::cast<AMXTileType>(getRes().getType());
}
std::string getIntrinsicName() {
return "llvm.x86.tilezero.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// AMX Tile Load
//===----------------------------------------------------------------------===//
def TileLoadOp : AMX_Op<"tile_load", [
X86IntrinsicOpInterface,
MemoryEffects<[MemWrite]>,
AttrSizedOperandSegments
]> {
let summary = "tile load operation";
let description = [{
Loads a tile from memory defined by a `base` and `indices`, with the
shape defined by the 2-dim vector type of the result.
The tile's rows are populated by reading contiguous elements starting
at the `base`. For each tile row, the `base` is incremented by `stride`
number of elements.
The tile is loaded using the following indexing scheme:
```
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
tile[row, col] = mem_row[col]
```
If the `stride` is not provided, then the `base` buffer must be at least
2-dimensional, and the `stride` is automatically inferred and corresponds
to the stride of the buffer's second innermost dimension.
The operation is eventually lowered into the "tileloadd" instruction
with the corresponding tile configuration.
With the write memory effect, each `x86.amx.tile_load` operation serves as
a compilation hint to use a separate tile register.
Example:
```mlir
// Tile load from a 2-D memref with implicit stride.
%0 = x86.amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
// Tile load from a 1-D memref with explicit stride.
%0 = x86.amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !x86.amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<Index>:$stride);
let results = (outs AnyAMXTile:$res);
let builders = [
OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
AMXTileType getTileType() {
return ::llvm::cast<AMXTileType>(getRes().getType());
}
std::string getIntrinsicName() {
return "llvm.x86.tileloadd64.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
"`:` type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// AMX Tile Store
//===----------------------------------------------------------------------===//
def TileStoreOp : AMX_Op<"tile_store", [
X86IntrinsicOpInterface,
AttrSizedOperandSegments
]> {
let summary = "tile store operation";
let description = [{
Stores a tile to memory defined by a `base` and `indices`, with the
shape defined by the 2-dim vector type of the value.
The tile's rows are written contiguously to the buffer starting at
the `base`. For each tile row, the `base` is incremented by `stride`
number of elements.
The tile is stored using the following indexing scheme:
```
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
mem_row[col] = tile[row, col]
```
If the `stride` is not provided, then the `base` buffer must be at least
2-dimensional, and the `stride` is automatically inferred and corresponds
to the stride of the buffer's second innermost dimension.
The operation is eventually lowered into the "tilestored" instruction
with the corresponding tile configuration.
Example:
```mlir
// Tile store to a 2-D memref with implicit stride.
x86.amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
// Tile store to a 1-D memref with explicit stride.
x86.amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !x86.amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
AnyAMXTile:$val,
Optional<Index>:$stride);
let builders = [
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
AMXTileType getTileType() {
return ::llvm::cast<AMXTileType>(getVal().getType());
}
std::string getIntrinsicName() {
return "llvm.x86.tilestored64.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
"attr-dict `:` type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// AMX Tile Multiply
//===----------------------------------------------------------------------===//
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
X86IntrinsicOpInterface,
AllTypesMatch<["acc", "res"]>
]> {
let summary = "tile multiplication operation (floating-point)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
pairs of "bf16").
The operation is eventually lowered into the "tdpbf16ps" instruction with
the corresponding tile configuration.
Example:
```mlir
%0 = x86.amx.tile_mulf %a, %b, %c
: !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
```
}];
let arguments = (ins AMXTileF16OrBF16:$lhs,
AMXTileF16OrBF16:$rhs,
AMXTileF32:$acc);
let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
AMXTileType getLhsTileType() {
return ::llvm::cast<AMXTileType>(getLhs().getType());
}
AMXTileType getRhsTileType() {
return ::llvm::cast<AMXTileType>(getRhs().getType());
}
AMXTileType getTileType() {
return ::llvm::cast<AMXTileType>(getRes().getType());
}
std::string getIntrinsicName() {
std::string intr = "llvm.x86.tdp";
auto elementType =
getLhsTileType().getElementType();
intr += elementType.isF16() ? "fp16" : "bf16";
intr += "ps.internal";
return intr;
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"qualified(type($lhs)) `,` qualified(type($rhs))"
" `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
X86IntrinsicOpInterface,
AllTypesMatch<["acc", "res"]>
]> {
let summary = "tile multiplication operation (integer)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
the attributes and default to sign extended).
The operation is eventually lowered into one of the "tdpbssd",
"tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
tile configuration.
Example:
```mlir
%0 = x86.amx.tile_muli %a zext, %b zext, %c
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
```
}];
let arguments = (ins AMXTileI8:$lhs,
AMXTileI8:$rhs,
AMXTileI32:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
let results = (outs AMXTileI32:$res);
let extraClassDeclaration = [{
AMXTileType getLhsTileType() {
return ::llvm::cast<AMXTileType>(getLhs().getType());
}
AMXTileType getRhsTileType() {
return ::llvm::cast<AMXTileType>(getRhs().getType());
}
AMXTileType getTileType() {
return ::llvm::cast<AMXTileType>(getRes().getType());
}
std::string getIntrinsicName() {
std::string intr = "llvm.x86.tdpb";
intr += getIsZextLhs() ? "u" : "s";
intr += getIsZextRhs() ? "u" : "s";
intr += "d.internal";
return intr;
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}
#endif // X86_OPS

View File

@@ -29,6 +29,19 @@
#include "mlir/Dialect/X86/X86Dialect.h.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/X86/X86Types.h.inc"
namespace mlir {
namespace x86 {
namespace amx {
// Alias to allow access to AMX type through nested namespaces
// analogously to AMX operations.
using TileType = mlir::x86::AMXTileType;
} // namespace amx
} // namespace x86
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/X86/X86.h.inc"

View File

@@ -1,13 +0,0 @@
//===- AMX.cpp - C Interface for AMX dialect ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/AMX.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMX, amx, mlir::amx::AMXDialect)

View File

@@ -26,15 +26,6 @@ add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU
MLIRAMDGPUTransforms
)
add_mlir_upstream_c_api_library(MLIRCAPIAMX
AMX.cpp
PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRAMXDialect
)
add_mlir_upstream_c_api_library(MLIRCAPIArith
Arith.cpp
ArithPasses.cpp

View File

@@ -8,7 +8,6 @@ add_mlir_conversion_library(MLIRVectorToAMX
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRAMXDialect
MLIRAffineUtils
MLIRArithDialect
MLIRLinalgUtils
@@ -16,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToAMX
MLIRSCFDialect
MLIRTransforms
MLIRVectorDialect
MLIRX86Dialect
)

View File

@@ -1,4 +1,4 @@
//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
//===- VectorToAMX.cpp - Convert vector to X86 dialect AMX ops --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,7 +8,6 @@
#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -16,6 +15,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -197,7 +197,7 @@ static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
static Operation *
loadStoreFromTransfer(PatternRewriter &rewriter,
VectorTransferOpInterface xferOp, bool isPacked,
TypedValue<amx::TileType> tileToStore = nullptr) {
TypedValue<x86::amx::TileType> tileToStore = nullptr) {
if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
return nullptr;
if (xferOp.hasOutOfBoundsDim() ||
@@ -267,18 +267,18 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
src = collapseLastDim(rewriter, src);
int64_t rows = vecShape[0];
int64_t cols = llvm::product_of(vecShape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
Operation *amxTileOp = nullptr;
if (isa<vector::TransferReadOp>(xferOp)) {
amxTileOp =
amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
amxTileOp = x86::amx::TileLoadOp::create(rewriter, loc, tileType, src,
tileIndicides);
} else if (isa<vector::TransferWriteOp>(xferOp)) {
amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
tileToStore);
amxTileOp = x86::amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
tileToStore);
} else {
llvm_unreachable("unsupported vector transfer op");
}
@@ -289,10 +289,10 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
/// Attempt to create an AMX tile load operation equivalent to the given
/// vector transfer `readOp`.
/// Returns loaded AMX tile if successful.
static FailureOr<TypedValue<amx::TileType>>
static FailureOr<TypedValue<x86::amx::TileType>>
loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
bool isPacked) {
amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
x86::amx::TileLoadOp loadOp = dyn_cast_if_present<x86::amx::TileLoadOp>(
loadStoreFromTransfer(rewriter, readOp, isPacked));
if (!loadOp)
return failure();
@@ -301,16 +301,16 @@ loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
/// Attempt to create an AMX tile store operation equivalent to the given
/// vector transfer `writeOp`.
static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
vector::TransferWriteOp writeOp,
TypedValue<amx::TileType> tileToStore) {
static LogicalResult
storeFromTransfer(PatternRewriter &rewriter, vector::TransferWriteOp writeOp,
TypedValue<x86::amx::TileType> tileToStore) {
return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
tileToStore));
}
/// Load vector values to an AMX tile.
static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
TypedValue<VectorType> vec) {
static TypedValue<x86::amx::TileType> loadTile(PatternRewriter &rewriter,
TypedValue<VectorType> vec) {
Location loc = vec.getLoc();
VectorType vecTy = vec.getType();
@@ -318,7 +318,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
// Try to load tile directly from vector producer's buffer.
auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
FailureOr<TypedValue<amx::TileType>> tile =
FailureOr<TypedValue<x86::amx::TileType>> tile =
loadFromTransfer(rewriter, readOp, isPacked);
if (succeeded(tile))
return *tile;
@@ -337,25 +337,25 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
ArrayRef<int64_t> shape = vecTy.getShape();
int64_t rows = shape[0];
int64_t cols = llvm::product_of(shape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
{zeroIndex, zeroIndex});
return x86::amx::TileLoadOp::create(rewriter, loc, tileType, buf,
{zeroIndex, zeroIndex});
}
/// Store an AMX tile in a vector.
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
TypedValue<amx::TileType> tile) {
TypedValue<x86::amx::TileType> tile) {
Location loc = tile.getLoc();
// Transfer the tile to a vector through an intermediate buffer.
amx::TileType tileTy = tile.getType();
x86::amx::TileType tileTy = tile.getType();
Value buf = memref::AllocaOp::create(
rewriter, loc,
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(2, zeroIndex);
amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
@@ -374,19 +374,21 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
if (failed(validateOperands(rewriter, contractOp)))
return failure();
TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
TypedValue<x86::amx::TileType> lhsTile =
loadTile(rewriter, contractOp.getLhs());
TypedValue<x86::amx::TileType> rhsTile =
loadTile(rewriter, contractOp.getRhs());
auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
assert(acc && "Invalid accumulator type");
TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
TypedValue<x86::amx::TileType> accTile = loadTile(rewriter, acc);
TypedValue<amx::TileType> tileMul;
TypedValue<x86::amx::TileType> tileMul;
if (acc.getType().getElementType().isFloat()) {
tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
lhsTile, rhsTile, accTile);
tileMul = x86::amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
lhsTile, rhsTile, accTile);
} else {
tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
lhsTile, rhsTile, accTile);
tileMul = x86::amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
lhsTile, rhsTile, accTile);
}
// If the contraction result is only written back to memory, try to replace

View File

@@ -38,8 +38,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
MLIRAMXTransforms
MLIRX86Dialect
MLIRX86Transforms
)

View File

@@ -10,8 +10,6 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
@@ -51,8 +49,6 @@ struct ConvertVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86)
registry.insert<x86::X86Dialect>();
}
@@ -136,10 +132,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
}
if (x86) {
configureX86LegalizeForExportTarget(target);
populateX86LegalizeForLLVMExportPatterns(converter, patterns);

View File

@@ -1,2 +0,0 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -1,318 +0,0 @@
//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the AMX dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
void amx::AMXDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMX/AMX.cpp.inc"
>();
}
/// Verify that AMX supports the implied tile shape.
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
const unsigned kMaxRows = 16;
const unsigned kBitsPerRow = 64 * 8;
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
if (tp.getDimSize(0) > kMaxRows)
return op->emitOpError("bad row height: ") << tp.getDimSize(0);
if (col > kBitsPerRow || col & 0x1f)
return op->emitOpError("bad column width: ") << (col >> 3);
return success();
}
/// Verify that AMX supports the multiplication.
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
amx::TileType btp, amx::TileType ctp,
unsigned scale) {
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
if (cm != am || cn != bn || ak != bk)
return op->emitOpError("bad mult shape: ")
<< cm << " x " << cn << " x " << ak;
return success();
}
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
/// dimension directly translates into the number of rows of the tiles.
/// The second dimensions needs to be scaled by the number of bytes.
static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
RewriterBase &rewriter) {
Type llvmInt16Type = rewriter.getIntegerType(16);
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return SmallVector<Value>{
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
/// Returns stride expressed in number of bytes for the given `elementStride`
/// stride encoded in number of elements of the type `mType`.
static Value computeStrideInBytes(Location loc, MemRefType mType,
Value elementStride, RewriterBase &rewriter) {
Type llvmInt64Type = rewriter.getIntegerType(64);
unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
.getResult();
}
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
static Value inferStride(Location loc, MemRefType mType, Value base,
RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
auto [strides, offset] = mType.getStridesAndOffset();
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
return computeStrideInBytes(
loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
.getResult();
}
LogicalResult amx::TileZeroOp::verify() {
return verifyTileSize(*this, getTileType());
}
SmallVector<Value>
amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
return getTileSizes(getLoc(), getTileType(), rewriter);
}
template <typename OpTy,
typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
std::is_same_v<OpTy, amx::TileStoreOp>>>
static LogicalResult tileTransferVerifier(OpTy op) {
MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
if (op.getIndices().size() != rank)
return op.emitOpError("requires ") << rank << " indices";
if (failed(verifyTileSize(op, op.getTileType())))
return failure();
// Validate basic buffer properties when the stride is implicit.
if (!op.getStride()) {
if (rank < 2)
return op.emitOpError("requires at least 2D memref");
SmallVector<int64_t> strides;
int64_t offset;
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return op.emitOpError("requires memref with unit innermost stride");
}
return success();
}
void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
Value base, ValueRange indices) {
build(builder, state, res, base, indices, /*stride=*/nullptr);
}
LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
SmallVector<Value>
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
SmallVector<Value> intrinsicOperands;
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
if (Value stride = adaptor.getStride())
intrinsicOperands.push_back(
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
else
intrinsicOperands.push_back(
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
return intrinsicOperands;
}
void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
Value base, ValueRange indices, Value val) {
build(builder, state, base, indices, val, /*stride=*/nullptr);
}
LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
SmallVector<Value>
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
SmallVector<Value> intrinsicOperands;
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
if (Value stride = adaptor.getStride())
intrinsicOperands.push_back(
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
else
intrinsicOperands.push_back(
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());
return intrinsicOperands;
}
LogicalResult amx::TileMulFOp::verify() {
amx::TileType aType = getLhsTileType();
amx::TileType bType = getRhsTileType();
amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
failed(verifyMultShape(*this, aType, bType, cType, 1)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
SmallVector<Value>
amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
amx::TileType aType = getLhsTileType();
amx::TileType bType = getRhsTileType();
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
tsza[1], adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs()};
return intrinsicOperands;
}
LogicalResult amx::TileMulIOp::verify() {
amx::TileType aType = getLhsTileType();
amx::TileType bType = getRhsTileType();
amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
failed(verifyMultShape(*this, aType, bType, cType, 2)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
return emitOpError("unsupported type combination");
return success();
}
SmallVector<Value>
amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
amx::TileType aType = getLhsTileType();
amx::TileType bType = getRhsTileType();
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
tsza[1], adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs()};
return intrinsicOperands;
}
Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseLess())
return nullptr;
SmallVector<int64_t, 2> shape;
if (parser.parseDimensionList(shape, false, true))
return nullptr;
Type elementType;
if (parser.parseType(elementType))
return nullptr;
if (parser.parseGreater())
return nullptr;
return TileType::getChecked(
[&] { return parser.emitError(parser.getNameLoc()); }, shape,
elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
os << "<";
os.printDimensionList(getShape());
os << 'x';
os.printType(getElementType());
os << '>';
}
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"

View File

@@ -1,15 +0,0 @@
add_mlir_dialect_library(MLIRAMXDialect
AMXDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX
DEPENDS
MLIRAMXIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRSideEffectInterfaces
)

View File

@@ -1,9 +0,0 @@
add_mlir_dialect_library(MLIRAMXTransforms
LegalizeForLLVMExport.cpp
LINK_LIBS PUBLIC
MLIRAMXDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
)

View File

@@ -1,70 +0,0 @@
//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::amx;
namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct AMXIntrinsicOpConversion
: public ConvertOpInterfaceToLLVMPattern<amx::AMXIntrinsicOp> {
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
}
};
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<AMXIntrinsicOpConversion>(converter);
converter.addConversion([&](amx::TileType type) {
return LLVM::LLVMX86AMXType::get(&converter.getContext());
});
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<AMXDialect>();
}
namespace {
/// Implement the interface to convert AMX to LLVM.
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
}
};
} // namespace
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
dialect->addInterfaces<AMXToLLVMDialectInterface>();
});
}

View File

@@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(AMDGPU)
add_subdirectory(AMX)
add_subdirectory(Arith)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSME)

View File

@@ -11,10 +11,16 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86/X86Dialect.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
#include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
@@ -22,6 +28,11 @@ using namespace mlir;
#include "mlir/Dialect/X86/X86Dialect.cpp.inc"
void x86::X86Dialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/X86/X86Types.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/X86/X86.cpp.inc"
@@ -107,5 +118,279 @@ SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
/// Verify that AMX supports the implied tile shape.
static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp) {
const unsigned kMaxRows = 16;
const unsigned kBitsPerRow = 64 * 8;
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
if (tp.getDimSize(0) > kMaxRows)
return op->emitOpError("bad row height: ") << tp.getDimSize(0);
if (col > kBitsPerRow || col & 0x1f)
return op->emitOpError("bad column width: ") << (col >> 3);
return success();
}
/// Verify that AMX supports the multiplication.
static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp,
x86::amx::TileType btp,
x86::amx::TileType ctp, unsigned scale) {
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
if (cm != am || cn != bn || ak != bk)
return op->emitOpError("bad mult shape: ")
<< cm << " x " << cn << " x " << ak;
return success();
}
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
/// dimension directly translates into the number of rows of the tiles.
/// The second dimensions needs to be scaled by the number of bytes.
static SmallVector<Value> getTileSizes(Location loc, x86::amx::TileType tType,
RewriterBase &rewriter) {
Type llvmInt16Type = rewriter.getIntegerType(16);
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return SmallVector<Value>{
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
/// Returns stride expressed in number of bytes for the given `elementStride`
/// stride encoded in number of elements of the type `mType`.
static Value computeStrideInBytes(Location loc, MemRefType mType,
Value elementStride, RewriterBase &rewriter) {
Type llvmInt64Type = rewriter.getIntegerType(64);
unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
.getResult();
}
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
static Value inferStride(Location loc, MemRefType mType, Value base,
RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
auto [strides, offset] = mType.getStridesAndOffset();
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
return computeStrideInBytes(
loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
.getResult();
}
LogicalResult x86::amx::TileZeroOp::verify() {
return verifyTileSize(*this, getTileType());
}
SmallVector<Value> x86::amx::TileZeroOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
return getTileSizes(getLoc(), getTileType(), rewriter);
}
template <typename OpTy, typename = std::enable_if_t<
std::is_same_v<OpTy, x86::amx::TileLoadOp> ||
std::is_same_v<OpTy, x86::amx::TileStoreOp>>>
static LogicalResult tileTransferVerifier(OpTy op) {
MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
if (op.getIndices().size() != rank)
return op.emitOpError("requires ") << rank << " indices";
if (failed(verifyTileSize(op, op.getTileType())))
return failure();
// Validate basic buffer properties when the stride is implicit.
if (!op.getStride()) {
if (rank < 2)
return op.emitOpError("requires at least 2D memref");
SmallVector<int64_t> strides;
int64_t offset;
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return op.emitOpError("requires memref with unit innermost stride");
}
return success();
}
void x86::amx::TileLoadOp::build(OpBuilder &builder, OperationState &state,
Type res, Value base, ValueRange indices) {
build(builder, state, res, base, indices, /*stride=*/nullptr);
}
LogicalResult x86::amx::TileLoadOp::verify() {
return tileTransferVerifier(*this);
}
SmallVector<Value> x86::amx::TileLoadOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
SmallVector<Value> intrinsicOperands;
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
if (Value stride = adaptor.getStride())
intrinsicOperands.push_back(
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
else
intrinsicOperands.push_back(
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
return intrinsicOperands;
}
void x86::amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
Value base, ValueRange indices, Value val) {
build(builder, state, base, indices, val, /*stride=*/nullptr);
}
LogicalResult x86::amx::TileStoreOp::verify() {
return tileTransferVerifier(*this);
}
SmallVector<Value> x86::amx::TileStoreOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
SmallVector<Value> intrinsicOperands;
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
if (Value stride = adaptor.getStride())
intrinsicOperands.push_back(
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
else
intrinsicOperands.push_back(
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());
return intrinsicOperands;
}
LogicalResult x86::amx::TileMulFOp::verify() {
x86::amx::TileType aType = getLhsTileType();
x86::amx::TileType bType = getRhsTileType();
x86::amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
failed(verifyMultShape(*this, aType, bType, cType, 1)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
SmallVector<Value> x86::amx::TileMulFOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
x86::amx::TileType aType = getLhsTileType();
x86::amx::TileType bType = getRhsTileType();
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
tsza[1], adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs()};
return intrinsicOperands;
}
LogicalResult x86::amx::TileMulIOp::verify() {
x86::amx::TileType aType = getLhsTileType();
x86::amx::TileType bType = getRhsTileType();
x86::amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
failed(verifyMultShape(*this, aType, bType, cType, 2)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
return emitOpError("unsupported type combination");
return success();
}
SmallVector<Value> x86::amx::TileMulIOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
Adaptor adaptor(operands, *this);
x86::amx::TileType aType = getLhsTileType();
x86::amx::TileType bType = getRhsTileType();
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
tsza[1], adaptor.getAcc(),
adaptor.getLhs(), adaptor.getRhs()};
return intrinsicOperands;
}
Type x86::amx::TileType::parse(AsmParser &parser) {
if (parser.parseLess())
return nullptr;
SmallVector<int64_t, 2> shape;
if (parser.parseDimensionList(shape, false, true))
return nullptr;
Type elementType;
if (parser.parseType(elementType))
return nullptr;
if (parser.parseGreater())
return nullptr;
return AMXTileType::getChecked(
[&] { return parser.emitError(parser.getNameLoc()); }, shape,
elementType);
}
void x86::amx::TileType::print(AsmPrinter &os) const {
os << "<";
os.printDimensionList(getShape());
os << 'x';
os.printType(getElementType());
os << '>';
}
#define GET_OP_CLASSES
#include "mlir/Dialect/X86/X86.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/X86/X86Types.cpp.inc"

View File

@@ -8,6 +8,7 @@
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/X86/X86Dialect.h"
@@ -39,10 +40,32 @@ struct X86IntrinsicOpConversion
/// Populate the given list with patterns that convert from X86 to LLVM.
void mlir::populateX86LegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<X86IntrinsicOpConversion>(converter);
converter.addConversion([&](x86::amx::TileType type) {
return LLVM::LLVMX86AMXType::get(&converter.getContext());
});
}
void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<X86Dialect>();
}
namespace {
/// Implement the interface to convert X86 to LLVM.
struct X86ToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateX86LegalizeForLLVMExportPatterns(typeConverter, patterns);
}
};
} // namespace
void mlir::registerConvertX86ToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, x86::X86Dialect *dialect) {
dialect->addInterfaces<X86ToLLVMDialectInterface>();
});
}

View File

@@ -14,7 +14,6 @@
#include "mlir/InitAllDialects.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -111,7 +110,6 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
registry.insert<acc::OpenACCDialect,
affine::AffineDialect,
amdgpu::AMDGPUDialect,
amx::AMXDialect,
arith::ArithDialect,
arm_neon::ArmNeonDialect,
arm_sme::ArmSMEDialect,

View File

@@ -32,7 +32,6 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -57,6 +56,7 @@
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
#include "mlir/Dialect/X86/Transforms.h"
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -90,10 +90,10 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
registerConvertAMXToLLVMInterface(registry);
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
registerConvertX86ToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);

View File

@@ -32,8 +32,8 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
"The GPU compilation format used by the tests.")
set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
"Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
option(MLIR_RUN_X86_TESTS "Run X86 tests.")
option(MLIR_RUN_X86_AMX_TESTS "Run X86 AMX tests.")
option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
@@ -77,7 +77,7 @@ llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_SPIRV_CPU_RUNNER
MLIR_ENABLE_VULKAN_RUNNER
MLIR_INCLUDE_INTEGRATION_TESTS
MLIR_RUN_AMX_TESTS
MLIR_RUN_X86_AMX_TESTS
MLIR_RUN_CUDA_TENSOR_CORE_TESTS
MLIR_RUN_X86_TESTS
MLIR_RUN_ARM_SVE_TESTS

View File

@@ -34,27 +34,27 @@ func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
// CHECK: vector.transfer_write %[[A]], %[[A_BUF]]
// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
// CHECK: %[[A_TILE:.+]] = x86.amx.tile_load %[[A_BUF_2D]]
/// Load B vector into an AMX tile
// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
// CHECK: vector.transfer_write %[[B]], %[[B_BUF]]
// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
// CHECK: %[[B_TILE:.+]] = x86.amx.tile_load %[[B_BUF_2D]]
/// Load C vector into an AMX tile
// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
// CHECK: vector.transfer_write %[[C]], %[[C_BUF]]
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_BUF]]
/// Perform tile multiplication
// CHECK: %[[RES:.+]] = amx.tile_mulf
// CHECK: %[[RES:.+]] = x86.amx.tile_mulf
// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
/// Load the result back into a vector
// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
// CHECK: x86.amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
// CHECK: return %[[RES_VEC]]
@@ -75,9 +75,9 @@ func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
}
// CHECK-LABEL: @contract_vnni_bf16(
// CHECK-COUNT-3: amx.tile_load
// CHECK: amx.tile_mulf
// CHECK: amx.tile_store
// CHECK-COUNT-3: x86.amx.tile_load
// CHECK: x86.amx.tile_mulf
// CHECK: x86.amx.tile_store
// -----
@@ -95,9 +95,9 @@ func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
}
// CHECK-LABEL: @contract_vnni_i8(
// CHECK-COUNT-3: amx.tile_load
// CHECK: amx.tile_muli
// CHECK: amx.tile_store
// CHECK-COUNT-3: x86.amx.tile_load
// CHECK: x86.amx.tile_muli
// CHECK: x86.amx.tile_store
// -----
@@ -115,9 +115,9 @@ func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4x
}
// CHECK-LABEL: @contract_shuffled_iterators(
// CHECK-COUNT-3: amx.tile_load
// CHECK: amx.tile_muli
// CHECK: amx.tile_store
// CHECK-COUNT-3: x86.amx.tile_load
// CHECK: x86.amx.tile_muli
// CHECK: x86.amx.tile_store
// -----

View File

@@ -38,7 +38,7 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
// CHECK: %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
// CHECK: %[[A_TILE:.+]] = x86.amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
// CHECK-NOT: vector.transfer_read %[[A]]
@@ -47,25 +47,25 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
// CHECK: %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
// CHECK: %[[B_TILE:.+]] = x86.amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
// CHECK-NOT: vector.transfer_read %[[B]]
/// Load C into an AMX tile
// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]]
// CHECK-SAME: {{\[}}%[[C0]], %[[C0]]{{\]}}
// CHECK-NOT: vector.transfer_read %[[C]]
/// Perform tile multiplication
// CHECK: %[[RES:.+]] = amx.tile_mulf
// CHECK: %[[RES:.+]] = x86.amx.tile_mulf
// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
/// Store the result back
// CHECK: %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
// CHECK: amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
// CHECK: x86.amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
// CHECK-NOT: vector.transfer_write{{.*}}%[[C]]
// -----
@@ -130,17 +130,17 @@ func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
/// Load to AMX tile directly from buffer.
// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]]
/// Vector read remains to load data for the other non-AMX consumer.
// CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
/// Contraction uses the directly loaded tile.
// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf{{.*}}%[[C_TILE]]
/// Consumer uses original C value and the updated one after contraction.
// CHECK: %[[RES_BUF:.+]] = memref.alloca
// CHECK: amx.tile_store %[[RES_BUF]]
// CHECK: x86.amx.tile_store %[[RES_BUF]]
// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]], %[[RES_VEC]]
@@ -168,7 +168,7 @@ func.func @negative_contract_multiple_users(%C: memref<64x64xf32>,
// CHECK-LABEL: @negative_contract_multiple_users(
// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf
// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf
// CHECK: vector.transfer_write{{.*}}%[[C]]
// -----

View File

@@ -18,7 +18,6 @@
// CHECK: builtin.module(
// CHECK-SAME: convert-vector-to-llvm{
// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
// CHECK-SAME: enable-x86={{[aA-zZ0-9]+}}

View File

@@ -1,158 +0,0 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @tile_row_height() {
// expected-error@+1 {{'amx.tile_zero' op bad row height: 17}}
%0 = amx.tile_zero : !amx.tile<17x16xbf16>
return
}
// -----
func.func @tile_col_width() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 65}}
%0 = amx.tile_zero : !amx.tile<16x65xi8>
return
}
// -----
func.func @tile_element_type() {
// expected-error@+1 {{failed to verify 'elementType'}}
%0 = amx.tile_zero : !amx.tile<8x8xi16>
return
}
// -----
func.func @tile_rank() {
// expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
%0 = amx.tile_zero : !amx.tile<32xi8>
return
}
// -----
func.func @tile_col_4_byte_multiple() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : !amx.tile<16x5xi8>
return
}
// -----
func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
return
}
// -----
func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
return
}
// -----
func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
// -----
func.func @load_base_rank(%arg0: memref<?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires at least 2D memref}}
%1 = amx.tile_load %arg0[%0] : memref<?xf32> into !amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op requires at least 2D memref}}
amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !amx.tile<16x16xf32>
return
}
// -----
func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires memref with unit innermost stride}}
%1 = amx.tile_load %arg0[%0, %0]
: memref<?x?xf32, strided<[?, ?]>> into !amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
%arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_store' op requires memref with unit innermost stride}}
amx.tile_store %arg0[%0, %0], %arg1
: memref<?x?xf32, strided<[?, ?]>>, !amx.tile<16x16xf32>
return
}
// -----
func.func @mulf_shape() {
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
%1 = amx.tile_zero : !amx.tile<8x8xbf16>
%2 = amx.tile_zero : !amx.tile<4x4xf32>
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
return
}
// -----
func.func @mulf_type_combination() {
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
%1 = amx.tile_zero : !amx.tile<4x8xf16>
%2 = amx.tile_zero : !amx.tile<8x4xf32>
// expected-error@+1 {{'amx.tile_mulf' op unsupported type combination}}
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<4x8xf16>, !amx.tile<8x4xf32>
return
}
// -----
func.func @muli_shape() {
%0 = amx.tile_zero : !amx.tile<8x8xi8>
%1 = amx.tile_zero : !amx.tile<8x8xi8>
%2 = amx.tile_zero : !amx.tile<4x4xi32>
// expected-error@+1 {{'amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
%3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x8xi8>, !amx.tile<8x8xi8>, !amx.tile<4x4xi32>
return
}
// -----
func.func @muli_type_combination() {
%0 = amx.tile_zero : !amx.tile<8x16xi8>
%1 = amx.tile_zero : !amx.tile<8x16xi32>
%2 = amx.tile_zero : !amx.tile<2x2xi32>
// expected-error@+1 {{'amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
%3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x16xi8>, !amx.tile<8x16xi32>, !amx.tile<2x2xi32>
return
}

View File

@@ -1,77 +0,0 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tloadstore
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} :
// CHECK-SAME: memref<?xbf16> into !amx.tile<16x32xbf16>
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} :
// CHECK-SAME: memref<?x?xbf16> into !amx.tile<16x32xbf16>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] :
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} :
// CHECK-SAME: memref<?xbf16>, !amx.tile<16x32xbf16>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} :
// CHECK-SAME: memref<?x?xbf16>, !amx.tile<16x32xbf16>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] :
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
func.func @tloadstore(%stride: index,
%arg0: memref<?xbf16>,
%arg1: memref<?x?xbf16>,
%arg2: memref<?x?xbf16, strided<[64, 1]>>) {
%0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%1 = amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !amx.tile<16x32xbf16>
amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !amx.tile<16x32xbf16>
amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
return
}
// CHECK-LABEL: tzero
// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : !amx.tile<16x16xbf16>
amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
return
}
// CHECK-LABEL: tmulf
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
%3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
return
}
// CHECK-LABEL: tmuli
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
// Verify the various `zext` combinations.
%5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
return
}

View File

@@ -1,32 +0,0 @@
// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-amx" | FileCheck %s
// With inclusion of memory side-effects, it is expected CSE not to fold multiple
// "tileload" and "tilezero".
// CHECK-LABEL: do_not_fold_tiles(
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%alloca = memref.alloca() : memref<16x32xf32>
%0 = amx.tile_zero : !amx.tile<16x16xf32>
%1 = amx.tile_zero : !amx.tile<16x16xf32>
%2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>) {
%3 = amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
%4 = amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
%5 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
%6 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
%7 = amx.tile_mulf %3, %5, %arg3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
%8 = amx.tile_mulf %4, %6, %arg4 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
scf.yield %7, %8 : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
}
amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !amx.tile<16x16xf32>
amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !amx.tile<16x16xf32>
return %alloca : memref<16x32xf32>
}

View File

@@ -635,10 +635,10 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
// -----
func.func @invalid_type_matmul(%arg0 : !amx.tile<16x16xbf16>)
func.func @invalid_type_matmul(%arg0 : !x86.amx.tile<16x16xbf16>)
{
// expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
%0 = linalg.matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
// expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
%0 = linalg.matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
return
}
@@ -1582,10 +1582,10 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
// -----
func.func @invalid_type_batch_matmul(%arg0 : !amx.tile<16x16xbf16>)
func.func @invalid_type_batch_matmul(%arg0 : !x86.amx.tile<16x16xbf16>)
{
// expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
%0 = linalg.batch_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
// expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
%0 = linalg.batch_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
return
}
@@ -1790,10 +1790,10 @@ func.func @invalid_C_map_result_dim(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>
// -----
func.func @batch_reduce_matmul_invalid_type(%arg0 : !amx.tile<16x16xbf16>)
func.func @batch_reduce_matmul_invalid_type(%arg0 : !x86.amx.tile<16x16xbf16>)
{
// expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
// expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
return
}

View File

@@ -0,0 +1,158 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @tile_row_height() {
// expected-error@+1 {{'x86.amx.tile_zero' op bad row height: 17}}
%0 = x86.amx.tile_zero : !x86.amx.tile<17x16xbf16>
return
}
// -----
func.func @tile_col_width() {
// expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 65}}
%0 = x86.amx.tile_zero : !x86.amx.tile<16x65xi8>
return
}
// -----
func.func @tile_element_type() {
// expected-error@+1 {{failed to verify 'elementType'}}
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi16>
return
}
// -----
func.func @tile_rank() {
// expected-error@+1 {{'x86.amx.tile_zero' op result #0 must be tile of}}
%0 = x86.amx.tile_zero : !x86.amx.tile<32xi8>
return
}
// -----
func.func @tile_col_4_byte_multiple() {
// expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 5}}
%0 = x86.amx.tile_zero : !x86.amx.tile<16x5xi8>
return
}
// -----
func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op bad column width: 68}}
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x17xf32>
return
}
// -----
func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op bad column width: 68}}
x86.amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x17xf32>
return
}
// -----
func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op requires 2 indices}}
%1 = x86.amx.tile_load %arg0[%0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op requires 2 indices}}
x86.amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @load_base_rank(%arg0: memref<?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op requires at least 2D memref}}
%1 = x86.amx.tile_load %arg0[%0] : memref<?xf32> into !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op requires at least 2D memref}}
x86.amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op requires memref with unit innermost stride}}
%1 = x86.amx.tile_load %arg0[%0, %0]
: memref<?x?xf32, strided<[?, ?]>> into !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
%arg1: !x86.amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op requires memref with unit innermost stride}}
x86.amx.tile_store %arg0[%0, %0], %arg1
: memref<?x?xf32, strided<[?, ?]>>, !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @mulf_shape() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
%1 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
%2 = x86.amx.tile_zero : !x86.amx.tile<4x4xf32>
// expected-error@+1 {{'x86.amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x4xf32>
return
}
// -----
func.func @mulf_type_combination() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
%1 = x86.amx.tile_zero : !x86.amx.tile<4x8xf16>
%2 = x86.amx.tile_zero : !x86.amx.tile<8x4xf32>
// expected-error@+1 {{'x86.amx.tile_mulf' op unsupported type combination}}
%3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x8xf16>, !x86.amx.tile<8x4xf32>
return
}
// -----
func.func @muli_shape() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
%1 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
%2 = x86.amx.tile_zero : !x86.amx.tile<4x4xi32>
// expected-error@+1 {{'x86.amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
%3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x8xi8>, !x86.amx.tile<8x8xi8>, !x86.amx.tile<4x4xi32>
return
}
// -----
func.func @muli_type_combination() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x16xi8>
%1 = x86.amx.tile_zero : !x86.amx.tile<8x16xi32>
%2 = x86.amx.tile_zero : !x86.amx.tile<2x2xi32>
// expected-error@+1 {{'x86.amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
%3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x16xi8>, !x86.amx.tile<8x16xi32>, !x86.amx.tile<2x2xi32>
return
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s
// CHECK-LABEL: muli(
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
@@ -14,17 +14,17 @@
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : !amx.tile<16x64xi8>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
%5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
%7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !amx.tile<16x16xi32>
%1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8>
%2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
%3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
%5 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
%6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
%7 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
return
}
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
%1 = x86.amx.tile_zero : !x86.amx.tile<16x32xbf16>
%2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
%3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
%4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
@@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_zero : !amx.tile<16x32xf16>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
%1 = x86.amx.tile_zero : !x86.amx.tile<16x32xf16>
%2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
%3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
%4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32>
x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
@@ -84,12 +84,12 @@ func.func @strides_implicit(%arg0: memref<16x32xi8>,
%arg1: memref<32x32xbf16, strided<[64, 1]>>,
%arg2: memref<16x32xf32, strided<[?, 1]>>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32>
amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8>
amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !x86.amx.tile<16x32xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !x86.amx.tile<16x16xf32>
x86.amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !x86.amx.tile<16x32xi8>
x86.amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
x86.amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !x86.amx.tile<16x16xf32>
return
}
@@ -123,11 +123,11 @@ func.func @strides_explicit(%stride: index,
%arg2: memref<32x32xf32, strided<[64, 1]>>) {
%0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%1 = amx.tile_load %arg0[%0], %stride : memref<?xi8> into !amx.tile<16x32xi8>
%2 = amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !amx.tile<16x32xbf16>
%3 = amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !amx.tile<16x16xf32>
amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !amx.tile<16x32xi8>
amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !amx.tile<16x32xbf16>
amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !amx.tile<16x16xf32>
%1 = x86.amx.tile_load %arg0[%0], %stride : memref<?xi8> into !x86.amx.tile<16x32xi8>
%2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !x86.amx.tile<16x32xbf16>
%3 = x86.amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !x86.amx.tile<16x16xf32>
x86.amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !x86.amx.tile<16x32xi8>
x86.amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !x86.amx.tile<16x32xbf16>
x86.amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !x86.amx.tile<16x16xf32>
return
}

View File

@@ -0,0 +1,77 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// CHECK-LABEL: tloadstore
// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} :
// CHECK-SAME: memref<?xbf16> into !x86.amx.tile<16x32xbf16>
// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} :
// CHECK-SAME: memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] :
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} :
// CHECK-SAME: memref<?xbf16>, !x86.amx.tile<16x32xbf16>
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} :
// CHECK-SAME: memref<?x?xbf16>, !x86.amx.tile<16x32xbf16>
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] :
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
func.func @tloadstore(%stride: index,
%arg0: memref<?xbf16>,
%arg1: memref<?x?xbf16>,
%arg2: memref<?x?xbf16, strided<[64, 1]>>) {
%0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%1 = x86.amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !x86.amx.tile<16x32xbf16>
%2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
x86.amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !x86.amx.tile<16x32xbf16>
x86.amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !x86.amx.tile<16x32xbf16>
x86.amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
return
}
// CHECK-LABEL: tzero
// CHECK: x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !x86.amx.tile<16x16xbf16>
func.func @tzero(%arg0: memref<?x?xbf16>) {
%0 = arith.constant 0 : index
%1 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
x86.amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !x86.amx.tile<16x16xbf16>
return
}
// CHECK-LABEL: tmulf
// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
// CHECK: %[[m:.*]] = x86.amx.tile_mulf %[[x]], %[[x]], %[[z]] : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
%0 = arith.constant 0 : index
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
%3 = x86.amx.tile_mulf %1, %1, %2 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
x86.amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
// CHECK-LABEL: tmuli
// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
// CHECK: %[[m:.*]] = x86.amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: x86.amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = arith.constant 0 : index
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
// Verify the various `zext` combinations.
%5 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
%6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
%7 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
return
}

View File

@@ -0,0 +1,32 @@
// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-x86" | FileCheck %s
// With inclusion of memory side-effects, it is expected CSE not to fold multiple
// "tileload" and "tilezero".
// CHECK-LABEL: do_not_fold_tiles(
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%alloca = memref.alloca() : memref<16x32xf32>
%0 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
%1 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
%2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
%3 = x86.amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16>
%4 = x86.amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16>
%5 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16>
%6 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16>
%7 = x86.amx.tile_mulf %3, %5, %arg3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
%8 = x86.amx.tile_mulf %4, %6, %arg4 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
scf.yield %7, %8 : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
}
x86.amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !x86.amx.tile<16x16xf32>
x86.amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !x86.amx.tile<16x16xf32>
return %alloca : memref<16x32xf32>
}

View File

@@ -1,7 +1,7 @@
import sys
# AMX tests must be enabled via build flag.
if not config.mlir_run_amx_tests:
if not config.mlir_run_x86_amx_tests:
config.unsupported = True
# No JIT on win32.

View File

@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" \
// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" \
// RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
@@ -14,11 +14,11 @@ func.func @kernel(%arg0: memref<16x32xbf16>,
%arg1: memref<16x32xbf16>,
%arg2: memref<16x16xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%3 = amx.tile_zero : vector<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%3 = x86.amx.tile_zero : vector<16x16xf32>
%4 = x86.amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
return
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x4xbf16>,
%arg1: memref<2x4xbf16>,
%arg2: memref<2x2xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%3 = amx.tile_zero : vector<2x2xf32>
%4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%3 = x86.amx.tile_zero : vector<2x2xf32>
%4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
return
}
@@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x4xbf16>,
%arg1: memref<2x4xbf16>,
%arg2: memref<2x2xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
%4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
%4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
return
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -21,11 +21,11 @@ func.func @kernel1(%arg0: memref<16x16xi8>,
%arg1: memref<4x16xi8>,
%arg2: memref<16x4xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = x86.amx.tile_zero : vector<16x4xi32>
%4 = x86.amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@@ -33,11 +33,11 @@ func.func @kernel2(%arg0: memref<16x16xi8>,
%arg1: memref<4x16xi8>,
%arg2: memref<16x4xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = x86.amx.tile_zero : vector<16x4xi32>
%4 = x86.amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@@ -45,11 +45,11 @@ func.func @kernel3(%arg0: memref<16x16xi8>,
%arg1: memref<4x16xi8>,
%arg2: memref<16x4xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = x86.amx.tile_zero : vector<16x4xi32>
%4 = x86.amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@@ -57,11 +57,11 @@ func.func @kernel4(%arg0: memref<16x16xi8>,
%arg1: memref<4x16xi8>,
%arg2: memref<16x4xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = x86.amx.tile_zero : vector<16x4xi32>
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}

View File

@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm="enable-amx" \
// RUN: -convert-vector-to-llvm="enable-x86" \
// RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
@@ -15,11 +15,11 @@ func.func @kernel(%arg0: memref<16x64xi8>,
%arg1: memref<16x64xi8>,
%arg2: memref<16x16xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
%3 = amx.tile_zero : vector<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
%3 = x86.amx.tile_zero : vector<16x16xi32>
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32>
return
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x8xi8>,
%arg1: memref<2x8xi8>,
%arg2: memref<2x2xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = amx.tile_zero : vector<2x2xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = x86.amx.tile_zero : vector<2x2xi32>
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
return
}
@@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x8xi8>,
%arg1: memref<2x8xi8>,
%arg2: memref<2x2xi32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
return
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -25,8 +25,8 @@ func.func @kernel(%arg0: memref<4x32xf32>) {
%c32 = arith.constant 32 : index
scf.for %i = %c0 to %c4 step %c2 {
scf.for %j = %c0 to %c32 step %c16 {
%0 = amx.tile_zero : vector<2x16xf32>
amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
%0 = x86.amx.tile_zero : vector<2x16xf32>
x86.amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
func.call @print(%arg0) : (memref<4x32xf32>) -> ()
}
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
@@ -6,8 +6,8 @@
// Note: To run this test, your CPU must support AMX.
func.func @tilezero(%arg0: memref<?x?xi32>, %i: index, %j: index) {
%1 = amx.tile_zero : vector<16x16xi32>
amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
%1 = x86.amx.tile_zero : vector<16x16xi32>
x86.amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
return
}

View File

@@ -5,14 +5,15 @@
func.func @entry() -> i32 {
%i0 = arith.constant 0 : i32
%i4 = arith.constant 4 : i32
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32>
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
%r = x86.avx.intr.dot %a, %b : vector<8xf32>
%1 = vector.extract %r[%i0] : f32 from vector<8xf32>
%2 = vector.extract %r[%i4] : f32 from vector<8xf32>
%1 = vector.extract %r[%c0] : f32 from vector<8xf32>
%2 = vector.extract %r[%c4] : f32 from vector<8xf32>
%d = arith.addf %1, %2 : f32
// CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 )

View File

@@ -180,8 +180,7 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
-> f64 {
// Helper constants for loops.
%c0 = arith.constant 0 : index
%i0 = arith.constant 0 : i32
%i7 = arith.constant 7 : i32
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%data_zero = arith.constant 0.0 : f64
@@ -196,13 +195,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) {
%v_A = vector.transfer_read %m_A[%a], %index_padding
: memref<?xi64>, vector<8xi64>
%segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
%segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64>
%r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8
iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) {
%v_C = vector.transfer_read %m_C[%b], %index_padding
: memref<?xi64>, vector<8xi64>
%segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
%segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
%seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
%r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) {
@@ -251,8 +250,7 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
-> f64 {
// Helper constants for loops.
%c0 = arith.constant 0 : index
%i0 = arith.constant 0 : i32
%i7 = arith.constant 7 : i32
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%data_zero = arith.constant 0.0 : f64
@@ -273,10 +271,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
%v_C = vector.transfer_read %m_C[%b1], %index_padding
: memref<?xi64>, vector<8xi64>
%segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
%segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
%segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64>
%segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
%segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64>
%segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64>
%segB_min = vector.extract %v_C[%c0] : i64 from vector<8xi64>
%segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
%seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
%r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
@@ -340,7 +338,7 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
-> f64 {
// Helper constants for loops.
%c0 = arith.constant 0 : index
%i7 = arith.constant 7 : i32
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%data_zero = arith.constant 0.0 : f64
@@ -370,8 +368,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
-> f64
%r2 = arith.addf %r1, %subresult : f64
%segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
%segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
%segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64>
%segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
%cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64
%cond_a_i64 = arith.extui %cond_a : i1 to i64

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: | mlir-translate --mlir-to-llvmir \
// RUN: | FileCheck %s
@@ -7,8 +7,8 @@ func.func @amx_tile_zero(%out: memref<?x?xf32>, %idx: index)
{
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
// CHECK: call void @llvm.x86.tilestored64.internal
%zero = amx.tile_zero : !amx.tile<16x16xf32>
amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !amx.tile<16x16xf32>
%zero = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
x86.amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
@@ -18,8 +18,8 @@ func.func @amx_tile_load_store(%base: memref<?x?xi8>, %out: memref<?x?xi8>,
{
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
// CHECK: call void @llvm.x86.tilestored64.internal
%val = amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !amx.tile<16x64xi8>
%val = x86.amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
x86.amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
return
}
@@ -29,10 +29,10 @@ func.func @amx_tile_load_store_strided(%base: memref<?xi8>, %out: memref<?xi8>,
{
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
// CHECK: call void @llvm.x86.tilestored64.internal
%val = amx.tile_load %base[%idx], %stride
: memref<?xi8> into !amx.tile<16x64xi8>
amx.tile_store %out[%idx], %val, %stride
: memref<?xi8>, !amx.tile<16x64xi8>
%val = x86.amx.tile_load %base[%idx], %stride
: memref<?xi8> into !x86.amx.tile<16x64xi8>
x86.amx.tile_store %out[%idx], %val, %stride
: memref<?xi8>, !x86.amx.tile<16x64xi8>
return
}
@@ -42,15 +42,15 @@ func.func @amx_tile_mulf_bf16(
%out: memref<?x?xf32>)
{
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
%acc = amx.tile_zero : !amx.tile<16x16xf32>
%acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
// CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
%tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
%tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
// CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal
%tRes = amx.tile_mulf %tA, %tB, %acc
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
%tRes = x86.amx.tile_mulf %tA, %tB, %acc
: !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
// CHECK: call void @llvm.x86.tilestored64.internal
amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
x86.amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
@@ -60,15 +60,15 @@ func.func @amx_tile_mulf_f16(
%out: memref<?x?xf32>)
{
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
%acc = amx.tile_zero : !amx.tile<16x16xf32>
%acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
// CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
%tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
%tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
// CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal
%tRes = amx.tile_mulf %tA, %tB, %acc
: !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
%tRes = x86.amx.tile_mulf %tA, %tB, %acc
: !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32>
// CHECK: call void @llvm.x86.tilestored64.internal
amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
x86.amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
@@ -79,26 +79,26 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
// CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
%acc = amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !amx.tile<16x16xi32>
%tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
%tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
%acc = x86.amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
// CHECK: call x86_amx @llvm.x86.tdpbuud.internal
// CHECK: call x86_amx @llvm.x86.tdpbssd.internal
// CHECK: call x86_amx @llvm.x86.tdpbusd.internal
// CHECK: call x86_amx @llvm.x86.tdpbsud.internal
%res = amx.tile_muli %tA zext, %tB zext, %acc
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%res1 = amx.tile_muli %tA, %tB, %acc
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%res2 = amx.tile_muli %tA zext, %tB, %acc
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%res3 = amx.tile_muli %tA, %tB zext, %acc
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
%res = x86.amx.tile_muli %tA zext, %tB zext, %acc
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
%res1 = x86.amx.tile_muli %tA, %tB, %acc
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
%res2 = x86.amx.tile_muli %tA zext, %tB, %acc
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
%res3 = x86.amx.tile_muli %tA, %tB zext, %acc
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
// CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !amx.tile<16x16xi32>
amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !amx.tile<16x16xi32>
amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !amx.tile<16x16xi32>
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
x86.amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
x86.amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
return
}
@@ -108,16 +108,16 @@ func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
cf.cond_br %cond, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
%0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
%0 = x86.amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
cf.br ^bb3(%0 : !x86.amx.tile<16x64xi8>)
^bb2: // pred: ^bb0
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
%1 = amx.tile_zero : !amx.tile<16x64xi8>
cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
%1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8>
cf.br ^bb3(%1 : !x86.amx.tile<16x64xi8>)
^bb3(%2: !x86.amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
cf.br ^bb4
^bb4: // pred: ^bb3
// CHECK: call void @llvm.x86.tilestored64.internal
amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
x86.amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
return
}

View File

@@ -46,7 +46,7 @@ config.enable_vulkan_runner = @MLIR_ENABLE_VULKAN_RUNNER@
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
config.enable_python_stable_abi = @MLIR_ENABLE_PYTHON_STABLE_ABI@
config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@"
config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@
config.mlir_run_x86_amx_tests = @MLIR_RUN_X86_AMX_TESTS@
config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
# This is a workaround for the fact that LIT's:
# %if <cond>

View File

@@ -3,7 +3,6 @@
// CHECK-SAME: acc
// CHECK-SAME: affine
// CHECK-SAME: amdgpu
// CHECK-SAME: amx
// CHECK-SAME: arith
// CHECK-SAME: arm_neon
// CHECK-SAME: arm_sme