Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions. Following the dialect renaming, X86 dialect is now a suitable home for wider range of operations targeting specific hardware features. Moving AMX definitions to X86 dialect creates a single, centralized hub for defining all x86 intrinsic-like operations. The new grouping aims to eliminate the need for new dialects as new hardware extensions become available. The two dialects are simply merged together. X86 dialect refactoring will be addressed separately. List of changes: - operations: 'amx.tile_*' => 'x86.amx.tile_*' - types: '!amx.tile' => '!x86.amx.tile' - namespace: 'mlir::amx' => 'mlir::x86::amx' - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS' - vector lowering: AMX is enabled by default together with X86 The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
431 lines
17 KiB
C++
431 lines
17 KiB
C++
//===- VectorToAMX.cpp - Convert vector to X86 dialect AMX ops --*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/X86/X86Dialect.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/DebugLog.h"
|
|
|
|
#include <numeric>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTVECTORTOAMX
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
#define DEBUG_TYPE "vector-to-amx"
|
|
|
|
namespace {
|
|
|
|
/// Return true if vector shape is compatible with AMX tiles.
|
|
/// The validation accounts for VNNI packing.
|
|
static bool verifyAmxShape(VectorType vec) {
|
|
// Check overall shape:
|
|
// - 2D for plain layout input or output
|
|
// - 3D for VNNI packed input
|
|
if (vec.getRank() != 2 && vec.getRank() != 3)
|
|
return false;
|
|
|
|
ArrayRef<int64_t> shape = vec.getShape();
|
|
int64_t rows = shape[0];
|
|
int64_t cols = shape[1];
|
|
unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
|
|
|
|
// 3D shape indicates VNNI packed layout.
|
|
if (vec.getRank() == 3) {
|
|
int64_t vnniFactor = 32 / elemBitWidth;
|
|
if (shape.back() != vnniFactor) {
|
|
LDBG() << "invalid VNNI packing factor";
|
|
return false;
|
|
}
|
|
cols *= vnniFactor;
|
|
}
|
|
|
|
// AMX tile supports up to 16 rows of 64 bytes each.
|
|
constexpr unsigned maxRows = 16;
|
|
constexpr unsigned maxBitsPerRow = 64 * 8;
|
|
return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
|
|
}
|
|
|
|
/// Check if contraction operands are in AMX-compatible packed VNNI layout.
|
|
static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
|
|
vector::ContractionOp contractOp) {
|
|
VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
|
|
if (!accType || accType.getRank() != 2)
|
|
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
|
|
|
|
// Expect 3D inputs for VNNI packed data.
|
|
VectorType lhsType = contractOp.getLhs().getType();
|
|
VectorType rhsType = contractOp.getRhs().getType();
|
|
if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Expects lhs and rhs 3D vectors");
|
|
|
|
// Check if shapes are compatible with AMX tile.
|
|
if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
|
|
!verifyAmxShape(accType))
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
|
|
|
|
// Validate affine maps.
|
|
//
|
|
// Iterators can be ordered arbitrarily. Indexing map positions are based on
|
|
// operands' target shapes.
|
|
// The matrix layouts must match the following:
|
|
// - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
|
|
// - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
|
|
// - matrix C - [M]x[N]
|
|
SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
|
|
AffineMap mapA = indexingMaps[0];
|
|
AffineMap mapB = indexingMaps[1];
|
|
if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
|
|
mapB.getNumResults() != 3)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Invalid input indexing maps");
|
|
FailureOr<linalg::ContractionDimensions> dims =
|
|
linalg::inferContractionDims(indexingMaps);
|
|
if (failed(dims))
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Failed to infer contraction dims");
|
|
// Two reduction dimensions are expected:
|
|
// - one for the K dimension
|
|
// - one for the VNNI factor
|
|
if (dims->k.size() != 2)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Expected two reduction dims");
|
|
assert(dims->m.size() == 1 && dims->n.size() == 1 &&
|
|
"Invalid parallel contraction dims");
|
|
|
|
SmallVector<vector::IteratorType> iteratorTypes =
|
|
contractOp.getIteratorTypesArray();
|
|
// Check VNNI dim maps - the innermost dim for A and B inputs.
|
|
auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
|
|
auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
|
|
if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
|
|
iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
|
|
// Check K dim maps - non-transposed row-major layout.
|
|
auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
|
|
auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
|
|
if (!redDimA || !redDimB || redDimA != redDimB ||
|
|
iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
|
|
// Check M and N dim maps - map to non-transposed output.
|
|
AffineMap mapC = indexingMaps[2];
|
|
auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
|
|
auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
|
|
if (!mDimC || !nDimC)
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
|
|
auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
|
|
if (!parallelDimA ||
|
|
iteratorTypes[parallelDimA.getPosition()] !=
|
|
vector::IteratorType::parallel ||
|
|
parallelDimA != mDimC)
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
|
|
auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
|
|
if (!parallelDimB ||
|
|
iteratorTypes[parallelDimB.getPosition()] !=
|
|
vector::IteratorType::parallel ||
|
|
parallelDimB != nDimC)
|
|
return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Validate contraction operands for AMX lowering.
|
|
static LogicalResult validateOperands(PatternRewriter &rewriter,
|
|
vector::ContractionOp contractOp) {
|
|
VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
|
|
if (!accType)
|
|
return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
|
|
|
|
// Check if operand types are compatible with AMX compute ops.
|
|
bool validElemTypes = false;
|
|
Type lhsElemType = contractOp.getLhs().getType().getElementType();
|
|
Type rhsElemType = contractOp.getRhs().getType().getElementType();
|
|
Type accElemType = accType.getElementType();
|
|
if (accElemType.isInteger(32)) {
|
|
validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
|
|
} else if (accElemType.isF32()) {
|
|
validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
|
|
(lhsElemType.isBF16() && rhsElemType.isBF16());
|
|
}
|
|
if (!validElemTypes)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Invalid combination of operand types");
|
|
|
|
if (failed(isAmxVnniLayout(rewriter, contractOp)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Collapse the two innermost dimensions together.
|
|
static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
|
|
TypedValue<MemRefType> memref) {
|
|
int64_t rank = memref.getType().getRank();
|
|
SmallVector<ReassociationIndices> reassocIndices;
|
|
for (auto i : llvm::seq<int64_t>(0, rank - 2))
|
|
reassocIndices.push_back({i});
|
|
reassocIndices.push_back({rank - 2, rank - 1});
|
|
return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
|
|
reassocIndices);
|
|
}
|
|
|
|
/// Attempt to create an AMX tile load/store operation equivalent to the given
|
|
/// vector transfer `xfer` op.
|
|
/// This approach allows to skip longer route through registers and a temporary
|
|
/// buffer otherwise required to move data to/from an AMX tile.
|
|
static Operation *
|
|
loadStoreFromTransfer(PatternRewriter &rewriter,
|
|
VectorTransferOpInterface xferOp, bool isPacked,
|
|
TypedValue<x86::amx::TileType> tileToStore = nullptr) {
|
|
if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
|
|
return nullptr;
|
|
if (xferOp.hasOutOfBoundsDim() ||
|
|
!xferOp.getPermutationMap().isMinorIdentity())
|
|
return nullptr;
|
|
|
|
// Extra checks in case of a write op.
|
|
// Stores must not be packed.
|
|
if (isa<vector::TransferWriteOp>(xferOp) &&
|
|
(!tileToStore || isPacked ||
|
|
tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
|
|
return nullptr;
|
|
|
|
// Check for a memref source buffer.
|
|
// AMX data transfer requires at least 2D shape to correctly
|
|
// infer stride between rows.
|
|
Value base = xferOp.getBase();
|
|
auto memTy = dyn_cast<MemRefType>(base.getType());
|
|
int64_t memRank = memTy.getRank();
|
|
if (!memTy || memRank < 2)
|
|
return nullptr;
|
|
|
|
// Check that the source buffer has enough contiguous elements to load whole
|
|
// AMX tile row.
|
|
//
|
|
// To ensure correctness, the validation is conservative and expects the
|
|
// buffer's innermost dimensions to be statically known, equal to or larger
|
|
// than the vector row length, and equal to the VNNI dimension if applicable.
|
|
//
|
|
// This check could be relaxed to accept more arbitrarily shaped buffers as
|
|
// long as there are enough contiguous elements to load a whole row.
|
|
if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
|
|
return nullptr;
|
|
VectorType vecTy = xferOp.getVectorType();
|
|
ArrayRef<int64_t> vecShape = vecTy.getShape();
|
|
ArrayRef<int64_t> memShape = memTy.getShape();
|
|
if (memShape.back() == ShapedType::kDynamic ||
|
|
memShape.back() < vecShape.back())
|
|
return nullptr;
|
|
if (isPacked &&
|
|
(memShape.back() != vecShape.back() ||
|
|
memShape[memShape.size() - 2] == ShapedType::kDynamic ||
|
|
memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
|
|
return nullptr;
|
|
|
|
// Load values directly from the buffer to an AMX tile.
|
|
PatternRewriter::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(xferOp);
|
|
Location loc = xferOp.getLoc();
|
|
|
|
// Create a subview of the source buffer based on the transfer op to resolve
|
|
// offsets.
|
|
SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
|
|
int64_t vecRank = vecTy.getRank();
|
|
assert(memRank >= vecRank &&
|
|
"Expects buffer to be the same or greater rank than vector");
|
|
SmallVector<int64_t> shape(memRank - vecRank, 1);
|
|
shape.append(vecShape.begin(), vecShape.end());
|
|
TypedValue<MemRefType> src =
|
|
memref::SubViewOp::create(
|
|
rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
|
|
getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
|
|
.getResult();
|
|
|
|
// Collapse the VNNI dimension in case of packing.
|
|
if (isPacked)
|
|
src = collapseLastDim(rewriter, src);
|
|
int64_t rows = vecShape[0];
|
|
int64_t cols = llvm::product_of(vecShape.drop_front());
|
|
auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
|
|
|
|
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
|
|
SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
|
|
|
|
Operation *amxTileOp = nullptr;
|
|
if (isa<vector::TransferReadOp>(xferOp)) {
|
|
amxTileOp = x86::amx::TileLoadOp::create(rewriter, loc, tileType, src,
|
|
tileIndicides);
|
|
} else if (isa<vector::TransferWriteOp>(xferOp)) {
|
|
amxTileOp = x86::amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
|
|
tileToStore);
|
|
} else {
|
|
llvm_unreachable("unsupported vector transfer op");
|
|
}
|
|
|
|
return amxTileOp;
|
|
}
|
|
|
|
/// Attempt to create an AMX tile load operation equivalent to the given
|
|
/// vector transfer `readOp`.
|
|
/// Returns loaded AMX tile if successful.
|
|
static FailureOr<TypedValue<x86::amx::TileType>>
|
|
loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
|
|
bool isPacked) {
|
|
x86::amx::TileLoadOp loadOp = dyn_cast_if_present<x86::amx::TileLoadOp>(
|
|
loadStoreFromTransfer(rewriter, readOp, isPacked));
|
|
if (!loadOp)
|
|
return failure();
|
|
return loadOp.getRes();
|
|
}
|
|
|
|
/// Attempt to create an AMX tile store operation equivalent to the given
|
|
/// vector transfer `writeOp`.
|
|
static LogicalResult
|
|
storeFromTransfer(PatternRewriter &rewriter, vector::TransferWriteOp writeOp,
|
|
TypedValue<x86::amx::TileType> tileToStore) {
|
|
return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
|
|
tileToStore));
|
|
}
|
|
|
|
/// Load vector values to an AMX tile.
|
|
static TypedValue<x86::amx::TileType> loadTile(PatternRewriter &rewriter,
|
|
TypedValue<VectorType> vec) {
|
|
Location loc = vec.getLoc();
|
|
|
|
VectorType vecTy = vec.getType();
|
|
bool isPacked = vecTy.getRank() == 3;
|
|
|
|
// Try to load tile directly from vector producer's buffer.
|
|
auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
|
|
FailureOr<TypedValue<x86::amx::TileType>> tile =
|
|
loadFromTransfer(rewriter, readOp, isPacked);
|
|
if (succeeded(tile))
|
|
return *tile;
|
|
|
|
// Transfer the vector to a tile through an intermediate buffer.
|
|
Value buf = memref::AllocaOp::create(
|
|
rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
|
|
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
|
|
SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
|
|
vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
|
|
|
|
// Collapse the VNNI dimension in case of packing.
|
|
if (isPacked)
|
|
buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
|
|
|
|
ArrayRef<int64_t> shape = vecTy.getShape();
|
|
int64_t rows = shape[0];
|
|
int64_t cols = llvm::product_of(shape.drop_front());
|
|
auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
|
|
|
|
return x86::amx::TileLoadOp::create(rewriter, loc, tileType, buf,
|
|
{zeroIndex, zeroIndex});
|
|
}
|
|
|
|
/// Store an AMX tile in a vector.
|
|
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
|
|
TypedValue<x86::amx::TileType> tile) {
|
|
Location loc = tile.getLoc();
|
|
|
|
// Transfer the tile to a vector through an intermediate buffer.
|
|
x86::amx::TileType tileTy = tile.getType();
|
|
Value buf = memref::AllocaOp::create(
|
|
rewriter, loc,
|
|
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
|
|
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
|
|
SmallVector<Value> indices(2, zeroIndex);
|
|
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
|
|
|
|
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
|
|
return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
|
|
}
|
|
|
|
struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Location loc = contractOp.getLoc();
|
|
|
|
if (contractOp.getKind() != vector::CombiningKind::ADD)
|
|
return rewriter.notifyMatchFailure(contractOp,
|
|
"Expects add combining kind");
|
|
if (failed(validateOperands(rewriter, contractOp)))
|
|
return failure();
|
|
|
|
TypedValue<x86::amx::TileType> lhsTile =
|
|
loadTile(rewriter, contractOp.getLhs());
|
|
TypedValue<x86::amx::TileType> rhsTile =
|
|
loadTile(rewriter, contractOp.getRhs());
|
|
auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
|
|
assert(acc && "Invalid accumulator type");
|
|
TypedValue<x86::amx::TileType> accTile = loadTile(rewriter, acc);
|
|
|
|
TypedValue<x86::amx::TileType> tileMul;
|
|
if (acc.getType().getElementType().isFloat()) {
|
|
tileMul = x86::amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
|
|
lhsTile, rhsTile, accTile);
|
|
} else {
|
|
tileMul = x86::amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
|
|
lhsTile, rhsTile, accTile);
|
|
}
|
|
|
|
// If the contraction result is only written back to memory, try to replace
|
|
// the vector op with an AMX store directly.
|
|
Value res = contractOp.getResult();
|
|
if (res.hasOneUse()) {
|
|
auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
|
|
LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
|
|
if (succeeded(storeRes)) {
|
|
rewriter.eraseOp(writeOp);
|
|
rewriter.eraseOp(contractOp);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Load the result back into a vector.
|
|
Value newResult = storeTile(rewriter, tileMul);
|
|
rewriter.replaceOp(contractOp, newResult);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertVectorToAMXPass
|
|
: public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
|
|
void runOnOperation() override {
|
|
MLIRContext &ctx = getContext();
|
|
RewritePatternSet patterns(&ctx);
|
|
populateVectorToAMXConversionPatterns(patterns);
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
|
|
patterns.add<ContractionToAMX>(patterns.getContext());
|
|
}
|