[mlir][memref] Add non-atomic RMW option for emulated memref.store. (#178498)

The revision follows
f0e1857c84
to add an option for supporting non-atomic RMW emulation. The 0D case
uses non-atomic option unconditionally because it writes the entire
value.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
This commit is contained in:
Han-Chung Wang
2026-01-29 14:15:44 -08:00
committed by GitHub
parent 60bc9d15f5
commit 20b925a28a
5 changed files with 167 additions and 19 deletions

View File

@@ -83,9 +83,11 @@ void populateMemRefWideIntEmulationConversions(
/// Appends patterns for emulating memref operations over narrow types with ops
/// over wider types.
/// When `disableAtomicRMW` is true, the store patterns generate non-atomic
/// read-modify-write sequences instead of atomic operations.
void populateMemRefNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns, bool disableAtomicRMW = false);
/// Appends type conversions for emulating memref operations over narrow types
/// with ops over wider types.

View File

@@ -408,9 +408,17 @@ struct ConvertMemRefReinterpretCast final
// ConvertMemrefStore
//===----------------------------------------------------------------------===//
/// Emulate narrow type memref store with a non-atomic or atomic
/// read-modify-write sequence. The `disableAtomicRMW` indicates whether to use
/// a normal read-modify-write sequence instead of using
/// `memref.generic_atomic_rmw` to perform subbyte storing.
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;
ConvertMemrefStore(MLIRContext *context, bool disableAtomicRMW)
: OpConversionPattern<memref::StoreOp>(context),
disableAtomicRMW(disableAtomicRMW) {}
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -437,11 +445,11 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
Value extendedInput =
arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
// Special case 0-rank memref stores. No need for masking.
// Special case 0-rank memref stores. No need for masking. The non-atomic
// store is used because it operates on the entire value.
if (convertedType.getRank() == 0) {
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
extendedInput, adaptor.getMemref(),
ValueRange{});
memref::StoreOp::create(rewriter, loc, extendedInput, adaptor.getMemref(),
ValueRange{});
rewriter.eraseOp(op);
return success();
}
@@ -454,19 +462,39 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
dstBits, rewriter);
Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
dstBits, bitwidthOffset, rewriter);
// Align the value to write with the destination bits
// Align the value to write with the destination bits.
Value alignedVal =
arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
// Clear destination bits
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(), storeIndices);
// Write srcs bits to destination
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(), storeIndices);
if (disableAtomicRMW) {
// Load the original value.
Value origValue = memref::LoadOp::create(
rewriter, loc, adaptor.getMemref(), storeIndices);
// Clear destination bits (and with mask).
Value clearedValue =
arith::AndIOp::create(rewriter, loc, origValue, writeMask);
// Write src bits to destination (or with aligned value), and store the
// result.
Value newValue =
arith::OrIOp::create(rewriter, loc, clearedValue, alignedVal);
memref::StoreOp::create(rewriter, loc, newValue, adaptor.getMemref(),
storeIndices);
} else {
// Atomic read-modify-write operations.
// Clear destination bits.
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(), storeIndices);
// Write src bits to destination.
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(),
storeIndices);
}
rewriter.eraseOp(op);
return success();
}
private:
bool disableAtomicRMW;
};
//===----------------------------------------------------------------------===//
@@ -601,16 +629,17 @@ struct ConvertMemRefExpandShape final
void memref::populateMemRefNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
RewritePatternSet &patterns, bool disableAtomicRMW) {
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
ConvertMemRefDealloc, ConvertMemRefCollapseShape,
ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
ConvertMemRefExpandShape, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
patterns.insert<ConvertMemrefStore>(patterns.getContext(), disableAtomicRMW);
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}

View File

@@ -0,0 +1,116 @@
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s
func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
%0 = memref.alloc() : memref<5xi4>
memref.store %arg1, %0[%arg0] : memref<5xi4>
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
// CHECK: func @memref_store_i4(
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
// CHECK: %[[ORIG:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>
// CHECK: %[[CLEARED:.+]] = arith.andi %[[ORIG]], %[[MASK]] : i8
// CHECK: %[[INSERTED:.+]] = arith.ori %[[CLEARED]], %[[SHIFTED_VAL]] : i8
// CHECK: memref.store %[[INSERTED]], %[[ALLOC]][%[[INDEX]]] : memref<3xi8>
// CHECK: return
// -----
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
%0 = memref.alloc() : memref<3x125xi4>
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
// CHECK: func @memref_store_i4_rank2(
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
// CHECK: %[[ORIG:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] : memref<188xi8>
// CHECK: %[[CLEARED:.+]] = arith.andi %[[ORIG]], %[[MASK]] : i8
// CHECK: %[[INSERTED:.+]] = arith.ori %[[CLEARED]], %[[SHIFTED_VAL]] : i8
// CHECK: memref.store %[[INSERTED]], %[[ALLOC]][%[[INDEX]]] : memref<188xi8>
// CHECK: return
// -----
func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
// CHECK: func @memref_store_i4_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
// CHECK: %[[ORIG:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>
// CHECK: %[[CLEARED:.+]] = arith.andi %[[ORIG]], %[[MASK]] : i8
// CHECK: %[[INSERTED:.+]] = arith.ori %[[CLEARED]], %[[SHIFTED_VAL]] : i8
// CHECK: memref.store %[[INSERTED]], %[[ALLOC]][%[[INDEX]]] : memref<?xi8>
// CHECK: return
// -----
func.func @memref_store_f4(%arg0: f4E2M1FN) -> () {
%0 = memref.alloc() : memref<f4E2M1FN>
memref.store %arg0, %0[] : memref<f4E2M1FN>
return
}
// CHECK-LABEL: func @memref_store_f4(
// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
// CHECK: return
// -----
// 0-rank memrefs don't need RMW since they store the entire element.
func.func @rank_zero_memref_store(%arg0: i4) -> () {
%0 = memref.alloc() : memref<i4>
memref.store %arg0, %0[] : memref<i4>
return
}
// CHECK-LABEL: func @rank_zero_memref
// CHECK-SAME: %[[ARG0:.+]]: i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
// CHECK: return

View File

@@ -493,14 +493,14 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// CHECK-SAME: %[[ARG0:.+]]: i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
// CHECK: return
// CHECK32-LABEL: func @rank_zero_memref
// CHECK32-SAME: %[[ARG0:.+]]: i4
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i32>
// CHECK32: return
// -----
@@ -515,7 +515,7 @@ func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
// CHECK: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i8>
// CHECK: return
// CHECK32-LABEL: func @rank_zero_memref
@@ -523,7 +523,7 @@ func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: memref.store %[[EXTUI]], %[[ALLOC]][] : memref<i32>
// CHECK32: return
// -----

View File

@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
RewritePatternSet patterns(ctx);
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns,
disableAtomicRMW);
vector::populateVectorNarrowTypeEmulationPatterns(
typeConverter, patterns, disableAtomicRMW, assumeAligned);