[mlir][x86] Move AMX dialect into X86 dialect (#183717)
Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions. Following the dialect renaming, X86 dialect is now a suitable home for wider range of operations targeting specific hardware features. Moving AMX definitions to X86 dialect creates a single, centralized hub for defining all x86 intrinsic-like operations. The new grouping aims to eliminate the need for new dialects as new hardware extensions become available. The two dialects are simply merged together. X86 dialect refactoring will be addressed separately. List of changes: - operations: 'amx.tile_*' => 'x86.amx.tile_*' - types: '!amx.tile' => '!x86.amx.tile' - namespace: 'mlir::amx' => 'mlir::x86::amx' - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS' - vector lowering: AMX is enabled by default together with X86 The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
This commit is contained in:
@@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas.
|
||||
* ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space))
|
||||
* ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
|
||||
* ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
|
||||
* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
|
||||
* ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
|
||||
* ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ overall flow is two-stage:
|
||||
|
||||
1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for
|
||||
example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
|
||||
dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
|
||||
[X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
|
||||
dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md)
|
||||
or [ArmNeon](Dialects/ArmNeon.md);
|
||||
2. **translation** of MLIR dialects to LLVM IR.
|
||||
|
||||
This flow allows the non-trivial transformation to be performed within MLIR
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_C_DIALECT_AMX_H
|
||||
#define MLIR_C_DIALECT_AMX_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLIR_C_DIALECT_AMX_H
|
||||
@@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
|
||||
operations. The lowering pass provides several options to control
|
||||
the kinds of optimizations that are allowed. It also provides options
|
||||
that enable the use of one or more architectural-specific dialects
|
||||
(AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
|
||||
architectural-neutral vector dialect lowering.
|
||||
(X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral
|
||||
vector dialect lowering.
|
||||
|
||||
}];
|
||||
// Override explicitly in C++ to allow conditional dialect dependence.
|
||||
@@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
|
||||
"vector access are naturally aligned. If operations have an "
|
||||
"alignment attribute set, the alignment attribute takes priority "
|
||||
"over this option ">,
|
||||
Option<"amx", "enable-amx",
|
||||
"bool", /*default=*/"false",
|
||||
"Enables the use of AMX dialect while lowering the vector "
|
||||
"dialect.">,
|
||||
Option<"armNeon", "enable-arm-neon",
|
||||
"bool", /*default=*/"false",
|
||||
"Enables the use of ArmNeon dialect while lowering the vector "
|
||||
@@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
|
||||
let summary = "Lower the operations from the vector dialect into the AMX "
|
||||
"dialect";
|
||||
let summary = "Lower the operations from the vector dialect into the X86 "
|
||||
"dialect AMX operations";
|
||||
let dependentDialects = [
|
||||
"affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
|
||||
"affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect",
|
||||
"memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
|
||||
];
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
|
||||
//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -18,7 +18,7 @@ class RewritePatternSet;
|
||||
#define GEN_PASS_DECL_CONVERTVECTORTOAMX
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
/// Collect a set of patterns to convert from the vector to AMX ops.
|
||||
/// Collect a set of patterns to convert from the vector to X86 AMX ops.
|
||||
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,440 +0,0 @@
|
||||
//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the basic operations for the AMX dialect.
|
||||
//
|
||||
// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
|
||||
// multiply unit (TMUL), a tile control register (TILECFG), and eight
|
||||
// tile registers TMM0 through TMM7 (TILEDATA).
|
||||
//
|
||||
// The AMX dialect provides a bridge between MLIR concepts, such as
|
||||
// 2-d vector, operations, and memrefs, and the lower level details
|
||||
// of Intel AMX, such as configuration setup, tile sizes, instructions,
|
||||
// and tile release.
|
||||
//
|
||||
// Note that since configuration changes (implicit at dialect level) are
|
||||
// costly, it is highly recommended to use the AMX dialect on same-shaped
|
||||
// vectors, at least within a single method.
|
||||
//
|
||||
// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef AMX
|
||||
#define AMX
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/AMX/AMXInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX dialect definition.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AMX_Dialect : Dialect {
|
||||
let name = "amx";
|
||||
let cppNamespace = "::mlir::amx";
|
||||
let description = [{
|
||||
The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
|
||||
multiply unit (TMUL), a tile control register (TILECFG), and eight
|
||||
tile registers TMM0 through TMM7 (TILEDATA).
|
||||
|
||||
This `AMX` dialect provides a bridge between MLIR concepts such as
|
||||
vectors and memrefs and the lower level LLVM IR support of AMX.
|
||||
|
||||
Note that since configuration changes (implicit at dialect level) are
|
||||
costly, it is highly recommended to use the AMX dialect on same-shaped
|
||||
vectors, at least within a single method.
|
||||
|
||||
For details, see the Intel documentation:
|
||||
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
|
||||
}];
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Tile definition.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
|
||||
: TypeDef<AMX_Dialect, typeName, traits> {
|
||||
let mnemonic = typeMnemonic;
|
||||
}
|
||||
|
||||
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
|
||||
let cppFunctionName = "isValidTileTypeElementType";
|
||||
}
|
||||
|
||||
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
|
||||
let summary = "AMX 2D tile to be used by AMX opertaions.";
|
||||
|
||||
let description = [{
|
||||
This type is used to represent values in AMX tile registers. All AMX operations
|
||||
work on AMX tiles and these tiles cannot be used in other operations directly.
|
||||
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
|
||||
element type for IR verification and lowering to LLVMIR dialect.
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"int64_t">:$shape,
|
||||
AMX_TileTypeElementType:$elementType
|
||||
);
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins
|
||||
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
|
||||
return $_get(elementType.getContext(), shape, elementType);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns if this type is ranked (always true).
|
||||
bool hasRank() const { return true; }
|
||||
|
||||
/// Clone this tile type with the given shape and element type. If the
|
||||
/// provided shape is `std::nullopt`, the current shape of the type is used.
|
||||
TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||
Type elementType) const {
|
||||
return get(shape.value_or(getShape()), elementType);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
|
||||
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
|
||||
|
||||
class AMXTileOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
|
||||
"::mlir::amx::TileType">;
|
||||
|
||||
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
|
||||
|
||||
def AMXTileF32 : AMXTileOf<[F32]>;
|
||||
|
||||
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
|
||||
|
||||
def AMXTileI32 : AMXTileOf<[I32]>;
|
||||
|
||||
def AMXTileI8 : AMXTileOf<[I8]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Op and IntrOp definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class AMX_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<AMX_Dialect, mnemonic, traits> {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//
|
||||
// Tile reset.
|
||||
//
|
||||
|
||||
def TileZeroOp : AMX_Op<"tile_zero", [
|
||||
AMXIntrinsicOpInterface,
|
||||
MemoryEffects<[MemWrite]>
|
||||
]> {
|
||||
let summary = "tile zero operation";
|
||||
let description = [{
|
||||
Zeroes the destination tile, with the shape defined by the 2-dim
|
||||
vector type of the result.
|
||||
|
||||
The operation is eventually lowered into the "tilezero" instruction
|
||||
with the corresponding tile configuration.
|
||||
|
||||
With the write memory effect, each `amx.tile_zero` operation serves as
|
||||
a compilation hint to use a separate tile register.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
|
||||
```
|
||||
}];
|
||||
let results = (outs AnyAMXTile:$res);
|
||||
let extraClassDeclaration = [{
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tilezero.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "attr-dict `:` qualified(type($res))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Tile memory operations.
|
||||
//
|
||||
|
||||
def TileLoadOp : AMX_Op<"tile_load", [
|
||||
AMXIntrinsicOpInterface,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
AttrSizedOperandSegments
|
||||
]> {
|
||||
let summary = "tile load operation";
|
||||
let description = [{
|
||||
Loads a tile from memory defined by a `base` and `indices`, with the
|
||||
shape defined by the 2-dim vector type of the result.
|
||||
The tile's rows are populated by reading contiguous elements starting
|
||||
at the `base`. For each tile row, the `base` is incremented by `stride`
|
||||
number of elements.
|
||||
|
||||
The tile is loaded using the following indexing scheme:
|
||||
|
||||
```
|
||||
for row in enumerate(tile_rows):
|
||||
mem_row = base[i0, i1, ..., iN + row * stride]
|
||||
for col in enumerate(tile_cols):
|
||||
tile[row, col] = mem_row[col]
|
||||
```
|
||||
|
||||
If the `stride` is not provided, then the `base` buffer must be at least
|
||||
2-dimensional, and the `stride` is automatically inferred and corresponds
|
||||
to the stride of the buffer's second innermost dimension.
|
||||
|
||||
The operation is eventually lowered into the "tileloadd" instruction
|
||||
with the corresponding tile configuration.
|
||||
|
||||
With the write memory effect, each `amx.tile_load` operation serves as
|
||||
a compilation hint to use a separate tile register.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Tile load from a 2-D memref with implicit stride.
|
||||
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
|
||||
// Tile load from a 1-D memref with explicit stride.
|
||||
%0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
|
||||
Variadic<Index>:$indices,
|
||||
Optional<Index>:$stride);
|
||||
let results = (outs AnyAMXTile:$res);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return ::llvm::cast<MemRefType>(getBase().getType());
|
||||
}
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tileloadd64.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
|
||||
"`:` type($base) `into` qualified(type($res))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileStoreOp : AMX_Op<"tile_store", [
|
||||
AMXIntrinsicOpInterface,
|
||||
AttrSizedOperandSegments
|
||||
]> {
|
||||
let summary = "tile store operation";
|
||||
let description = [{
|
||||
Stores a tile to memory defined by a `base` and `indices`, with the
|
||||
shape defined by the 2-dim vector type of the value.
|
||||
The tile's rows are written contiguously to the buffer starting at
|
||||
the `base`. For each tile row, the `base` is incremented by `stride`
|
||||
number of elements.
|
||||
|
||||
The tile is stored using the following indexing scheme:
|
||||
|
||||
```
|
||||
for row in enumerate(tile_rows):
|
||||
mem_row = base[i0, i1, ..., iN + row * stride]
|
||||
for col in enumerate(tile_cols):
|
||||
mem_row[col] = tile[row, col]
|
||||
```
|
||||
|
||||
If the `stride` is not provided, then the `base` buffer must be at least
|
||||
2-dimensional, and the `stride` is automatically inferred and corresponds
|
||||
to the stride of the buffer's second innermost dimension.
|
||||
|
||||
The operation is eventually lowered into the "tilestored" instruction
|
||||
with the corresponding tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Tile store to a 2-D memref with implicit stride.
|
||||
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
|
||||
|
||||
// Tile store to a 1-D memref with explicit stride.
|
||||
amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
|
||||
Variadic<Index>:$indices,
|
||||
AnyAMXTile:$val,
|
||||
Optional<Index>:$stride);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return ::llvm::cast<MemRefType>(getBase().getType());
|
||||
}
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getVal().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tilestored64.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
|
||||
"attr-dict `:` type($base) `,` qualified(type($val))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Tile arithmetic operations.
|
||||
//
|
||||
|
||||
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
|
||||
AMXIntrinsicOpInterface,
|
||||
AllTypesMatch<["acc", "res"]>
|
||||
]> {
|
||||
let summary = "tile multiplication operation (floating-point)";
|
||||
let description = [{
|
||||
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
|
||||
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
|
||||
pairs of "bf16").
|
||||
|
||||
The operation is eventually lowered into the "tdpbf16ps" instruction with
|
||||
the corresponding tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = amx.tile_mulf %a, %b, %c
|
||||
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins AMXTileF16OrBF16:$lhs,
|
||||
AMXTileF16OrBF16:$rhs,
|
||||
AMXTileF32:$acc);
|
||||
let results = (outs AMXTileF32:$res);
|
||||
let extraClassDeclaration = [{
|
||||
TileType getLhsTileType() {
|
||||
return ::llvm::cast<TileType>(getLhs().getType());
|
||||
}
|
||||
TileType getRhsTileType() {
|
||||
return ::llvm::cast<TileType>(getRhs().getType());
|
||||
}
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.tdp";
|
||||
auto elementType =
|
||||
getLhsTileType().getElementType();
|
||||
intr += elementType.isF16() ? "fp16" : "bf16";
|
||||
intr += "ps.internal";
|
||||
return intr;
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
||||
"qualified(type($lhs)) `,` qualified(type($rhs))"
|
||||
" `,` qualified(type($acc)) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
|
||||
AMXIntrinsicOpInterface,
|
||||
AllTypesMatch<["acc", "res"]>
|
||||
]> {
|
||||
let summary = "tile multiplication operation (integer)";
|
||||
let description = [{
|
||||
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
|
||||
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
|
||||
combinations (4 bytes packed into dwords in the columns of both the
|
||||
source operand tiles; the zero or sign extension is specified with
|
||||
the attributes and default to sign extended).
|
||||
|
||||
The operation is eventually lowered into one of the "tdpbssd",
|
||||
"tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
|
||||
tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = amx.tile_muli %a zext, %b zext, %c
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins AMXTileI8:$lhs,
|
||||
AMXTileI8:$rhs,
|
||||
AMXTileI32:$acc,
|
||||
UnitAttr:$isZextLhs,
|
||||
UnitAttr:$isZextRhs
|
||||
);
|
||||
let results = (outs AMXTileI32:$res);
|
||||
let extraClassDeclaration = [{
|
||||
TileType getLhsTileType() {
|
||||
return ::llvm::cast<TileType>(getLhs().getType());
|
||||
}
|
||||
TileType getRhsTileType() {
|
||||
return ::llvm::cast<TileType>(getRhs().getType());
|
||||
}
|
||||
TileType getTileType() {
|
||||
return ::llvm::cast<TileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.tdpb";
|
||||
intr += getIsZextLhs() ? "u" : "s";
|
||||
intr += getIsZextRhs() ? "u" : "s";
|
||||
intr += "d.internal";
|
||||
return intr;
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // AMX
|
||||
@@ -1,34 +0,0 @@
|
||||
//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares the Target dialect for AMX in MLIR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_
|
||||
#define MLIR_DIALECT_AMX_AMXDIALECT_H_
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/AMX/AMXTypes.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/AMX/AMX.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_
|
||||
@@ -1,31 +0,0 @@
|
||||
//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines interfaces for the AMX dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef AMX_INTERFACES
|
||||
#define AMX_INTERFACES
|
||||
|
||||
include "mlir/IR/Interfaces.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Intrinsic Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AMXIntrinsicOpInterface
|
||||
: OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
|
||||
let description = [{
|
||||
A wrapper interface for operations representing AMX LLVM intrinsics.
|
||||
}];
|
||||
let cppNamespace = "::mlir::amx";
|
||||
}
|
||||
|
||||
#endif // AMX_INTERFACES
|
||||
@@ -1,5 +0,0 @@
|
||||
add_mlir_dialect(AMX amx)
|
||||
add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
|
||||
|
||||
add_mlir_interface(AMXInterfaces)
|
||||
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
|
||||
@@ -1,33 +0,0 @@
|
||||
//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
|
||||
#define MLIR_DIALECT_AMX_TRANSFORMS_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class LLVMConversionTarget;
|
||||
class LLVMTypeConverter;
|
||||
class RewritePatternSet;
|
||||
class DialectRegistry;
|
||||
|
||||
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Configure the target to support lowering AMX ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
|
||||
|
||||
/// Register LLVM conversion interface for AMX dialect.
|
||||
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
|
||||
@@ -1,6 +1,5 @@
|
||||
add_subdirectory(Affine)
|
||||
add_subdirectory(AMDGPU)
|
||||
add_subdirectory(AMX)
|
||||
add_subdirectory(Arith)
|
||||
add_subdirectory(ArmNeon)
|
||||
add_subdirectory(ArmSME)
|
||||
|
||||
@@ -104,10 +104,6 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
|
||||
desc("Allows compiler to assume indices fit in 32-bit if that yields "
|
||||
"faster code"),
|
||||
init(true)};
|
||||
PassOptions::Option<bool> amx{
|
||||
*this, "enable-amx",
|
||||
desc("Enables the use of AMX dialect while lowering the vector dialect"),
|
||||
init(false)};
|
||||
PassOptions::Option<bool> armNeon{
|
||||
*this, "enable-arm-neon",
|
||||
desc("Enables the use of ArmNeon dialect while lowering the vector "
|
||||
@@ -168,7 +164,6 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
|
||||
opts.force32BitVectorIndices = force32BitVectorIndices;
|
||||
opts.armNeon = armNeon;
|
||||
opts.armSVE = armSVE;
|
||||
opts.amx = amx;
|
||||
opts.x86 = x86;
|
||||
return opts;
|
||||
}
|
||||
|
||||
@@ -200,13 +200,16 @@ void populateSpecializedTransposeLoweringPatterns(
|
||||
|
||||
/// Collect a set of patterns to lower X86 ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void populateX86LegalizeForLLVMExportPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Configure the target to support lowering X86 ops to ops that map to
|
||||
/// LLVM intrinsics.
|
||||
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
|
||||
|
||||
/// Register LLVM conversion interface for X86 dialect.
|
||||
void registerConvertX86ToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_X86_TRANSFORMS_H
|
||||
|
||||
@@ -17,6 +17,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/X86/X86Interfaces.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// X86 dialect definition
|
||||
@@ -25,6 +26,8 @@ include "mlir/Dialect/X86/X86Interfaces.td"
|
||||
def X86_Dialect : Dialect {
|
||||
let name = "x86";
|
||||
let cppNamespace = "::mlir::x86";
|
||||
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -673,4 +676,385 @@ def CvtPackedOddIndexedToF32Op
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Tile definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
|
||||
: TypeDef<X86_Dialect, "AMX" # typeName, traits> {
|
||||
let mnemonic = "amx." # typeMnemonic;
|
||||
}
|
||||
|
||||
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
|
||||
let cppFunctionName = "isValidTileTypeElementType";
|
||||
}
|
||||
|
||||
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
|
||||
let summary = "AMX 2D tile to be used by AMX opertaions.";
|
||||
|
||||
let description = [{
|
||||
This type is used to represent values in AMX tile registers. All AMX operations
|
||||
work on AMX tiles and these tiles cannot be used in other operations directly.
|
||||
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
|
||||
element type for IR verification and lowering to LLVMIR dialect.
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"int64_t">:$shape,
|
||||
AMX_TileTypeElementType:$elementType
|
||||
);
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins
|
||||
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
|
||||
return $_get(elementType.getContext(), shape, elementType);
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns if this type is ranked (always true).
|
||||
bool hasRank() const { return true; }
|
||||
|
||||
/// Clone this tile type with the given shape and element type. If the
|
||||
/// provided shape is `std::nullopt`, the current shape of the type is used.
|
||||
AMXTileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||
Type elementType) const {
|
||||
return get(shape.value_or(getShape()), elementType);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::x86::AMXTileType>($_self)">,
|
||||
CPred<[{::llvm::cast<::mlir::x86::AMXTileType>($_self).getRank() == 2}]>]>;
|
||||
|
||||
class AMXTileOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
|
||||
"::mlir::x86::AMXTileType">;
|
||||
|
||||
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
|
||||
|
||||
def AMXTileF32 : AMXTileOf<[F32]>;
|
||||
|
||||
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
|
||||
|
||||
def AMXTileI32 : AMXTileOf<[I32]>;
|
||||
|
||||
def AMXTileI8 : AMXTileOf<[I8]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class AMX_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<X86_Dialect, "amx." # mnemonic, traits> {
|
||||
let cppNamespace = X86_Dialect.cppNamespace # "::amx";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Tile Zero
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TileZeroOp : AMX_Op<"tile_zero", [
|
||||
X86IntrinsicOpInterface,
|
||||
MemoryEffects<[MemWrite]>
|
||||
]> {
|
||||
let summary = "tile zero operation";
|
||||
let description = [{
|
||||
Zeroes the destination tile, with the shape defined by the 2-dim
|
||||
vector type of the result.
|
||||
|
||||
The operation is eventually lowered into the "tilezero" instruction
|
||||
with the corresponding tile configuration.
|
||||
|
||||
With the write memory effect, each `x86.amx.tile_zero` operation serves as
|
||||
a compilation hint to use a separate tile register.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
|
||||
```
|
||||
}];
|
||||
let results = (outs AnyAMXTile:$res);
|
||||
let extraClassDeclaration = [{
|
||||
AMXTileType getTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tilezero.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "attr-dict `:` qualified(type($res))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Tile Load
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TileLoadOp : AMX_Op<"tile_load", [
|
||||
X86IntrinsicOpInterface,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
AttrSizedOperandSegments
|
||||
]> {
|
||||
let summary = "tile load operation";
|
||||
let description = [{
|
||||
Loads a tile from memory defined by a `base` and `indices`, with the
|
||||
shape defined by the 2-dim vector type of the result.
|
||||
The tile's rows are populated by reading contiguous elements starting
|
||||
at the `base`. For each tile row, the `base` is incremented by `stride`
|
||||
number of elements.
|
||||
|
||||
The tile is loaded using the following indexing scheme:
|
||||
|
||||
```
|
||||
for row in enumerate(tile_rows):
|
||||
mem_row = base[i0, i1, ..., iN + row * stride]
|
||||
for col in enumerate(tile_cols):
|
||||
tile[row, col] = mem_row[col]
|
||||
```
|
||||
|
||||
If the `stride` is not provided, then the `base` buffer must be at least
|
||||
2-dimensional, and the `stride` is automatically inferred and corresponds
|
||||
to the stride of the buffer's second innermost dimension.
|
||||
|
||||
The operation is eventually lowered into the "tileloadd" instruction
|
||||
with the corresponding tile configuration.
|
||||
|
||||
With the write memory effect, each `x86.amx.tile_load` operation serves as
|
||||
a compilation hint to use a separate tile register.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Tile load from a 2-D memref with implicit stride.
|
||||
%0 = x86.amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
|
||||
// Tile load from a 1-D memref with explicit stride.
|
||||
%0 = x86.amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !x86.amx.tile<16x64xi8>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
|
||||
Variadic<Index>:$indices,
|
||||
Optional<Index>:$stride);
|
||||
let results = (outs AnyAMXTile:$res);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return ::llvm::cast<MemRefType>(getBase().getType());
|
||||
}
|
||||
AMXTileType getTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tileloadd64.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
|
||||
"`:` type($base) `into` qualified(type($res))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Tile Store
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TileStoreOp : AMX_Op<"tile_store", [
|
||||
X86IntrinsicOpInterface,
|
||||
AttrSizedOperandSegments
|
||||
]> {
|
||||
let summary = "tile store operation";
|
||||
let description = [{
|
||||
Stores a tile to memory defined by a `base` and `indices`, with the
|
||||
shape defined by the 2-dim vector type of the value.
|
||||
The tile's rows are written contiguously to the buffer starting at
|
||||
the `base`. For each tile row, the `base` is incremented by `stride`
|
||||
number of elements.
|
||||
|
||||
The tile is stored using the following indexing scheme:
|
||||
|
||||
```
|
||||
for row in enumerate(tile_rows):
|
||||
mem_row = base[i0, i1, ..., iN + row * stride]
|
||||
for col in enumerate(tile_cols):
|
||||
mem_row[col] = tile[row, col]
|
||||
```
|
||||
|
||||
If the `stride` is not provided, then the `base` buffer must be at least
|
||||
2-dimensional, and the `stride` is automatically inferred and corresponds
|
||||
to the stride of the buffer's second innermost dimension.
|
||||
|
||||
The operation is eventually lowered into the "tilestored" instruction
|
||||
with the corresponding tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Tile store to a 2-D memref with implicit stride.
|
||||
x86.amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
|
||||
|
||||
// Tile store to a 1-D memref with explicit stride.
|
||||
x86.amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !x86.amx.tile<16x64xi8>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
|
||||
Variadic<Index>:$indices,
|
||||
AnyAMXTile:$val,
|
||||
Optional<Index>:$stride);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return ::llvm::cast<MemRefType>(getBase().getType());
|
||||
}
|
||||
AMXTileType getTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getVal().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
return "llvm.x86.tilestored64.internal";
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
|
||||
"attr-dict `:` type($base) `,` qualified(type($val))";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AMX Tile Multiply
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
|
||||
X86IntrinsicOpInterface,
|
||||
AllTypesMatch<["acc", "res"]>
|
||||
]> {
|
||||
let summary = "tile multiplication operation (floating-point)";
|
||||
let description = [{
|
||||
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
|
||||
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
|
||||
pairs of "bf16").
|
||||
|
||||
The operation is eventually lowered into the "tdpbf16ps" instruction with
|
||||
the corresponding tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = x86.amx.tile_mulf %a, %b, %c
|
||||
: !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins AMXTileF16OrBF16:$lhs,
|
||||
AMXTileF16OrBF16:$rhs,
|
||||
AMXTileF32:$acc);
|
||||
let results = (outs AMXTileF32:$res);
|
||||
let extraClassDeclaration = [{
|
||||
AMXTileType getLhsTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getLhs().getType());
|
||||
}
|
||||
AMXTileType getRhsTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getRhs().getType());
|
||||
}
|
||||
AMXTileType getTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.tdp";
|
||||
auto elementType =
|
||||
getLhsTileType().getElementType();
|
||||
intr += elementType.isF16() ? "fp16" : "bf16";
|
||||
intr += "ps.internal";
|
||||
return intr;
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
||||
"qualified(type($lhs)) `,` qualified(type($rhs))"
|
||||
" `,` qualified(type($acc)) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
|
||||
X86IntrinsicOpInterface,
|
||||
AllTypesMatch<["acc", "res"]>
|
||||
]> {
|
||||
let summary = "tile multiplication operation (integer)";
|
||||
let description = [{
|
||||
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
|
||||
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
|
||||
combinations (4 bytes packed into dwords in the columns of both the
|
||||
source operand tiles; the zero or sign extension is specified with
|
||||
the attributes and default to sign extended).
|
||||
|
||||
The operation is eventually lowered into one of the "tdpbssd",
|
||||
"tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
|
||||
tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = x86.amx.tile_muli %a zext, %b zext, %c
|
||||
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins AMXTileI8:$lhs,
|
||||
AMXTileI8:$rhs,
|
||||
AMXTileI32:$acc,
|
||||
UnitAttr:$isZextLhs,
|
||||
UnitAttr:$isZextRhs
|
||||
);
|
||||
let results = (outs AMXTileI32:$res);
|
||||
let extraClassDeclaration = [{
|
||||
AMXTileType getLhsTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getLhs().getType());
|
||||
}
|
||||
AMXTileType getRhsTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getRhs().getType());
|
||||
}
|
||||
AMXTileType getTileType() {
|
||||
return ::llvm::cast<AMXTileType>(getRes().getType());
|
||||
}
|
||||
|
||||
std::string getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.tdpb";
|
||||
intr += getIsZextLhs() ? "u" : "s";
|
||||
intr += getIsZextRhs() ? "u" : "s";
|
||||
intr += "d.internal";
|
||||
return intr;
|
||||
}
|
||||
SmallVector<Value> getIntrinsicOperands(
|
||||
::mlir::ArrayRef<Value> operands,
|
||||
const ::mlir::LLVMTypeConverter &typeConverter,
|
||||
::mlir::RewriterBase &rewriter);
|
||||
}];
|
||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // X86_OPS
|
||||
|
||||
@@ -29,6 +29,19 @@
|
||||
|
||||
#include "mlir/Dialect/X86/X86Dialect.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/X86/X86Types.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace x86 {
|
||||
namespace amx {
|
||||
// Alias to allow access to AMX type through nested namespaces
|
||||
// analogously to AMX operations.
|
||||
using TileType = mlir::x86::AMXTileType;
|
||||
} // namespace amx
|
||||
} // namespace x86
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/X86/X86.h.inc"
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//===- AMX.cpp - C Interface for AMX dialect ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/Dialect/AMX.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMX, amx, mlir::amx::AMXDialect)
|
||||
@@ -26,15 +26,6 @@ add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU
|
||||
MLIRAMDGPUTransforms
|
||||
)
|
||||
|
||||
add_mlir_upstream_c_api_library(MLIRCAPIAMX
|
||||
AMX.cpp
|
||||
|
||||
PARTIAL_SOURCES_INTENDED
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
MLIRAMXDialect
|
||||
)
|
||||
|
||||
add_mlir_upstream_c_api_library(MLIRCAPIArith
|
||||
Arith.cpp
|
||||
ArithPasses.cpp
|
||||
|
||||
@@ -8,7 +8,6 @@ add_mlir_conversion_library(MLIRVectorToAMX
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAMXDialect
|
||||
MLIRAffineUtils
|
||||
MLIRArithDialect
|
||||
MLIRLinalgUtils
|
||||
@@ -16,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToAMX
|
||||
MLIRSCFDialect
|
||||
MLIRTransforms
|
||||
MLIRVectorDialect
|
||||
MLIRX86Dialect
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
|
||||
//===- VectorToAMX.cpp - Convert vector to X86 dialect AMX ops --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
||||
@@ -16,6 +15,7 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -197,7 +197,7 @@ static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
|
||||
static Operation *
|
||||
loadStoreFromTransfer(PatternRewriter &rewriter,
|
||||
VectorTransferOpInterface xferOp, bool isPacked,
|
||||
TypedValue<amx::TileType> tileToStore = nullptr) {
|
||||
TypedValue<x86::amx::TileType> tileToStore = nullptr) {
|
||||
if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
|
||||
return nullptr;
|
||||
if (xferOp.hasOutOfBoundsDim() ||
|
||||
@@ -267,18 +267,18 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
|
||||
src = collapseLastDim(rewriter, src);
|
||||
int64_t rows = vecShape[0];
|
||||
int64_t cols = llvm::product_of(vecShape.drop_front());
|
||||
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
|
||||
auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
|
||||
|
||||
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
|
||||
SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
|
||||
|
||||
Operation *amxTileOp = nullptr;
|
||||
if (isa<vector::TransferReadOp>(xferOp)) {
|
||||
amxTileOp =
|
||||
amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
|
||||
amxTileOp = x86::amx::TileLoadOp::create(rewriter, loc, tileType, src,
|
||||
tileIndicides);
|
||||
} else if (isa<vector::TransferWriteOp>(xferOp)) {
|
||||
amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
|
||||
tileToStore);
|
||||
amxTileOp = x86::amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
|
||||
tileToStore);
|
||||
} else {
|
||||
llvm_unreachable("unsupported vector transfer op");
|
||||
}
|
||||
@@ -289,10 +289,10 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
|
||||
/// Attempt to create an AMX tile load operation equivalent to the given
|
||||
/// vector transfer `readOp`.
|
||||
/// Returns loaded AMX tile if successful.
|
||||
static FailureOr<TypedValue<amx::TileType>>
|
||||
static FailureOr<TypedValue<x86::amx::TileType>>
|
||||
loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
|
||||
bool isPacked) {
|
||||
amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
|
||||
x86::amx::TileLoadOp loadOp = dyn_cast_if_present<x86::amx::TileLoadOp>(
|
||||
loadStoreFromTransfer(rewriter, readOp, isPacked));
|
||||
if (!loadOp)
|
||||
return failure();
|
||||
@@ -301,16 +301,16 @@ loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
|
||||
|
||||
/// Attempt to create an AMX tile store operation equivalent to the given
|
||||
/// vector transfer `writeOp`.
|
||||
static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
|
||||
vector::TransferWriteOp writeOp,
|
||||
TypedValue<amx::TileType> tileToStore) {
|
||||
static LogicalResult
|
||||
storeFromTransfer(PatternRewriter &rewriter, vector::TransferWriteOp writeOp,
|
||||
TypedValue<x86::amx::TileType> tileToStore) {
|
||||
return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
|
||||
tileToStore));
|
||||
}
|
||||
|
||||
/// Load vector values to an AMX tile.
|
||||
static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
|
||||
TypedValue<VectorType> vec) {
|
||||
static TypedValue<x86::amx::TileType> loadTile(PatternRewriter &rewriter,
|
||||
TypedValue<VectorType> vec) {
|
||||
Location loc = vec.getLoc();
|
||||
|
||||
VectorType vecTy = vec.getType();
|
||||
@@ -318,7 +318,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
|
||||
|
||||
// Try to load tile directly from vector producer's buffer.
|
||||
auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
|
||||
FailureOr<TypedValue<amx::TileType>> tile =
|
||||
FailureOr<TypedValue<x86::amx::TileType>> tile =
|
||||
loadFromTransfer(rewriter, readOp, isPacked);
|
||||
if (succeeded(tile))
|
||||
return *tile;
|
||||
@@ -337,25 +337,25 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
|
||||
ArrayRef<int64_t> shape = vecTy.getShape();
|
||||
int64_t rows = shape[0];
|
||||
int64_t cols = llvm::product_of(shape.drop_front());
|
||||
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
|
||||
auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
|
||||
|
||||
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
|
||||
{zeroIndex, zeroIndex});
|
||||
return x86::amx::TileLoadOp::create(rewriter, loc, tileType, buf,
|
||||
{zeroIndex, zeroIndex});
|
||||
}
|
||||
|
||||
/// Store an AMX tile in a vector.
|
||||
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
|
||||
TypedValue<amx::TileType> tile) {
|
||||
TypedValue<x86::amx::TileType> tile) {
|
||||
Location loc = tile.getLoc();
|
||||
|
||||
// Transfer the tile to a vector through an intermediate buffer.
|
||||
amx::TileType tileTy = tile.getType();
|
||||
x86::amx::TileType tileTy = tile.getType();
|
||||
Value buf = memref::AllocaOp::create(
|
||||
rewriter, loc,
|
||||
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
|
||||
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
|
||||
SmallVector<Value> indices(2, zeroIndex);
|
||||
amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
|
||||
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
|
||||
|
||||
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
|
||||
return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
|
||||
@@ -374,19 +374,21 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
|
||||
if (failed(validateOperands(rewriter, contractOp)))
|
||||
return failure();
|
||||
|
||||
TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
|
||||
TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
|
||||
TypedValue<x86::amx::TileType> lhsTile =
|
||||
loadTile(rewriter, contractOp.getLhs());
|
||||
TypedValue<x86::amx::TileType> rhsTile =
|
||||
loadTile(rewriter, contractOp.getRhs());
|
||||
auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
|
||||
assert(acc && "Invalid accumulator type");
|
||||
TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
|
||||
TypedValue<x86::amx::TileType> accTile = loadTile(rewriter, acc);
|
||||
|
||||
TypedValue<amx::TileType> tileMul;
|
||||
TypedValue<x86::amx::TileType> tileMul;
|
||||
if (acc.getType().getElementType().isFloat()) {
|
||||
tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
|
||||
lhsTile, rhsTile, accTile);
|
||||
tileMul = x86::amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
|
||||
lhsTile, rhsTile, accTile);
|
||||
} else {
|
||||
tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
|
||||
lhsTile, rhsTile, accTile);
|
||||
tileMul = x86::amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
|
||||
lhsTile, rhsTile, accTile);
|
||||
}
|
||||
|
||||
// If the contraction result is only written back to memory, try to replace
|
||||
|
||||
@@ -38,8 +38,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
|
||||
MLIRArmNeonTransforms
|
||||
MLIRArmSVEDialect
|
||||
MLIRArmSVETransforms
|
||||
MLIRAMXDialect
|
||||
MLIRAMXTransforms
|
||||
MLIRX86Dialect
|
||||
MLIRX86Transforms
|
||||
)
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/Dialect/AMX/Transforms.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
||||
#include "mlir/Dialect/ArmNeon/Transforms.h"
|
||||
@@ -51,8 +49,6 @@ struct ConvertVectorToLLVMPass
|
||||
registry.insert<arm_neon::ArmNeonDialect>();
|
||||
if (armSVE)
|
||||
registry.insert<arm_sve::ArmSVEDialect>();
|
||||
if (amx)
|
||||
registry.insert<amx::AMXDialect>();
|
||||
if (x86)
|
||||
registry.insert<x86::X86Dialect>();
|
||||
}
|
||||
@@ -136,10 +132,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
|
||||
configureArmSVELegalizeForExportTarget(target);
|
||||
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
|
||||
}
|
||||
if (amx) {
|
||||
configureAMXLegalizeForExportTarget(target);
|
||||
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
|
||||
}
|
||||
if (x86) {
|
||||
configureX86LegalizeForExportTarget(target);
|
||||
populateX86LegalizeForLLVMExportPatterns(converter, patterns);
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
@@ -1,318 +0,0 @@
|
||||
//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the AMX dialect and its operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
|
||||
|
||||
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
|
||||
|
||||
void amx::AMXDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
|
||||
>();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/AMX/AMX.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
/// Verify that AMX supports the implied tile shape.
|
||||
static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
|
||||
const unsigned kMaxRows = 16;
|
||||
const unsigned kBitsPerRow = 64 * 8;
|
||||
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
|
||||
if (tp.getDimSize(0) > kMaxRows)
|
||||
return op->emitOpError("bad row height: ") << tp.getDimSize(0);
|
||||
if (col > kBitsPerRow || col & 0x1f)
|
||||
return op->emitOpError("bad column width: ") << (col >> 3);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verify that AMX supports the multiplication.
|
||||
static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
|
||||
amx::TileType btp, amx::TileType ctp,
|
||||
unsigned scale) {
|
||||
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
|
||||
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
|
||||
unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
|
||||
if (cm != am || cn != bn || ak != bk)
|
||||
return op->emitOpError("bad mult shape: ")
|
||||
<< cm << " x " << cn << " x " << ak;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
|
||||
/// dimension directly translates into the number of rows of the tiles.
|
||||
/// The second dimensions needs to be scaled by the number of bytes.
|
||||
static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
|
||||
RewriterBase &rewriter) {
|
||||
Type llvmInt16Type = rewriter.getIntegerType(16);
|
||||
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
|
||||
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
|
||||
return SmallVector<Value>{
|
||||
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
|
||||
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
|
||||
}
|
||||
|
||||
/// Returns stride expressed in number of bytes for the given `elementStride`
|
||||
/// stride encoded in number of elements of the type `mType`.
|
||||
static Value computeStrideInBytes(Location loc, MemRefType mType,
|
||||
Value elementStride, RewriterBase &rewriter) {
|
||||
Type llvmInt64Type = rewriter.getIntegerType(64);
|
||||
unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
auto attr = rewriter.getI64IntegerAttr(bytes);
|
||||
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
|
||||
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
|
||||
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
|
||||
static Value inferStride(Location loc, MemRefType mType, Value base,
|
||||
RewriterBase &rewriter) {
|
||||
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
|
||||
int64_t preLast = mType.getRank() - 2;
|
||||
Type llvmInt64Type = rewriter.getIntegerType(64);
|
||||
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto [strides, offset] = mType.getStridesAndOffset();
|
||||
if (strides[preLast] == ShapedType::kDynamic) {
|
||||
// Dynamic stride needs code to compute the stride at runtime.
|
||||
MemRefDescriptor memrefDescriptor(base);
|
||||
return computeStrideInBytes(
|
||||
loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
|
||||
}
|
||||
// Use direct constant for static stride.
|
||||
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
|
||||
return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
LogicalResult amx::TileZeroOp::verify() {
|
||||
return verifyTileSize(*this, getTileType());
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
return getTileSizes(getLoc(), getTileType(), rewriter);
|
||||
}
|
||||
|
||||
template <typename OpTy,
|
||||
typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
|
||||
std::is_same_v<OpTy, amx::TileStoreOp>>>
|
||||
static LogicalResult tileTransferVerifier(OpTy op) {
|
||||
MemRefType memrefTy = op.getMemRefType();
|
||||
unsigned rank = memrefTy.getRank();
|
||||
if (op.getIndices().size() != rank)
|
||||
return op.emitOpError("requires ") << rank << " indices";
|
||||
|
||||
if (failed(verifyTileSize(op, op.getTileType())))
|
||||
return failure();
|
||||
|
||||
// Validate basic buffer properties when the stride is implicit.
|
||||
if (!op.getStride()) {
|
||||
if (rank < 2)
|
||||
return op.emitOpError("requires at least 2D memref");
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset;
|
||||
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
|
||||
strides.back() != 1)
|
||||
return op.emitOpError("requires memref with unit innermost stride");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
|
||||
Value base, ValueRange indices) {
|
||||
build(builder, state, res, base, indices, /*stride=*/nullptr);
|
||||
}
|
||||
|
||||
LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
SmallVector<Value> intrinsicOperands;
|
||||
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
|
||||
intrinsicOperands.push_back(
|
||||
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
|
||||
adaptor.getBase(), adaptor.getIndices()));
|
||||
if (Value stride = adaptor.getStride())
|
||||
intrinsicOperands.push_back(
|
||||
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
|
||||
else
|
||||
intrinsicOperands.push_back(
|
||||
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
|
||||
Value base, ValueRange indices, Value val) {
|
||||
build(builder, state, base, indices, val, /*stride=*/nullptr);
|
||||
}
|
||||
|
||||
LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
SmallVector<Value> intrinsicOperands;
|
||||
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
|
||||
intrinsicOperands.push_back(
|
||||
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
|
||||
adaptor.getBase(), adaptor.getIndices()));
|
||||
if (Value stride = adaptor.getStride())
|
||||
intrinsicOperands.push_back(
|
||||
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
|
||||
else
|
||||
intrinsicOperands.push_back(
|
||||
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
|
||||
intrinsicOperands.push_back(adaptor.getVal());
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult amx::TileMulFOp::verify() {
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
amx::TileType cType = getTileType();
|
||||
if (failed(verifyTileSize(*this, aType)) ||
|
||||
failed(verifyTileSize(*this, bType)) ||
|
||||
failed(verifyTileSize(*this, cType)) ||
|
||||
failed(verifyMultShape(*this, aType, bType, cType, 1)))
|
||||
return failure();
|
||||
Type ta = aType.getElementType();
|
||||
Type tb = bType.getElementType();
|
||||
Type tc = cType.getElementType();
|
||||
if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
|
||||
return emitOpError("unsupported type combination");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
|
||||
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
|
||||
|
||||
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
|
||||
tsza[1], adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs()};
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult amx::TileMulIOp::verify() {
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
amx::TileType cType = getTileType();
|
||||
if (failed(verifyTileSize(*this, aType)) ||
|
||||
failed(verifyTileSize(*this, bType)) ||
|
||||
failed(verifyTileSize(*this, cType)) ||
|
||||
failed(verifyMultShape(*this, aType, bType, cType, 2)))
|
||||
return failure();
|
||||
Type ta = aType.getElementType();
|
||||
Type tb = bType.getElementType();
|
||||
Type tc = cType.getElementType();
|
||||
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
|
||||
return emitOpError("unsupported type combination");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
amx::TileType aType = getLhsTileType();
|
||||
amx::TileType bType = getRhsTileType();
|
||||
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
|
||||
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
|
||||
|
||||
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
|
||||
tsza[1], adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs()};
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
Type amx::TileType::parse(AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t, 2> shape;
|
||||
if (parser.parseDimensionList(shape, false, true))
|
||||
return nullptr;
|
||||
|
||||
Type elementType;
|
||||
if (parser.parseType(elementType))
|
||||
return nullptr;
|
||||
|
||||
if (parser.parseGreater())
|
||||
return nullptr;
|
||||
|
||||
return TileType::getChecked(
|
||||
[&] { return parser.emitError(parser.getNameLoc()); }, shape,
|
||||
elementType);
|
||||
}
|
||||
|
||||
void amx::TileType::print(AsmPrinter &os) const {
|
||||
os << "<";
|
||||
os.printDimensionList(getShape());
|
||||
os << 'x';
|
||||
os.printType(getElementType());
|
||||
os << '>';
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/AMX/AMX.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
|
||||
@@ -1,15 +0,0 @@
|
||||
add_mlir_dialect_library(MLIRAMXDialect
|
||||
AMXDialect.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX
|
||||
|
||||
DEPENDS
|
||||
MLIRAMXIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMDialect
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
@@ -1,9 +0,0 @@
|
||||
add_mlir_dialect_library(MLIRAMXTransforms
|
||||
LegalizeForLLVMExport.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAMXDialect
|
||||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMDialect
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/AMX/Transforms.h"
|
||||
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::amx;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Generic one-to-one conversion of simply mappable operations into calls
|
||||
/// to their respective LLVM intrinsics.
|
||||
struct AMXIntrinsicOpConversion
|
||||
: public ConvertOpInterfaceToLLVMPattern<amx::AMXIntrinsicOp> {
|
||||
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const LLVMTypeConverter &typeConverter = *getTypeConverter();
|
||||
return LLVM::detail::intrinsicRewrite(
|
||||
op, rewriter.getStringAttr(op.getIntrinsicName()),
|
||||
op.getIntrinsicOperands(operands, typeConverter, rewriter),
|
||||
typeConverter, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateAMXLegalizeForLLVMExportPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<AMXIntrinsicOpConversion>(converter);
|
||||
converter.addConversion([&](amx::TileType type) {
|
||||
return LLVM::LLVMX86AMXType::get(&converter.getContext());
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
|
||||
target.addIllegalDialect<AMXDialect>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Implement the interface to convert AMX to LLVM.
|
||||
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
|
||||
dialect->addInterfaces<AMXToLLVMDialectInterface>();
|
||||
});
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
add_subdirectory(Affine)
|
||||
add_subdirectory(AMDGPU)
|
||||
add_subdirectory(AMX)
|
||||
add_subdirectory(Arith)
|
||||
add_subdirectory(ArmNeon)
|
||||
add_subdirectory(ArmSME)
|
||||
|
||||
@@ -11,10 +11,16 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#include "mlir/Dialect/X86/X86Interfaces.cpp.inc"
|
||||
@@ -22,6 +28,11 @@ using namespace mlir;
|
||||
#include "mlir/Dialect/X86/X86Dialect.cpp.inc"
|
||||
|
||||
void x86::X86Dialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/X86/X86Types.cpp.inc"
|
||||
>();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/X86/X86.cpp.inc"
|
||||
@@ -107,5 +118,279 @@ SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
|
||||
typeConverter, rewriter)};
|
||||
}
|
||||
|
||||
/// Verify that AMX supports the implied tile shape.
|
||||
static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp) {
|
||||
const unsigned kMaxRows = 16;
|
||||
const unsigned kBitsPerRow = 64 * 8;
|
||||
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
|
||||
if (tp.getDimSize(0) > kMaxRows)
|
||||
return op->emitOpError("bad row height: ") << tp.getDimSize(0);
|
||||
if (col > kBitsPerRow || col & 0x1f)
|
||||
return op->emitOpError("bad column width: ") << (col >> 3);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verify that AMX supports the multiplication.
|
||||
static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp,
|
||||
x86::amx::TileType btp,
|
||||
x86::amx::TileType ctp, unsigned scale) {
|
||||
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
|
||||
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
|
||||
unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
|
||||
if (cm != am || cn != bn || ak != bk)
|
||||
return op->emitOpError("bad mult shape: ")
|
||||
<< cm << " x " << cn << " x " << ak;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
|
||||
/// dimension directly translates into the number of rows of the tiles.
|
||||
/// The second dimensions needs to be scaled by the number of bytes.
|
||||
static SmallVector<Value> getTileSizes(Location loc, x86::amx::TileType tType,
|
||||
RewriterBase &rewriter) {
|
||||
Type llvmInt16Type = rewriter.getIntegerType(16);
|
||||
unsigned width = tType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
|
||||
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
|
||||
return SmallVector<Value>{
|
||||
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
|
||||
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
|
||||
}
|
||||
|
||||
/// Returns stride expressed in number of bytes for the given `elementStride`
|
||||
/// stride encoded in number of elements of the type `mType`.
|
||||
static Value computeStrideInBytes(Location loc, MemRefType mType,
|
||||
Value elementStride, RewriterBase &rewriter) {
|
||||
Type llvmInt64Type = rewriter.getIntegerType(64);
|
||||
unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
auto attr = rewriter.getI64IntegerAttr(bytes);
|
||||
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
|
||||
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
|
||||
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
|
||||
static Value inferStride(Location loc, MemRefType mType, Value base,
|
||||
RewriterBase &rewriter) {
|
||||
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
|
||||
int64_t preLast = mType.getRank() - 2;
|
||||
Type llvmInt64Type = rewriter.getIntegerType(64);
|
||||
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
|
||||
assert(llvm::isPowerOf2_64(width) && width >= 8);
|
||||
unsigned bytes = width >> 3;
|
||||
auto [strides, offset] = mType.getStridesAndOffset();
|
||||
if (strides[preLast] == ShapedType::kDynamic) {
|
||||
// Dynamic stride needs code to compute the stride at runtime.
|
||||
MemRefDescriptor memrefDescriptor(base);
|
||||
return computeStrideInBytes(
|
||||
loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
|
||||
}
|
||||
// Use direct constant for static stride.
|
||||
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
|
||||
return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
LogicalResult x86::amx::TileZeroOp::verify() {
|
||||
return verifyTileSize(*this, getTileType());
|
||||
}
|
||||
|
||||
SmallVector<Value> x86::amx::TileZeroOp::getIntrinsicOperands(
|
||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
return getTileSizes(getLoc(), getTileType(), rewriter);
|
||||
}
|
||||
|
||||
template <typename OpTy, typename = std::enable_if_t<
|
||||
std::is_same_v<OpTy, x86::amx::TileLoadOp> ||
|
||||
std::is_same_v<OpTy, x86::amx::TileStoreOp>>>
|
||||
static LogicalResult tileTransferVerifier(OpTy op) {
|
||||
MemRefType memrefTy = op.getMemRefType();
|
||||
unsigned rank = memrefTy.getRank();
|
||||
if (op.getIndices().size() != rank)
|
||||
return op.emitOpError("requires ") << rank << " indices";
|
||||
|
||||
if (failed(verifyTileSize(op, op.getTileType())))
|
||||
return failure();
|
||||
|
||||
// Validate basic buffer properties when the stride is implicit.
|
||||
if (!op.getStride()) {
|
||||
if (rank < 2)
|
||||
return op.emitOpError("requires at least 2D memref");
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset;
|
||||
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
|
||||
strides.back() != 1)
|
||||
return op.emitOpError("requires memref with unit innermost stride");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void x86::amx::TileLoadOp::build(OpBuilder &builder, OperationState &state,
|
||||
Type res, Value base, ValueRange indices) {
|
||||
build(builder, state, res, base, indices, /*stride=*/nullptr);
|
||||
}
|
||||
|
||||
LogicalResult x86::amx::TileLoadOp::verify() {
|
||||
return tileTransferVerifier(*this);
|
||||
}
|
||||
|
||||
SmallVector<Value> x86::amx::TileLoadOp::getIntrinsicOperands(
|
||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
SmallVector<Value> intrinsicOperands;
|
||||
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
|
||||
intrinsicOperands.push_back(
|
||||
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
|
||||
adaptor.getBase(), adaptor.getIndices()));
|
||||
if (Value stride = adaptor.getStride())
|
||||
intrinsicOperands.push_back(
|
||||
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
|
||||
else
|
||||
intrinsicOperands.push_back(
|
||||
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
void x86::amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
|
||||
Value base, ValueRange indices, Value val) {
|
||||
build(builder, state, base, indices, val, /*stride=*/nullptr);
|
||||
}
|
||||
|
||||
LogicalResult x86::amx::TileStoreOp::verify() {
|
||||
return tileTransferVerifier(*this);
|
||||
}
|
||||
|
||||
SmallVector<Value> x86::amx::TileStoreOp::getIntrinsicOperands(
|
||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
SmallVector<Value> intrinsicOperands;
|
||||
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
|
||||
intrinsicOperands.push_back(
|
||||
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
|
||||
adaptor.getBase(), adaptor.getIndices()));
|
||||
if (Value stride = adaptor.getStride())
|
||||
intrinsicOperands.push_back(
|
||||
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
|
||||
else
|
||||
intrinsicOperands.push_back(
|
||||
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
|
||||
intrinsicOperands.push_back(adaptor.getVal());
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult x86::amx::TileMulFOp::verify() {
|
||||
x86::amx::TileType aType = getLhsTileType();
|
||||
x86::amx::TileType bType = getRhsTileType();
|
||||
x86::amx::TileType cType = getTileType();
|
||||
if (failed(verifyTileSize(*this, aType)) ||
|
||||
failed(verifyTileSize(*this, bType)) ||
|
||||
failed(verifyTileSize(*this, cType)) ||
|
||||
failed(verifyMultShape(*this, aType, bType, cType, 1)))
|
||||
return failure();
|
||||
Type ta = aType.getElementType();
|
||||
Type tb = bType.getElementType();
|
||||
Type tc = cType.getElementType();
|
||||
if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
|
||||
return emitOpError("unsupported type combination");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> x86::amx::TileMulFOp::getIntrinsicOperands(
|
||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
x86::amx::TileType aType = getLhsTileType();
|
||||
x86::amx::TileType bType = getRhsTileType();
|
||||
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
|
||||
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
|
||||
|
||||
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
|
||||
tsza[1], adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs()};
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
LogicalResult x86::amx::TileMulIOp::verify() {
|
||||
x86::amx::TileType aType = getLhsTileType();
|
||||
x86::amx::TileType bType = getRhsTileType();
|
||||
x86::amx::TileType cType = getTileType();
|
||||
if (failed(verifyTileSize(*this, aType)) ||
|
||||
failed(verifyTileSize(*this, bType)) ||
|
||||
failed(verifyTileSize(*this, cType)) ||
|
||||
failed(verifyMultShape(*this, aType, bType, cType, 2)))
|
||||
return failure();
|
||||
Type ta = aType.getElementType();
|
||||
Type tb = bType.getElementType();
|
||||
Type tc = cType.getElementType();
|
||||
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
|
||||
return emitOpError("unsupported type combination");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> x86::amx::TileMulIOp::getIntrinsicOperands(
|
||||
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
|
||||
RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
Adaptor adaptor(operands, *this);
|
||||
|
||||
x86::amx::TileType aType = getLhsTileType();
|
||||
x86::amx::TileType bType = getRhsTileType();
|
||||
SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
|
||||
SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
|
||||
|
||||
SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
|
||||
tsza[1], adaptor.getAcc(),
|
||||
adaptor.getLhs(), adaptor.getRhs()};
|
||||
|
||||
return intrinsicOperands;
|
||||
}
|
||||
|
||||
Type x86::amx::TileType::parse(AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t, 2> shape;
|
||||
if (parser.parseDimensionList(shape, false, true))
|
||||
return nullptr;
|
||||
|
||||
Type elementType;
|
||||
if (parser.parseType(elementType))
|
||||
return nullptr;
|
||||
|
||||
if (parser.parseGreater())
|
||||
return nullptr;
|
||||
|
||||
return AMXTileType::getChecked(
|
||||
[&] { return parser.emitError(parser.getNameLoc()); }, shape,
|
||||
elementType);
|
||||
}
|
||||
|
||||
void x86::amx::TileType::print(AsmPrinter &os) const {
|
||||
os << "<";
|
||||
os.printDimensionList(getShape());
|
||||
os << 'x';
|
||||
os.printType(getElementType());
|
||||
os << '>';
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/X86/X86.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/X86/X86Types.cpp.inc"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Dialect/X86/Transforms.h"
|
||||
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/X86/X86Dialect.h"
|
||||
@@ -39,10 +40,32 @@ struct X86IntrinsicOpConversion
|
||||
|
||||
/// Populate the given list with patterns that convert from X86 to LLVM.
|
||||
void mlir::populateX86LegalizeForLLVMExportPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<X86IntrinsicOpConversion>(converter);
|
||||
converter.addConversion([&](x86::amx::TileType type) {
|
||||
return LLVM::LLVMX86AMXType::get(&converter.getContext());
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) {
|
||||
target.addIllegalDialect<X86Dialect>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Implement the interface to convert X86 to LLVM.
|
||||
struct X86ToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateX86LegalizeForLLVMExportPatterns(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::registerConvertX86ToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, x86::X86Dialect *dialect) {
|
||||
dialect->addInterfaces<X86ToLLVMDialectInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "mlir/InitAllDialects.h"
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||||
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
@@ -111,7 +110,6 @@ void mlir::registerAllDialects(DialectRegistry ®istry) {
|
||||
registry.insert<acc::OpenACCDialect,
|
||||
affine::AffineDialect,
|
||||
amdgpu::AMDGPUDialect,
|
||||
amx::AMXDialect,
|
||||
arith::ArithDialect,
|
||||
arm_neon::ArmNeonDialect,
|
||||
arm_sme::ArmSMEDialect,
|
||||
|
||||
@@ -32,7 +32,6 @@
|
||||
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
|
||||
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/AMX/Transforms.h"
|
||||
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
|
||||
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
|
||||
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
|
||||
@@ -57,6 +56,7 @@
|
||||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
|
||||
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
||||
#include "mlir/Dialect/X86/TransformOps/X86TransformOps.h"
|
||||
#include "mlir/Dialect/X86/Transforms.h"
|
||||
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
|
||||
@@ -90,10 +90,10 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
|
||||
registerConvertOpenMPToLLVMInterface(registry);
|
||||
registerConvertSCFToEmitCInterface(registry);
|
||||
ub::registerConvertUBToLLVMInterface(registry);
|
||||
registerConvertAMXToLLVMInterface(registry);
|
||||
gpu::registerConvertGpuToLLVMInterface(registry);
|
||||
NVVM::registerConvertGpuToNVVMInterface(registry);
|
||||
vector::registerConvertVectorToLLVMInterface(registry);
|
||||
registerConvertX86ToLLVMInterface(registry);
|
||||
|
||||
// Register all transform dialect extensions.
|
||||
affine::registerTransformDialectExtension(registry);
|
||||
|
||||
@@ -32,8 +32,8 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
|
||||
"The GPU compilation format used by the tests.")
|
||||
set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
|
||||
"Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
|
||||
option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
|
||||
option(MLIR_RUN_X86_TESTS "Run X86 tests.")
|
||||
option(MLIR_RUN_X86_AMX_TESTS "Run X86 AMX tests.")
|
||||
option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
|
||||
option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
|
||||
option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
|
||||
@@ -77,7 +77,7 @@ llvm_canonicalize_cmake_booleans(
|
||||
MLIR_ENABLE_SPIRV_CPU_RUNNER
|
||||
MLIR_ENABLE_VULKAN_RUNNER
|
||||
MLIR_INCLUDE_INTEGRATION_TESTS
|
||||
MLIR_RUN_AMX_TESTS
|
||||
MLIR_RUN_X86_AMX_TESTS
|
||||
MLIR_RUN_CUDA_TENSOR_CORE_TESTS
|
||||
MLIR_RUN_X86_TESTS
|
||||
MLIR_RUN_ARM_SVE_TESTS
|
||||
|
||||
@@ -34,27 +34,27 @@ func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
|
||||
// CHECK: vector.transfer_write %[[A]], %[[A_BUF]]
|
||||
// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
|
||||
// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
|
||||
// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
|
||||
// CHECK: %[[A_TILE:.+]] = x86.amx.tile_load %[[A_BUF_2D]]
|
||||
|
||||
/// Load B vector into an AMX tile
|
||||
// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
|
||||
// CHECK: vector.transfer_write %[[B]], %[[B_BUF]]
|
||||
// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
|
||||
// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
|
||||
// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
|
||||
// CHECK: %[[B_TILE:.+]] = x86.amx.tile_load %[[B_BUF_2D]]
|
||||
|
||||
/// Load C vector into an AMX tile
|
||||
// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
|
||||
// CHECK: vector.transfer_write %[[C]], %[[C_BUF]]
|
||||
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
|
||||
// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_BUF]]
|
||||
|
||||
/// Perform tile multiplication
|
||||
// CHECK: %[[RES:.+]] = amx.tile_mulf
|
||||
// CHECK: %[[RES:.+]] = x86.amx.tile_mulf
|
||||
// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
|
||||
|
||||
/// Load the result back into a vector
|
||||
// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
|
||||
// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
|
||||
// CHECK: x86.amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
|
||||
// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
|
||||
|
||||
// CHECK: return %[[RES_VEC]]
|
||||
@@ -75,9 +75,9 @@ func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @contract_vnni_bf16(
|
||||
// CHECK-COUNT-3: amx.tile_load
|
||||
// CHECK: amx.tile_mulf
|
||||
// CHECK: amx.tile_store
|
||||
// CHECK-COUNT-3: x86.amx.tile_load
|
||||
// CHECK: x86.amx.tile_mulf
|
||||
// CHECK: x86.amx.tile_store
|
||||
|
||||
// -----
|
||||
|
||||
@@ -95,9 +95,9 @@ func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @contract_vnni_i8(
|
||||
// CHECK-COUNT-3: amx.tile_load
|
||||
// CHECK: amx.tile_muli
|
||||
// CHECK: amx.tile_store
|
||||
// CHECK-COUNT-3: x86.amx.tile_load
|
||||
// CHECK: x86.amx.tile_muli
|
||||
// CHECK: x86.amx.tile_store
|
||||
|
||||
// -----
|
||||
|
||||
@@ -115,9 +115,9 @@ func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4x
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @contract_shuffled_iterators(
|
||||
// CHECK-COUNT-3: amx.tile_load
|
||||
// CHECK: amx.tile_muli
|
||||
// CHECK: amx.tile_store
|
||||
// CHECK-COUNT-3: x86.amx.tile_load
|
||||
// CHECK: x86.amx.tile_muli
|
||||
// CHECK: x86.amx.tile_store
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
|
||||
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
|
||||
// CHECK: %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
|
||||
// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
|
||||
// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
|
||||
// CHECK: %[[A_TILE:.+]] = x86.amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
|
||||
// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
|
||||
// CHECK-NOT: vector.transfer_read %[[A]]
|
||||
|
||||
@@ -47,25 +47,25 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
|
||||
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
|
||||
// CHECK: %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
|
||||
// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
|
||||
// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
|
||||
// CHECK: %[[B_TILE:.+]] = x86.amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
|
||||
// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
|
||||
// CHECK-NOT: vector.transfer_read %[[B]]
|
||||
|
||||
/// Load C into an AMX tile
|
||||
// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
|
||||
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
|
||||
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
|
||||
// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]]
|
||||
// CHECK-SAME: {{\[}}%[[C0]], %[[C0]]{{\]}}
|
||||
// CHECK-NOT: vector.transfer_read %[[C]]
|
||||
|
||||
/// Perform tile multiplication
|
||||
// CHECK: %[[RES:.+]] = amx.tile_mulf
|
||||
// CHECK: %[[RES:.+]] = x86.amx.tile_mulf
|
||||
// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
|
||||
|
||||
/// Store the result back
|
||||
// CHECK: %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
|
||||
// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
|
||||
// CHECK: amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
|
||||
// CHECK: x86.amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
|
||||
// CHECK-NOT: vector.transfer_write{{.*}}%[[C]]
|
||||
|
||||
// -----
|
||||
@@ -130,17 +130,17 @@ func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
|
||||
|
||||
/// Load to AMX tile directly from buffer.
|
||||
// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
|
||||
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
|
||||
// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]]
|
||||
|
||||
/// Vector read remains to load data for the other non-AMX consumer.
|
||||
// CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
|
||||
|
||||
/// Contraction uses the directly loaded tile.
|
||||
// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
|
||||
// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf{{.*}}%[[C_TILE]]
|
||||
|
||||
/// Consumer uses original C value and the updated one after contraction.
|
||||
// CHECK: %[[RES_BUF:.+]] = memref.alloca
|
||||
// CHECK: amx.tile_store %[[RES_BUF]]
|
||||
// CHECK: x86.amx.tile_store %[[RES_BUF]]
|
||||
// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
|
||||
// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]], %[[RES_VEC]]
|
||||
|
||||
@@ -168,7 +168,7 @@ func.func @negative_contract_multiple_users(%C: memref<64x64xf32>,
|
||||
|
||||
// CHECK-LABEL: @negative_contract_multiple_users(
|
||||
// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
|
||||
// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf
|
||||
// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf
|
||||
// CHECK: vector.transfer_write{{.*}}%[[C]]
|
||||
|
||||
// -----
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
// CHECK: builtin.module(
|
||||
// CHECK-SAME: convert-vector-to-llvm{
|
||||
// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
|
||||
// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
|
||||
// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
|
||||
// CHECK-SAME: enable-x86={{[aA-zZ0-9]+}}
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
func.func @tile_row_height() {
|
||||
// expected-error@+1 {{'amx.tile_zero' op bad row height: 17}}
|
||||
%0 = amx.tile_zero : !amx.tile<17x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_col_width() {
|
||||
// expected-error@+1 {{'amx.tile_zero' op bad column width: 65}}
|
||||
%0 = amx.tile_zero : !amx.tile<16x65xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_element_type() {
|
||||
// expected-error@+1 {{failed to verify 'elementType'}}
|
||||
%0 = amx.tile_zero : !amx.tile<8x8xi16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_rank() {
|
||||
// expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
|
||||
%0 = amx.tile_zero : !amx.tile<32xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_col_4_byte_multiple() {
|
||||
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
|
||||
%0 = amx.tile_zero : !amx.tile<16x5xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
|
||||
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
|
||||
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
|
||||
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_rank(%arg0: memref<?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_load' op requires at least 2D memref}}
|
||||
%1 = amx.tile_load %arg0[%0] : memref<?xf32> into !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !amx.tile<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_store' op requires at least 2D memref}}
|
||||
amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_load' op requires memref with unit innermost stride}}
|
||||
%1 = amx.tile_load %arg0[%0, %0]
|
||||
: memref<?x?xf32, strided<[?, ?]>> into !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
|
||||
%arg1: !amx.tile<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'amx.tile_store' op requires memref with unit innermost stride}}
|
||||
amx.tile_store %arg0[%0, %0], %arg1
|
||||
: memref<?x?xf32, strided<[?, ?]>>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @mulf_shape() {
|
||||
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
|
||||
%1 = amx.tile_zero : !amx.tile<8x8xbf16>
|
||||
%2 = amx.tile_zero : !amx.tile<4x4xf32>
|
||||
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
|
||||
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @mulf_type_combination() {
|
||||
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
|
||||
%1 = amx.tile_zero : !amx.tile<4x8xf16>
|
||||
%2 = amx.tile_zero : !amx.tile<8x4xf32>
|
||||
// expected-error@+1 {{'amx.tile_mulf' op unsupported type combination}}
|
||||
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<4x8xf16>, !amx.tile<8x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @muli_shape() {
|
||||
%0 = amx.tile_zero : !amx.tile<8x8xi8>
|
||||
%1 = amx.tile_zero : !amx.tile<8x8xi8>
|
||||
%2 = amx.tile_zero : !amx.tile<4x4xi32>
|
||||
// expected-error@+1 {{'amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
|
||||
%3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x8xi8>, !amx.tile<8x8xi8>, !amx.tile<4x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @muli_type_combination() {
|
||||
%0 = amx.tile_zero : !amx.tile<8x16xi8>
|
||||
%1 = amx.tile_zero : !amx.tile<8x16xi32>
|
||||
%2 = amx.tile_zero : !amx.tile<2x2xi32>
|
||||
// expected-error@+1 {{'amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
|
||||
%3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x16xi8>, !amx.tile<8x16xi32>, !amx.tile<2x2xi32>
|
||||
return
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: tloadstore
|
||||
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} :
|
||||
// CHECK-SAME: memref<?xbf16> into !amx.tile<16x32xbf16>
|
||||
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} :
|
||||
// CHECK-SAME: memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] :
|
||||
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} :
|
||||
// CHECK-SAME: memref<?xbf16>, !amx.tile<16x32xbf16>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} :
|
||||
// CHECK-SAME: memref<?x?xbf16>, !amx.tile<16x32xbf16>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] :
|
||||
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
|
||||
func.func @tloadstore(%stride: index,
|
||||
%arg0: memref<?xbf16>,
|
||||
%arg1: memref<?x?xbf16>,
|
||||
%arg2: memref<?x?xbf16, strided<[64, 1]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%c64 = arith.constant 64 : index
|
||||
%1 = amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !amx.tile<16x32xbf16>
|
||||
%2 = amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
|
||||
amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !amx.tile<16x32xbf16>
|
||||
amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !amx.tile<16x32xbf16>
|
||||
amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tzero
|
||||
// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
|
||||
func.func @tzero(%arg0: memref<?x?xbf16>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x16xbf16>
|
||||
amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tmulf
|
||||
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
|
||||
// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
|
||||
%3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tmuli
|
||||
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
|
||||
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
|
||||
// Verify the parsing/printing of the sign-extension annotation.
|
||||
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
|
||||
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
|
||||
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
|
||||
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
|
||||
// Verify the various `zext` combinations.
|
||||
%5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
return
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-amx" | FileCheck %s
|
||||
|
||||
// With inclusion of memory side-effects, it is expected CSE not to fold multiple
|
||||
// "tileload" and "tilezero".
|
||||
// CHECK-LABEL: do_not_fold_tiles(
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
%alloca = memref.alloca() : memref<16x32xf32>
|
||||
%0 = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
%1 = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
%2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>) {
|
||||
%3 = amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
|
||||
%4 = amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
|
||||
%5 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
|
||||
%6 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
|
||||
%7 = amx.tile_mulf %3, %5, %arg3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
%8 = amx.tile_mulf %4, %6, %arg4 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
scf.yield %7, %8 : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
|
||||
}
|
||||
amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !amx.tile<16x16xf32>
|
||||
amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !amx.tile<16x16xf32>
|
||||
return %alloca : memref<16x32xf32>
|
||||
}
|
||||
@@ -635,10 +635,10 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_type_matmul(%arg0 : !amx.tile<16x16xbf16>)
|
||||
func.func @invalid_type_matmul(%arg0 : !x86.amx.tile<16x16xbf16>)
|
||||
{
|
||||
// expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
|
||||
%0 = linalg.matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
|
||||
// expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
|
||||
%0 = linalg.matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1582,10 +1582,10 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_type_batch_matmul(%arg0 : !amx.tile<16x16xbf16>)
|
||||
func.func @invalid_type_batch_matmul(%arg0 : !x86.amx.tile<16x16xbf16>)
|
||||
{
|
||||
// expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
|
||||
%0 = linalg.batch_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
|
||||
// expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
|
||||
%0 = linalg.batch_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1790,10 +1790,10 @@ func.func @invalid_C_map_result_dim(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @batch_reduce_matmul_invalid_type(%arg0 : !amx.tile<16x16xbf16>)
|
||||
func.func @batch_reduce_matmul_invalid_type(%arg0 : !x86.amx.tile<16x16xbf16>)
|
||||
{
|
||||
// expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
|
||||
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
|
||||
// expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}}
|
||||
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
158
mlir/test/Dialect/X86/AMX/invalid.mlir
Normal file
158
mlir/test/Dialect/X86/AMX/invalid.mlir
Normal file
@@ -0,0 +1,158 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
func.func @tile_row_height() {
|
||||
// expected-error@+1 {{'x86.amx.tile_zero' op bad row height: 17}}
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<17x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_col_width() {
|
||||
// expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 65}}
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<16x65xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_element_type() {
|
||||
// expected-error@+1 {{failed to verify 'elementType'}}
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_rank() {
|
||||
// expected-error@+1 {{'x86.amx.tile_zero' op result #0 must be tile of}}
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<32xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @tile_col_4_byte_multiple() {
|
||||
// expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 5}}
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<16x5xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_load' op bad column width: 68}}
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x17xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x17xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_store' op bad column width: 68}}
|
||||
x86.amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x17xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_load' op requires 2 indices}}
|
||||
%1 = x86.amx.tile_load %arg0[%0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_store' op requires 2 indices}}
|
||||
x86.amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_rank(%arg0: memref<?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_load' op requires at least 2D memref}}
|
||||
%1 = x86.amx.tile_load %arg0[%0] : memref<?xf32> into !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_store' op requires at least 2D memref}}
|
||||
x86.amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_load' op requires memref with unit innermost stride}}
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0]
|
||||
: memref<?x?xf32, strided<[?, ?]>> into !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
|
||||
%arg1: !x86.amx.tile<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'x86.amx.tile_store' op requires memref with unit innermost stride}}
|
||||
x86.amx.tile_store %arg0[%0, %0], %arg1
|
||||
: memref<?x?xf32, strided<[?, ?]>>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @mulf_shape() {
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
|
||||
%2 = x86.amx.tile_zero : !x86.amx.tile<4x4xf32>
|
||||
// expected-error@+1 {{'x86.amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
|
||||
%3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @mulf_type_combination() {
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<4x8xf16>
|
||||
%2 = x86.amx.tile_zero : !x86.amx.tile<8x4xf32>
|
||||
// expected-error@+1 {{'x86.amx.tile_mulf' op unsupported type combination}}
|
||||
%3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x8xf16>, !x86.amx.tile<8x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @muli_shape() {
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
|
||||
%2 = x86.amx.tile_zero : !x86.amx.tile<4x4xi32>
|
||||
// expected-error@+1 {{'x86.amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
|
||||
%3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x8xi8>, !x86.amx.tile<8x8xi8>, !x86.amx.tile<4x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @muli_type_combination() {
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<8x16xi8>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<8x16xi32>
|
||||
%2 = x86.amx.tile_zero : !x86.amx.tile<2x2xi32>
|
||||
// expected-error@+1 {{'x86.amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
|
||||
%3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x16xi8>, !x86.amx.tile<8x16xi32>, !x86.amx.tile<2x2xi32>
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: muli(
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
@@ -14,17 +14,17 @@
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x64xi8>
|
||||
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
|
||||
%5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
|
||||
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
|
||||
%7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !amx.tile<16x16xi32>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8>
|
||||
%2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
%3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
|
||||
%5 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
|
||||
%6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
|
||||
%7 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x32xbf16>
|
||||
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
|
||||
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<16x32xbf16>
|
||||
%2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
|
||||
%4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
|
||||
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_zero : !amx.tile<16x32xf16>
|
||||
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
|
||||
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
|
||||
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
|
||||
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<16x32xf16>
|
||||
%2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
|
||||
%3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
|
||||
%4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,12 +84,12 @@ func.func @strides_implicit(%arg0: memref<16x32xi8>,
|
||||
%arg1: memref<32x32xbf16, strided<[64, 1]>>,
|
||||
%arg2: memref<16x32xf32, strided<[?, 1]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32>
|
||||
amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8>
|
||||
amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
|
||||
amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !x86.amx.tile<16x32xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
|
||||
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !x86.amx.tile<16x32xi8>
|
||||
x86.amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
|
||||
x86.amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,11 +123,11 @@ func.func @strides_explicit(%stride: index,
|
||||
%arg2: memref<32x32xf32, strided<[64, 1]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%c64 = arith.constant 64 : index
|
||||
%1 = amx.tile_load %arg0[%0], %stride : memref<?xi8> into !amx.tile<16x32xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !amx.tile<16x32xbf16>
|
||||
%3 = amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !amx.tile<16x16xf32>
|
||||
amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !amx.tile<16x32xi8>
|
||||
amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !amx.tile<16x32xbf16>
|
||||
amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !amx.tile<16x16xf32>
|
||||
%1 = x86.amx.tile_load %arg0[%0], %stride : memref<?xi8> into !x86.amx.tile<16x32xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%3 = x86.amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !x86.amx.tile<16x32xi8>
|
||||
x86.amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !x86.amx.tile<16x32xbf16>
|
||||
x86.amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
77
mlir/test/Dialect/X86/AMX/roundtrip.mlir
Normal file
77
mlir/test/Dialect/X86/AMX/roundtrip.mlir
Normal file
@@ -0,0 +1,77 @@
|
||||
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: tloadstore
|
||||
// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} :
|
||||
// CHECK-SAME: memref<?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} :
|
||||
// CHECK-SAME: memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] :
|
||||
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} :
|
||||
// CHECK-SAME: memref<?xbf16>, !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} :
|
||||
// CHECK-SAME: memref<?x?xbf16>, !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] :
|
||||
// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
|
||||
func.func @tloadstore(%stride: index,
|
||||
%arg0: memref<?xbf16>,
|
||||
%arg1: memref<?x?xbf16>,
|
||||
%arg2: memref<?x?xbf16, strided<[64, 1]>>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%c64 = arith.constant 64 : index
|
||||
%1 = x86.amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16>
|
||||
x86.amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !x86.amx.tile<16x32xbf16>
|
||||
x86.amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !x86.amx.tile<16x32xbf16>
|
||||
x86.amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tzero
|
||||
// CHECK: x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
|
||||
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !x86.amx.tile<16x16xbf16>
|
||||
func.func @tzero(%arg0: memref<?x?xbf16>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16>
|
||||
x86.amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !x86.amx.tile<16x16xbf16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tmulf
|
||||
// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
|
||||
// CHECK: %[[m:.*]] = x86.amx.tile_mulf %[[x]], %[[x]], %[[z]] : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
|
||||
%3 = x86.amx.tile_mulf %1, %1, %2 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tmuli
|
||||
// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
|
||||
// CHECK: %[[m:.*]] = x86.amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
|
||||
// Verify the parsing/printing of the sign-extension annotation.
|
||||
// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
|
||||
// CHECK: x86.amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
|
||||
// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
|
||||
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !x86.amx.tile<16x16xi32>
|
||||
// Verify the various `zext` combinations.
|
||||
%5 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
%6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
%7 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
return
|
||||
}
|
||||
32
mlir/test/Dialect/X86/AMX/side-effects.mlir
Normal file
32
mlir/test/Dialect/X86/AMX/side-effects.mlir
Normal file
@@ -0,0 +1,32 @@
|
||||
// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-x86" | FileCheck %s
|
||||
|
||||
// With inclusion of memory side-effects, it is expected CSE not to fold multiple
|
||||
// "tileload" and "tilezero".
|
||||
// CHECK-LABEL: do_not_fold_tiles(
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
|
||||
func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
%alloca = memref.alloca() : memref<16x32xf32>
|
||||
%0 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
|
||||
%2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) {
|
||||
%3 = x86.amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%4 = x86.amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%5 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%6 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%7 = x86.amx.tile_mulf %3, %5, %arg3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
%8 = x86.amx.tile_mulf %4, %6, %arg4 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
scf.yield %7, %8 : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>
|
||||
}
|
||||
x86.amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !x86.amx.tile<16x16xf32>
|
||||
return %alloca : memref<16x32xf32>
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
# AMX tests must be enabled via build flag.
|
||||
if not config.mlir_run_amx_tests:
|
||||
if not config.mlir_run_x86_amx_tests:
|
||||
config.unsupported = True
|
||||
|
||||
# No JIT on win32.
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
|
||||
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
|
||||
// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" \
|
||||
// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" \
|
||||
// RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
|
||||
@@ -14,11 +14,11 @@ func.func @kernel(%arg0: memref<16x32xbf16>,
|
||||
%arg1: memref<16x32xbf16>,
|
||||
%arg2: memref<16x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
|
||||
%3 = amx.tile_zero : vector<16x16xf32>
|
||||
%4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
|
||||
%3 = x86.amx.tile_zero : vector<16x16xf32>
|
||||
%4 = x86.amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
|
||||
// RUN: FileCheck %s
|
||||
@@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x4xbf16>,
|
||||
%arg1: memref<2x4xbf16>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%3 = amx.tile_zero : vector<2x2xf32>
|
||||
%4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%3 = x86.amx.tile_zero : vector<2x2xf32>
|
||||
%4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x4xbf16>,
|
||||
%arg1: memref<2x4xbf16>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
|
||||
%4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16>
|
||||
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32>
|
||||
%4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
|
||||
// RUN: FileCheck %s
|
||||
@@ -21,11 +21,11 @@ func.func @kernel1(%arg0: memref<16x16xi8>,
|
||||
%arg1: memref<4x16xi8>,
|
||||
%arg2: memref<16x4xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = x86.amx.tile_zero : vector<16x4xi32>
|
||||
%4 = x86.amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -33,11 +33,11 @@ func.func @kernel2(%arg0: memref<16x16xi8>,
|
||||
%arg1: memref<4x16xi8>,
|
||||
%arg2: memref<16x4xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = x86.amx.tile_zero : vector<16x4xi32>
|
||||
%4 = x86.amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -45,11 +45,11 @@ func.func @kernel3(%arg0: memref<16x16xi8>,
|
||||
%arg1: memref<4x16xi8>,
|
||||
%arg2: memref<16x4xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = x86.amx.tile_zero : vector<16x4xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -57,11 +57,11 @@ func.func @kernel4(%arg0: memref<16x16xi8>,
|
||||
%arg1: memref<4x16xi8>,
|
||||
%arg2: memref<16x4xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = x86.amx.tile_zero : vector<16x4xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
|
||||
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
|
||||
// RUN: -convert-scf-to-cf \
|
||||
// RUN: -convert-vector-to-llvm="enable-amx" \
|
||||
// RUN: -convert-vector-to-llvm="enable-x86" \
|
||||
// RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
|
||||
@@ -15,11 +15,11 @@ func.func @kernel(%arg0: memref<16x64xi8>,
|
||||
%arg1: memref<16x64xi8>,
|
||||
%arg2: memref<16x16xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
|
||||
%3 = amx.tile_zero : vector<16x16xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x64xi8> into vector<16x64xi8>
|
||||
%3 = x86.amx.tile_zero : vector<16x16xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
|
||||
// RUN: FileCheck %s
|
||||
@@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x8xi8>,
|
||||
%arg1: memref<2x8xi8>,
|
||||
%arg2: memref<2x2xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%3 = amx.tile_zero : vector<2x2xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%3 = x86.amx.tile_zero : vector<2x2xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x8xi8>,
|
||||
%arg1: memref<2x8xi8>,
|
||||
%arg2: memref<2x2xi32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
|
||||
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
|
||||
%4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
|
||||
// RUN: FileCheck %s
|
||||
@@ -25,8 +25,8 @@ func.func @kernel(%arg0: memref<4x32xf32>) {
|
||||
%c32 = arith.constant 32 : index
|
||||
scf.for %i = %c0 to %c4 step %c2 {
|
||||
scf.for %j = %c0 to %c32 step %c16 {
|
||||
%0 = amx.tile_zero : vector<2x16xf32>
|
||||
amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
|
||||
%0 = x86.amx.tile_zero : vector<2x16xf32>
|
||||
x86.amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32>
|
||||
func.call @print(%arg0) : (memref<4x32xf32>) -> ()
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-translate -mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \
|
||||
// RUN: FileCheck %s
|
||||
@@ -6,8 +6,8 @@
|
||||
// Note: To run this test, your CPU must support AMX.
|
||||
|
||||
func.func @tilezero(%arg0: memref<?x?xi32>, %i: index, %j: index) {
|
||||
%1 = amx.tile_zero : vector<16x16xi32>
|
||||
amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
|
||||
%1 = x86.amx.tile_zero : vector<16x16xi32>
|
||||
x86.amx.tile_store %arg0[%i, %j], %1 : memref<?x?xi32>, vector<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5,14 +5,15 @@
|
||||
|
||||
func.func @entry() -> i32 {
|
||||
%i0 = arith.constant 0 : i32
|
||||
%i4 = arith.constant 4 : i32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
|
||||
%a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32>
|
||||
%b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32>
|
||||
%r = x86.avx.intr.dot %a, %b : vector<8xf32>
|
||||
|
||||
%1 = vector.extract %r[%i0] : f32 from vector<8xf32>
|
||||
%2 = vector.extract %r[%i4] : f32 from vector<8xf32>
|
||||
%1 = vector.extract %r[%c0] : f32 from vector<8xf32>
|
||||
%2 = vector.extract %r[%c4] : f32 from vector<8xf32>
|
||||
%d = arith.addf %1, %2 : f32
|
||||
|
||||
// CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 )
|
||||
|
||||
@@ -180,8 +180,7 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
||||
-> f64 {
|
||||
// Helper constants for loops.
|
||||
%c0 = arith.constant 0 : index
|
||||
%i0 = arith.constant 0 : i32
|
||||
%i7 = arith.constant 7 : i32
|
||||
%c7 = arith.constant 7 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
|
||||
%data_zero = arith.constant 0.0 : f64
|
||||
@@ -196,13 +195,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
||||
iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) {
|
||||
%v_A = vector.transfer_read %m_A[%a], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
%segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
|
||||
%segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64>
|
||||
|
||||
%r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8
|
||||
iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) {
|
||||
%v_C = vector.transfer_read %m_C[%b], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
%segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
|
||||
%segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
|
||||
%seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
|
||||
|
||||
%r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) {
|
||||
@@ -251,8 +250,7 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
||||
-> f64 {
|
||||
// Helper constants for loops.
|
||||
%c0 = arith.constant 0 : index
|
||||
%i0 = arith.constant 0 : i32
|
||||
%i7 = arith.constant 7 : i32
|
||||
%c7 = arith.constant 7 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
|
||||
%data_zero = arith.constant 0.0 : f64
|
||||
@@ -273,10 +271,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
||||
%v_C = vector.transfer_read %m_C[%b1], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
|
||||
%segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64>
|
||||
%segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
|
||||
%segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64>
|
||||
%segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
|
||||
%segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64>
|
||||
%segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64>
|
||||
%segB_min = vector.extract %v_C[%c0] : i64 from vector<8xi64>
|
||||
%segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
|
||||
|
||||
%seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64
|
||||
%r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
|
||||
@@ -340,7 +338,7 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
|
||||
-> f64 {
|
||||
// Helper constants for loops.
|
||||
%c0 = arith.constant 0 : index
|
||||
%i7 = arith.constant 7 : i32
|
||||
%c7 = arith.constant 7 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
|
||||
%data_zero = arith.constant 0.0 : f64
|
||||
@@ -370,8 +368,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64
|
||||
-> f64
|
||||
%r2 = arith.addf %r1, %subresult : f64
|
||||
|
||||
%segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64>
|
||||
%segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64>
|
||||
%segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64>
|
||||
%segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64>
|
||||
|
||||
%cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64
|
||||
%cond_a_i64 = arith.extui %cond_a : i1 to i64
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \
|
||||
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \
|
||||
// RUN: | mlir-translate --mlir-to-llvmir \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
@@ -7,8 +7,8 @@ func.func @amx_tile_zero(%out: memref<?x?xf32>, %idx: index)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
%zero = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
%zero = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %out[%idx, %idx], %zero : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,8 +18,8 @@ func.func @amx_tile_load_store(%base: memref<?x?xi8>, %out: memref<?x?xi8>,
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
%val = amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !amx.tile<16x64xi8>
|
||||
%val = x86.amx.tile_load %base[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
x86.amx.tile_store %out[%idx, %idx], %val : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -29,10 +29,10 @@ func.func @amx_tile_load_store_strided(%base: memref<?xi8>, %out: memref<?xi8>,
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
%val = amx.tile_load %base[%idx], %stride
|
||||
: memref<?xi8> into !amx.tile<16x64xi8>
|
||||
amx.tile_store %out[%idx], %val, %stride
|
||||
: memref<?xi8>, !amx.tile<16x64xi8>
|
||||
%val = x86.amx.tile_load %base[%idx], %stride
|
||||
: memref<?xi8> into !x86.amx.tile<16x64xi8>
|
||||
x86.amx.tile_store %out[%idx], %val, %stride
|
||||
: memref<?xi8>, !x86.amx.tile<16x64xi8>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -42,15 +42,15 @@ func.func @amx_tile_mulf_bf16(
|
||||
%out: memref<?x?xf32>)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
%acc = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
%acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
|
||||
// CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
|
||||
%tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
%tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xbf16> into !x86.amx.tile<16x32xbf16>
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal
|
||||
%tRes = amx.tile_mulf %tA, %tB, %acc
|
||||
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
|
||||
%tRes = x86.amx.tile_mulf %tA, %tB, %acc
|
||||
: !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -60,15 +60,15 @@ func.func @amx_tile_mulf_f16(
|
||||
%out: memref<?x?xf32>)
|
||||
{
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
%acc = amx.tile_zero : !amx.tile<16x16xf32>
|
||||
%acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
|
||||
// CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
|
||||
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !amx.tile<16x32xf16>
|
||||
%tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
|
||||
%tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xf16> into !x86.amx.tile<16x32xf16>
|
||||
// CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal
|
||||
%tRes = amx.tile_mulf %tA, %tB, %acc
|
||||
: !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
|
||||
%tRes = x86.amx.tile_mulf %tA, %tB, %acc
|
||||
: !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32>
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !amx.tile<16x16xf32>
|
||||
x86.amx.tile_store %out[%idx, %idx], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -79,26 +79,26 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
|
||||
%c0 = arith.constant 0 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
// CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%tA = amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%tB = amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
%acc = amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !amx.tile<16x16xi32>
|
||||
%tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
%tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
%acc = x86.amx.tile_load %matC[%idx, %idx] : memref<?x?xi32> into !x86.amx.tile<16x16xi32>
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbuud.internal
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbssd.internal
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbusd.internal
|
||||
// CHECK: call x86_amx @llvm.x86.tdpbsud.internal
|
||||
%res = amx.tile_muli %tA zext, %tB zext, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res1 = amx.tile_muli %tA, %tB, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res2 = amx.tile_muli %tA zext, %tB, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res3 = amx.tile_muli %tA, %tB zext, %acc
|
||||
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
|
||||
%res = x86.amx.tile_muli %tA zext, %tB zext, %acc
|
||||
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
%res1 = x86.amx.tile_muli %tA, %tB, %acc
|
||||
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
%res2 = x86.amx.tile_muli %tA zext, %tB, %acc
|
||||
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
%res3 = x86.amx.tile_muli %tA, %tB zext, %acc
|
||||
: !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32>
|
||||
// CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %out[%c0, %c0], %res : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %out[%c0, %c16], %res1 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %out[%c16, %c0], %res2 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
|
||||
x86.amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !x86.amx.tile<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,16 +108,16 @@ func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
|
||||
%0 = x86.amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !x86.amx.tile<16x64xi8>
|
||||
cf.br ^bb3(%0 : !x86.amx.tile<16x64xi8>)
|
||||
^bb2: // pred: ^bb0
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
%1 = amx.tile_zero : !amx.tile<16x64xi8>
|
||||
cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
|
||||
^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
|
||||
%1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8>
|
||||
cf.br ^bb3(%1 : !x86.amx.tile<16x64xi8>)
|
||||
^bb3(%2: !x86.amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
|
||||
cf.br ^bb4
|
||||
^bb4: // pred: ^bb3
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
|
||||
x86.amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !x86.amx.tile<16x64xi8>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ config.enable_vulkan_runner = @MLIR_ENABLE_VULKAN_RUNNER@
|
||||
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
|
||||
config.enable_python_stable_abi = @MLIR_ENABLE_PYTHON_STABLE_ABI@
|
||||
config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@"
|
||||
config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@
|
||||
config.mlir_run_x86_amx_tests = @MLIR_RUN_X86_AMX_TESTS@
|
||||
config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@
|
||||
# This is a workaround for the fact that LIT's:
|
||||
# %if <cond>
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
// CHECK-SAME: acc
|
||||
// CHECK-SAME: affine
|
||||
// CHECK-SAME: amdgpu
|
||||
// CHECK-SAME: amx
|
||||
// CHECK-SAME: arith
|
||||
// CHECK-SAME: arm_neon
|
||||
// CHECK-SAME: arm_sme
|
||||
|
||||
Reference in New Issue
Block a user