[mlir][memref] Add non-atomic RMW option for emulated memref.store. (#178498)
The revision follows
f0e1857c84
to add an option for supporting non-atomic RMW emulation. The 0D case
uses non-atomic option unconditionally because it writes the entire
value.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
This commit is contained in:
@@ -83,9 +83,11 @@ void populateMemRefWideIntEmulationConversions(
|
||||
|
||||
/// Appends patterns for emulating memref operations over narrow types with ops
|
||||
/// over wider types.
|
||||
/// When `disableAtomicRMW` is true, the store patterns generate non-atomic
|
||||
/// read-modify-write sequences instead of atomic operations.
|
||||
void populateMemRefNarrowTypeEmulationPatterns(
|
||||
const arith::NarrowTypeEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
RewritePatternSet &patterns, bool disableAtomicRMW = false);
|
||||
|
||||
/// Appends type conversions for emulating memref operations over narrow types
|
||||
/// with ops over wider types.
|
||||
|
||||
@@ -408,9 +408,17 @@ struct ConvertMemRefReinterpretCast final
|
||||
// ConvertMemrefStore
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Emulate narrow type memref store with a non-atomic or atomic
|
||||
/// read-modify-write sequence. The `disableAtomicRMW` indicates whether to use
|
||||
/// a normal read-modify-write sequence instead of using
|
||||
/// `memref.generic_atomic_rmw` to perform subbyte storing.
|
||||
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
ConvertMemrefStore(MLIRContext *context, bool disableAtomicRMW)
|
||||
: OpConversionPattern<memref::StoreOp>(context),
|
||||
disableAtomicRMW(disableAtomicRMW) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
@@ -437,11 +445,11 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
|
||||
Value extendedInput =
|
||||
arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
|
||||
|
||||
// Special case 0-rank memref stores. No need for masking.
|
||||
// Special case 0-rank memref stores. No need for masking. The non-atomic
|
||||
// store is used because it operates on the entire value.
|
||||
if (convertedType.getRank() == 0) {
|
||||
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
|
||||
extendedInput, adaptor.getMemref(),
|
||||
ValueRange{});
|
||||
memref::StoreOp::create(rewriter, loc, extendedInput, adaptor.getMemref(),
|
||||
ValueRange{});
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -454,19 +462,39 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
|
||||
dstBits, rewriter);
|
||||
Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
|
||||
dstBits, bitwidthOffset, rewriter);
|
||||
// Align the value to write with the destination bits
|
||||
// Align the value to write with the destination bits.
|
||||
Value alignedVal =
|
||||
arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
|
||||
|
||||
// Clear destination bits
|
||||
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
|
||||
writeMask, adaptor.getMemref(), storeIndices);
|
||||
// Write srcs bits to destination
|
||||
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
|
||||
alignedVal, adaptor.getMemref(), storeIndices);
|
||||
if (disableAtomicRMW) {
|
||||
// Load the original value.
|
||||
Value origValue = memref::LoadOp::create(
|
||||
rewriter, loc, adaptor.getMemref(), storeIndices);
|
||||
// Clear destination bits (and with mask).
|
||||
Value clearedValue =
|
||||
arith::AndIOp::create(rewriter, loc, origValue, writeMask);
|
||||
// Write src bits to destination (or with aligned value), and store the
|
||||
// result.
|
||||
Value newValue =
|
||||
arith::OrIOp::create(rewriter, loc, clearedValue, alignedVal);
|
||||
memref::StoreOp::create(rewriter, loc, newValue, adaptor.getMemref(),
|
||||
storeIndices);
|
||||
} else {
|
||||
// Atomic read-modify-write operations.
|
||||
// Clear destination bits.
|
||||
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
|
||||
writeMask, adaptor.getMemref(), storeIndices);
|
||||
// Write src bits to destination.
|
||||
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
|
||||
alignedVal, adaptor.getMemref(),
|
||||
storeIndices);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
bool disableAtomicRMW;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -601,16 +629,17 @@ struct ConvertMemRefExpandShape final
|
||||
|
||||
void memref::populateMemRefNarrowTypeEmulationPatterns(
|
||||
const arith::NarrowTypeEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
RewritePatternSet &patterns, bool disableAtomicRMW) {
|
||||
|
||||
// Populate `memref.*` conversion patterns.
|
||||
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
|
||||
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
|
||||
ConvertMemRefDealloc, ConvertMemRefCollapseShape,
|
||||
ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
|
||||
ConvertMemRefExpandShape, ConvertMemRefLoad,
|
||||
ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
|
||||
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns.insert<ConvertMemrefStore>(patterns.getContext(), disableAtomicRMW);
|
||||
memref::populateResolveExtractStridedMetadataPatterns(patterns);
|
||||
}
|
||||
|
||||
|
||||
116
mlir/test/Dialect/MemRef/emulate-narrow-type-non-atomic.mlir
Normal file
116
mlir/test/Dialect/MemRef/emulate-narrow-type-non-atomic.mlir
Normal file
@@ -0,0 +1,116 @@
|
||||
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
|
||||
%0 = memref.alloc() : memref<5xi4>
|
||||
memref.store %arg1, %0[%arg0] : memref<5xi4>
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
|
||||
// CHECK: func @memref_store_i4(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
|
||||
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
|
||||
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
|
||||
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
|
||||
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
|
||||
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
|
||||
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
|
||||
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
|
||||
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
|
||||
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK: %[[ORIG:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>
|
||||
// CHECK: %[[CLEARED:.+]] = arith.andi %[[ORIG]], %[[MASK]] : i8
|
||||
// CHECK: %[[INSERTED:.+]] = arith.ori %[[CLEARED]], %[[SHIFTED_VAL]] : i8
|
||||
// CHECK: memref.store %[[INSERTED]], %[[ALLOC]][%[[INDEX]]] : memref<3xi8>
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
|
||||
%0 = memref.alloc() : memref<3x125xi4>
|
||||
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
|
||||
// CHECK: func @memref_store_i4_rank2(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
|
||||
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
|
||||
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
|
||||
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
|
||||
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
|
||||
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
|
||||
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
|
||||
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK: %[[ORIG:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] : memref<188xi8>
|
||||
// CHECK: %[[CLEARED:.+]] = arith.andi %[[ORIG]], %[[MASK]] : i8
|
||||
// CHECK: %[[INSERTED:.+]] = arith.ori %[[CLEARED]], %[[SHIFTED_VAL]] : i8
|
||||
// CHECK: memref.store %[[INSERTED]], %[[ALLOC]][%[[INDEX]]] : memref<188xi8>
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
|
||||
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
|
||||
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
|
||||
// CHECK: func @memref_store_i4_dynamic(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
|
||||
// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
|
||||
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
|
||||
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
|
||||
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
|
||||
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
|
||||
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
|
||||
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
|
||||
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
|
||||
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
|
||||
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK: %[[ORIG:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>
|
||||
// CHECK: %[[CLEARED:.+]] = arith.andi %[[ORIG]], %[[MASK]] : i8
|
||||
// CHECK: %[[INSERTED:.+]] = arith.ori %[[CLEARED]], %[[SHIFTED_VAL]] : i8
|
||||
// CHECK: memref.store %[[INSERTED]], %[[ALLOC]][%[[INDEX]]] : memref<?xi8>
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
func.func @memref_store_f4(%arg0: f4E2M1FN) -> () {
|
||||
%0 = memref.alloc() : memref<f4E2M1FN>
|
||||
memref.store %arg0, %0[] : memref<f4E2M1FN>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @memref_store_f4(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN
|
||||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
|
||||
// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
|
||||
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
|
||||
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
// 0-rank memrefs don't need RMW since they store the entire element.
|
||||
func.func @rank_zero_memref_store(%arg0: i4) -> () {
|
||||
%0 = memref.alloc() : memref<i4>
|
||||
memref.store %arg0, %0[] : memref<i4>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @rank_zero_memref
|
||||
// CHECK-SAME: %[[ARG0:.+]]: i4
|
||||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
|
||||
// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
|
||||
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
|
||||
// CHECK: return
|
||||
@@ -493,14 +493,14 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
|
||||
// CHECK-SAME: %[[ARG0:.+]]: i4
|
||||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
|
||||
// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
|
||||
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
|
||||
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
|
||||
// CHECK: return
|
||||
|
||||
// CHECK32-LABEL: func @rank_zero_memref
|
||||
// CHECK32-SAME: %[[ARG0:.+]]: i4
|
||||
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
|
||||
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
|
||||
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
|
||||
// CHECK32: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i32>
|
||||
// CHECK32: return
|
||||
|
||||
// -----
|
||||
@@ -515,7 +515,7 @@ func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
|
||||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
|
||||
// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
|
||||
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
|
||||
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
|
||||
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
|
||||
// CHECK: return
|
||||
|
||||
// CHECK32-LABEL: func @rank_zero_memref
|
||||
@@ -523,7 +523,7 @@ func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
|
||||
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
|
||||
// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
|
||||
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
|
||||
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
|
||||
// CHECK32: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i32>
|
||||
// CHECK32: return
|
||||
|
||||
// -----
|
||||
|
||||
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
|
||||
RewritePatternSet patterns(ctx);
|
||||
|
||||
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
|
||||
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
|
||||
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns,
|
||||
disableAtomicRMW);
|
||||
vector::populateVectorNarrowTypeEmulationPatterns(
|
||||
typeConverter, patterns, disableAtomicRMW, assumeAligned);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user