This commit adds functionality to bubble down memory-space casts
operations, allowing consumer operations to use the original
memory-space rather than first casting to a different memory space.
Changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast
operations
- Create a `MemorySpaceCastConsumerOpInterface` pass that identifies and
bubbles down eligible casts
- Add implementation for memref and vector operations to handle
memory-space cast propagation
- Add `bubbleDownCasts` method to relevant operations to support the
fusion
In particular, in the current implementation only memory-space casts
into the default memory-space can be bubbled-down.
Example:
```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
%collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
%loaded = memref.load %collapsed[%c0] : memref<16xf32>
%added = arith.addf %loaded, %arg2 : f32
memref.store %added, %collapsed[%c0] : memref<16xf32>
%atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
return %collapsed : memref<16xf32>
}
// mlir-opt --bubble-down-memory-space-casts
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
%collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
%memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
%0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
%1 = arith.addf %0, %arg2 : f32
memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
%2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
return %memspacecast : memref<16xf32>
}
```
---------
Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
70 lines
2.5 KiB
C++
70 lines
2.5 KiB
C++
//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/MemOpInterfaces.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS
|
|
#include "mlir/Transforms/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
namespace {
|
|
//===----------------------------------------------------------------------===//
|
|
// BubbleDownCastsPattern pattern
|
|
//===----------------------------------------------------------------------===//
|
|
/// Pattern to bubble down casts into consumer operations.
|
|
struct BubbleDownCastsPattern
|
|
: public OpInterfaceRewritePattern<MemorySpaceCastConsumerOpInterface> {
|
|
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<std::optional<SmallVector<Value>>> results =
|
|
op.bubbleDownCasts(rewriter);
|
|
if (failed(results))
|
|
return failure();
|
|
if (!results->has_value()) {
|
|
rewriter.modifyOpInPlace(op, []() {});
|
|
return success();
|
|
}
|
|
rewriter.replaceOp(op, **results);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BubbleDownMemorySpaceCasts pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct BubbleDownMemorySpaceCasts
|
|
: public impl::BubbleDownMemorySpaceCastsBase<BubbleDownMemorySpaceCasts> {
|
|
using impl::BubbleDownMemorySpaceCastsBase<
|
|
BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase;
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1));
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateBubbleDownMemorySpaceCastPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit benefit) {
|
|
patterns.add<BubbleDownCastsPattern>(patterns.getContext(), benefit);
|
|
}
|