Renames 'x86vector' dialect to 'x86'. This is the first PR in series of cleanups around dialects targeting x86 platforms. The new naming scheme is shorter, cleaner, and opens possibility of integrating other x86-specific operations not strictly fitting pure vector representation. For example, the generalization will allow for future merger of AMX dialect into the x86 dialect to create one-stop x86 operations collection and boost discoverability.
75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
import mlir.dialects.builtin as builtin
|
|
import mlir.dialects.func as func
|
|
import mlir.dialects.x86 as x86
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
with Context(), Location.unknown():
|
|
f()
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAvxOp
|
|
@run
|
|
def testAvxOp():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
|
|
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
|
|
def avx_op(arg):
|
|
return x86.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
|
|
|
|
# CHECK-LABEL: func @avx_op(
|
|
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
|
|
# CHECK: %[[VAL:.+]] = x86.avx.bcst_to_f32.packed %[[ARG]]
|
|
# CHECK: return %[[VAL]] : vector<8xf32>
|
|
# CHECK: }
|
|
print(module)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAvx512Op
|
|
@run
|
|
def testAvx512Op():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
|
|
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
|
|
def avx512_op(arg):
|
|
return x86.CvtPackedF32ToBF16Op(
|
|
a=arg, dst=VectorType.get((8,), BF16Type.get())
|
|
)
|
|
|
|
# CHECK-LABEL: func @avx512_op(
|
|
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
|
|
# CHECK: %[[VAL:.+]] = x86.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
|
|
# CHECK: return %[[VAL]] : vector<8xbf16>
|
|
# CHECK: }
|
|
print(module)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAvx10Op
|
|
@run
|
|
def testAvx10Op():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
|
|
@func.FuncOp.from_py_func(
|
|
VectorType.get((16,), IntegerType.get(32)),
|
|
VectorType.get((64,), IntegerType.get(8)),
|
|
VectorType.get((64,), IntegerType.get(8)),
|
|
)
|
|
def avx10_op(*args):
|
|
return x86.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
|
|
|
|
# CHECK-LABEL: func @avx10_op(
|
|
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
|
|
# CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
|
|
# CHECK: %[[VAL:.+]] = x86.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
|
|
# CHECK: return %[[VAL]] : vector<16xi32>
|
|
# CHECK: }
|
|
print(module)
|