This is a fix for the `BufferizableOpInterface` implementation for
`ml_program.global_store`.
`bufferizesToMemoryRead` currently returns false for
`GlobalStoreOpInterface`, but I believe it should return true as
`ml_program.global_store` needs to read its input buffer to know what
value to store to global.
This manifested in a bug where `one-shot-bufferize` would produce MLIR
that copies uninitialized data to the global var instead of the intended
value to be stored.
For the following MLIR:
```
module {
ml_program.global private mutable @"state_tensor"(dense<0.0> : tensor<4x75xf32>) : tensor<4x75xf32>
func.func @main() -> tensor<4x75xf32> {
%c0 = arith.constant 0 : index
%cst_val = arith.constant 1.0 : f32
%initial_state = ml_program.global_load @"state_tensor" : tensor<4x75xf32>
%val = tensor.extract %initial_state[%c0, %c0] : tensor<4x75xf32>
%next_val = arith.addf %val, %cst_val : f32
%updated_tensor = tensor.insert %next_val into %initial_state[%c0, %c0] : tensor<4x75xf32>
ml_program.global_store @"state_tensor" = %updated_tensor : tensor<4x75xf32>
return %updated_tensor : tensor<4x75xf32>
}
}
```
`one-shot-bufferize` produces this incorrect MLIR
```
module {
memref.global "private" @state_tensor : memref<4x75xf32> = dense<0.000000e+00>
func.func @main() -> tensor<4x75xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 1.000000e+00 : f32
%0 = memref.get_global @state_tensor : memref<4x75xf32>
%1 = memref.load %0[%c0, %c0] : memref<4x75xf32>
%2 = arith.addf %1, %cst : f32
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
memref.copy %0, %alloc : memref<4x75xf32> to memref<4x75xf32>
memref.store %2, %alloc[%c0, %c0] : memref<4x75xf32>
%3 = bufferization.to_tensor %alloc : memref<4x75xf32> to tensor<4x75xf32>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
%4 = memref.get_global @state_tensor : memref<4x75xf32>
memref.copy %alloc_0, %4 : memref<4x75xf32> to memref<4x75xf32>
return %3 : tensor<4x75xf32>
}
}
```
Note that `memref.copy` at the end copies an uninitialized `alloc_0` to
the global variable.
But after the change we see the following MLIR:
```
module {
memref.global "private" @state_tensor : memref<4x75xf32> = dense<0.000000e+00>
func.func @main() -> tensor<4x75xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 1.000000e+00 : f32
%0 = memref.get_global @state_tensor : memref<4x75xf32>
%1 = memref.load %0[%c0, %c0] : memref<4x75xf32>
%2 = arith.addf %1, %cst : f32
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
memref.copy %0, %alloc : memref<4x75xf32> to memref<4x75xf32>
memref.store %2, %alloc[%c0, %c0] : memref<4x75xf32>
%3 = bufferization.to_tensor %alloc : memref<4x75xf32> to tensor<4x75xf32>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
memref.copy %alloc, %alloc_0 : memref<4x75xf32> to memref<4x75xf32>
%4 = memref.get_global @state_tensor : memref<4x75xf32>
memref.copy %alloc_0, %4 : memref<4x75xf32> to memref<4x75xf32>
return %3 : tensor<4x75xf32>
}
}
```
We now see that the relevant data is copied to `alloc_0` before it is
stored in global.
Co-authored-by: Nathan Malimban <nmalimba@ah-nmalimba-l.dhcp.mathworks.com>
169 lines
5.5 KiB
C++
169 lines
5.5 KiB
C++
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
|
|
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
using namespace mlir::ml_program;
|
|
|
|
namespace mlir {
|
|
namespace ml_program {
|
|
namespace {
|
|
|
|
template <typename Interface, typename Op>
|
|
struct ExternalModelBase
|
|
: public BufferizableOpInterface::ExternalModel<Interface, Op> {
|
|
|
|
AliasingValueList getAliasingValues(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return {};
|
|
}
|
|
|
|
BufferRelation bufferRelation(Operation *, OpResult,
|
|
const AnalysisState &) const {
|
|
return BufferRelation::Unknown;
|
|
}
|
|
};
|
|
|
|
/// Bufferization of ml_program.global into a memref.global
|
|
struct GlobalOpInterface
|
|
: public ExternalModelBase<GlobalOpInterface, GlobalOp> {
|
|
|
|
bool bufferizesToMemoryRead(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return false;
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return false;
|
|
}
|
|
|
|
bool hasTensorSemantics(Operation *) const { return true; }
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &,
|
|
BufferizationState &state) const {
|
|
auto globalOp = cast<GlobalOp>(op);
|
|
if (!globalOp.getValue().has_value())
|
|
return globalOp.emitError("global op must have a value");
|
|
|
|
bufferization::removeSymbol(globalOp, state);
|
|
|
|
auto tensorType = cast<TensorType>(globalOp.getType());
|
|
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
|
|
|
|
auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
|
|
rewriter, globalOp, globalOp.getSymName(),
|
|
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
|
|
/*type=*/cast<MemRefType>(memrefType),
|
|
/*initial_value=*/globalOp.getValue().value(),
|
|
/*constant=*/!globalOp.getIsMutable(),
|
|
/*alignment=*/nullptr);
|
|
|
|
bufferization::insertSymbol(replacement, state);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Bufferization of ml_program.global_load into a memref.get_global
|
|
struct GlobalLoadOpInterface
|
|
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
|
|
|
|
bool bufferizesToMemoryRead(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return false;
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return false;
|
|
}
|
|
|
|
bool isWritable(Operation *, Value, const AnalysisState &) const {
|
|
return false;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &,
|
|
BufferizationState &state) const {
|
|
auto globalLoadOp = cast<GlobalLoadOp>(op);
|
|
|
|
auto tensorType = cast<TensorType>(globalLoadOp.getType());
|
|
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
|
|
|
|
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
|
|
rewriter, globalLoadOp, memrefType,
|
|
globalLoadOp.getGlobalAttr().getLeafReference());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Bufferization of ml_program.global_store into a memref.get_global and
|
|
/// memcpy
|
|
struct GlobalStoreOpInterface
|
|
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
|
|
|
|
bool bufferizesToMemoryRead(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return true;
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
|
|
const AnalysisState &) const {
|
|
return true;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options,
|
|
BufferizationState &state) const {
|
|
auto globalStoreOp = cast<GlobalStoreOp>(op);
|
|
|
|
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
|
|
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
|
|
|
|
auto loc = globalStoreOp.getLoc();
|
|
auto targetMemref = memref::GetGlobalOp::create(
|
|
rewriter, loc, memrefType,
|
|
globalStoreOp.getGlobalAttr().getLeafReference());
|
|
|
|
auto sourceMemref =
|
|
getBuffer(rewriter, globalStoreOp.getValue(), options, state);
|
|
if (failed(sourceMemref)) {
|
|
return failure();
|
|
}
|
|
|
|
auto memcpy =
|
|
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
|
|
if (failed(memcpy)) {
|
|
return failure();
|
|
}
|
|
rewriter.eraseOp(globalStoreOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
|
|
GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
|
|
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
|
|
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
|
|
});
|
|
}
|
|
} // namespace ml_program
|
|
} // namespace mlir
|