Files
llvm-project/mlir/test/python/dialects/transform_xegpu_ext.py
Tuomas Kärnä 7cb57c6808 [MLIR][XeGPU][TransformOps] Remove obsolete transform ops (#187561)
Cleaning up XeGPU transform ops. Now that XeGPU layout propagation
works, it is sufficient to set the layouts for anchor ops (e.g.
load/store/dpas ops) only.

Changes:
* Remove `xegpu.get_desc_op` and `xegpu.set_desc_layout`. Users should
not change the layout of descriptor op's return value anymore.
* Add `xegpu.get_load_op(value)` that finds either `xegpu.load_nd` or
`xegpu.load` op in the value's producer chain. This is a useful utility
as load ops often need to be annotated with a layout.
* The generic `xegpu.set_op_layout_attr(op, ...)` is now replaced by
`xegpu.set_anchor_layout(op, ...)` that only sets layout attribute of
anchor ops. Raises an error if the given op does not support anchor
layouts.
* `xegpu.insert_prefetch` takes a load op handle instead of a value.
2026-03-25 10:24:11 +02:00

253 lines
7.4 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
from mlir.dialects.transform import structured, AnyValueType
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def getLoadOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
load_handle = xegpu.get_load_op(operand)
transform.YieldOp()
# CHECK-LABEL: TEST: getLoadOp
# CHECK: transform.xegpu.get_load_op %
@run
def setAnchorLayout():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
xegpu.set_anchor_layout(
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
)
transform.YieldOp()
# CHECK-LABEL: TEST: setAnchorLayout
# CHECK: transform.xegpu.set_anchor_layout %
# CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
@run
def setAnchorLayoutDPAS():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
xegpu.set_anchor_layout(
sequence.bodyTarget,
index=1,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
)
transform.YieldOp()
# CHECK-LABEL: TEST: setAnchorLayoutDPAS
# CHECK: transform.xegpu.set_anchor_layout %
# CHECK: index = 1
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
@run
def setAnchorLayoutOrder():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
xegpu.set_anchor_layout(
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
order=[1, 0],
)
transform.YieldOp()
# CHECK-LABEL: TEST: setAnchorLayoutOrder
# CHECK: transform.xegpu.set_anchor_layout %
# CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: order = [1, 0]
@run
def setAnchorLayoutSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.load"),
)
with InsertionPoint(sequence.body):
xegpu.set_anchor_layout(
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
slice_dims=[0],
)
transform.YieldOp()
# CHECK-LABEL: TEST: setAnchorLayoutSlice
# CHECK: transform.xegpu.set_anchor_layout %
# CHECK-NOT: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]
@run
def setGPULaunchThreadsOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("gpu.launch"),
)
with InsertionPoint(sequence.body):
xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
# CHECK: transform.xegpu.set_gpu_launch_threads
# CHECK: threads = [8, 4, 1]
@run
def insertPrefetch():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
xegpu.insert_prefetch(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetch
# CHECK: transform.xegpu.insert_prefetch
@run
def insertPrefetchNbPrefetch():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
xegpu.insert_prefetch(sequence.bodyTarget, nb_prefetch=2)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetchNbPrefetch
# CHECK: transform.xegpu.insert_prefetch
# CHECK-SAME: nb_prefetch = 2
@run
def insertPrefetchNbPrefetchParam():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.load_nd"),
)
with InsertionPoint(sequence.body):
int32_t = IntegerType.get_signless(32)
param_int32_t = transform.ParamType.get(int32_t)
nb_param = transform.ParamConstantOp(
param_int32_t,
IntegerAttr.get(int32_t, 2),
)
xegpu.insert_prefetch(sequence.bodyTarget, nb_prefetch=nb_param)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
# CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
# CHECK: transform.xegpu.insert_prefetch
# CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
@run
def ConvertLayoutMinimal():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
xegpu.convert_layout(
operand,
input_sg_layout=[6, 4],
input_sg_data=[32, 16],
target_sg_layout=[6, 4],
target_sg_data=[8, 16],
)
transform.YieldOp()
# CHECK-LABEL: TEST: ConvertLayoutMinimal
# CHECK: transform.xegpu.convert_layout %
# CHECK: input_sg_layout = [6, 4]
# CHECK: input_sg_data = [32, 16]
# CHECK: target_sg_layout = [6, 4]
# CHECK: target_sg_data = [8, 16]
@run
def ConvertLayout():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
xegpu.convert_layout(
operand,
input_sg_layout=[6, 4],
input_sg_data=[32, 32],
input_inst_data=[32, 16],
input_order=[1, 0],
target_sg_layout=[6, 4],
target_sg_data=[32, 32],
target_inst_data=[8, 16],
target_order=[0, 1],
)
transform.YieldOp()
# CHECK-LABEL: TEST: ConvertLayout
# CHECK: transform.xegpu.convert_layout %
# CHECK: input_sg_layout = [6, 4]
# CHECK: input_sg_data = [32, 32]
# CHECK: input_inst_data = [32, 16]
# CHECK: input_order = [1, 0]
# CHECK: target_sg_layout = [6, 4]
# CHECK: target_sg_data = [32, 32]
# CHECK: target_inst_data = [8, 16]
# CHECK: target_order = [0, 1]