From da0314842e5604e50ce2d1bb1aedda5cd7269c63 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 13 Mar 2026 10:35:37 -0500 Subject: [PATCH] [mlir][bytecode] Unpack i1 splats to 0x01 (#186221) Previously the arith folder test would emit `dense<255>` (`0xFF` zero extended). In-memory without bytecode is `0x01`, so this change ensures in-memory formats match. Also changes `0xFF` to `~0x00` since compilation on machines with signed chars was causing issues, this should ensure it is set to all ones regardless of char interpretation: ``` [1083/5044] Building CXX object tools/mlir/lib/IR/CMakeFiles/obj.MLIRIR.dir/BuiltinDialectBytecode.cpp.o /.../BuiltinDialectBytecode.cpp:184:35: warning: result of comparison of constant 255 with expression of type 'const char' is always false [-Wtautological-constant-out-of-range-compare] 184 | if (blob.size() == 1 && blob[0] == 0xFF) { | ~~~~~~~ ^ ~~~~ 1 warning generated. ``` Fixes llvm/llvm-project#186178 --- mlir/lib/IR/BuiltinDialectBytecode.cpp | 16 ++++++++++++++-- mlir/test/Bytecode/i1_roundtrip.mlir | 12 ++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp index fe194991a6b2..c55fe64d781b 100644 --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -179,8 +179,18 @@ readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader, ShapedType type, // cheap. size_t numElements = type.getNumElements(); size_t packedSize = llvm::divideCeil(numElements, 8); + + // Unpack splats to single element 0x01 to match unpacked splat format. + if (blob.size() == 1 && blob[0] == ~0x00) { + rawData.resize(1); + rawData[0] = 0x01; + return success(); + } + + // Unpack the blob if it's packed. + // Splat and blob.size() == packedSize for all N<=8 elements are ambiguous, + // non 0xFF means not splat so must be unpacked. if (blob.size() == packedSize && blob.size() != numElements) { - // Unpack the blob. rawData.resize(numElements); for (size_t i = 0; i < numElements; ++i) rawData[i] = (blob[i / 8] & (1 << (i % 8))) ? 1 : 0; @@ -200,9 +210,11 @@ static void writeDenseIntOrFPElementsAttr(DialectBytecodeWriter &writer, ArrayRef rawData = attr.getRawData(); // If the attribute is a splat, we can just splat the value directly. + // Use 0xFF to avoid ambiguity with packed format of <=8 elements, + // written ~0x00 to ensure proper compilation with signed chars. if (attr.isSplat()) { data.resize(1); - data[0] = rawData[0] ? 0xFF : 0x00; + data[0] = rawData[0] ? ~0x00 : 0x00; writer.writeUnownedBlob(data); return; } diff --git a/mlir/test/Bytecode/i1_roundtrip.mlir b/mlir/test/Bytecode/i1_roundtrip.mlir index dc2529e62430..aa11b66b3c07 100644 --- a/mlir/test/Bytecode/i1_roundtrip.mlir +++ b/mlir/test/Bytecode/i1_roundtrip.mlir @@ -1,4 +1,6 @@ // RUN: mlir-opt %s -emit-bytecode | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -canonicalize | FileCheck %s --check-prefix=CHECK-FOLD +// RUN: mlir-opt %s -emit-bytecode | mlir-opt -canonicalize | FileCheck %s --check-prefix=CHECK-FOLD // CHECK-LABEL: func.func @test_i1_splat_true func.func @test_i1_splat_true() -> tensor<100xi1> { @@ -43,3 +45,13 @@ func.func @test_i9_mixed() { %0 = arith.constant dense<[true, false, true, false, true, false, true, false, true]> : tensor<9xi1> return } + +// Test that the in-memory representation of i1 values is correctly handled +// during bytecode roundtrip (must be unpacked to 0x01 not 0xFF). +// See llvm/llvm-project#186178. +func.func public @test_in_memory_repr() -> (tensor<32xi32> {jax.result_info = "result"}) { + // CHECK-FOLD: dense<1> : tensor<32xi32> + %cst = arith.constant dense : tensor<32xi1> + %0 = arith.extui %cst : tensor<32xi1> to tensor<32xi32> + return %0 : tensor<32xi32> +}