Files
llvm-project/mlir/lib/Interfaces/MemOpInterfaces.cpp
Fabian Mora 077a796c0d [mlir] Implement a memory-space cast bubbling-down transform (#159454)
This commit adds functionality to bubble down memory-space casts
operations, allowing consumer operations to use the original
memory-space rather than first casting to a different memory space.

Changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast
operations
- Create a `MemorySpaceCastConsumerOpInterface` pass that identifies and
bubbles down eligible casts
- Add implementation for memref and vector operations to handle
memory-space cast propagation
- Add `bubbleDownCasts` method to relevant operations to support the
fusion

In particular, in the current implementation only memory-space casts
into the default memory-space can be bubbled-down.

Example:

```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --bubble-down-memory-space-casts
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}
```

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
2025-09-24 09:11:43 -04:00

74 lines
2.4 KiB
C++

//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
using namespace mlir;
LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
auto memCastOp = cast<MemorySpaceCastOpInterface>(op);
// Verify that the source and target pointers are valid
Value sourcePtr = memCastOp.getSourcePtr();
Value targetPtr = memCastOp.getTargetPtr();
if (!sourcePtr || !targetPtr) {
return op->emitError()
<< "memory space cast op must have valid source and target pointers";
}
if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) {
return op->emitError()
<< "expected source and target types of the same kind";
}
// Verify the Types are of `PtrLikeTypeInterface` type.
auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType());
if (!sourceType) {
return op->emitError()
<< "source type must implement `PtrLikeTypeInterface`, but got: "
<< sourcePtr.getType();
}
auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType());
if (!targetType) {
return op->emitError()
<< "target type must implement `PtrLikeTypeInterface`, but got: "
<< targetPtr.getType();
}
// Verify that the operation has exactly one result
if (op->getNumResults() != 1) {
return op->emitError()
<< "memory space cast op must have exactly one result";
}
return success();
}
FailureOr<std::optional<SmallVector<Value>>>
mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand,
ValueRange results) {
MemorySpaceCastOpInterface castOp =
MemorySpaceCastOpInterface::getIfPromotableCast(operand.get());
// Bail if the src is not valid.
if (!castOp)
return failure();
// Modify the op.
operand.set(castOp.getSourcePtr());
return std::optional<SmallVector<Value>>();
}
#include "mlir/Interfaces/MemOpInterfaces.cpp.inc"