[mlir][bytecode] Unpack i1 splats to 0x01 (#186221)
Previously the arith folder test would emit `dense<255>` (`0xFF` zero
extended). In-memory without bytecode is `0x01`, so this change ensures
in-memory formats match.
Also changes `0xFF` to `~0x00` since compilation on machines with signed
chars was causing issues, this should ensure it is set to all ones
regardless of char interpretation:
```
[1083/5044] Building CXX object tools/mlir/lib/IR/CMakeFiles/obj.MLIRIR.dir/BuiltinDialectBytecode.cpp.o
/.../BuiltinDialectBytecode.cpp:184:35: warning: result of comparison of constant 255 with expression of type 'const char' is always false [-Wtautological-constant-out-of-range-compare]
184 | if (blob.size() == 1 && blob[0] == 0xFF) {
| ~~~~~~~ ^ ~~~~
1 warning generated.
```
Fixes llvm/llvm-project#186178
This commit is contained in:
@@ -179,8 +179,18 @@ readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader, ShapedType type,
|
||||
// cheap.
|
||||
size_t numElements = type.getNumElements();
|
||||
size_t packedSize = llvm::divideCeil(numElements, 8);
|
||||
|
||||
// Unpack splats to single element 0x01 to match unpacked splat format.
|
||||
if (blob.size() == 1 && blob[0] == ~0x00) {
|
||||
rawData.resize(1);
|
||||
rawData[0] = 0x01;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Unpack the blob if it's packed.
|
||||
// Splat and blob.size() == packedSize for all N<=8 elements are ambiguous,
|
||||
// non 0xFF means not splat so must be unpacked.
|
||||
if (blob.size() == packedSize && blob.size() != numElements) {
|
||||
// Unpack the blob.
|
||||
rawData.resize(numElements);
|
||||
for (size_t i = 0; i < numElements; ++i)
|
||||
rawData[i] = (blob[i / 8] & (1 << (i % 8))) ? 1 : 0;
|
||||
@@ -200,9 +210,11 @@ static void writeDenseIntOrFPElementsAttr(DialectBytecodeWriter &writer,
|
||||
ArrayRef<char> rawData = attr.getRawData();
|
||||
|
||||
// If the attribute is a splat, we can just splat the value directly.
|
||||
// Use 0xFF to avoid ambiguity with packed format of <=8 elements,
|
||||
// written ~0x00 to ensure proper compilation with signed chars.
|
||||
if (attr.isSplat()) {
|
||||
data.resize(1);
|
||||
data[0] = rawData[0] ? 0xFF : 0x00;
|
||||
data[0] = rawData[0] ? ~0x00 : 0x00;
|
||||
writer.writeUnownedBlob(data);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
// RUN: mlir-opt %s -emit-bytecode | mlir-opt | FileCheck %s
|
||||
// RUN: mlir-opt %s -canonicalize | FileCheck %s --check-prefix=CHECK-FOLD
|
||||
// RUN: mlir-opt %s -emit-bytecode | mlir-opt -canonicalize | FileCheck %s --check-prefix=CHECK-FOLD
|
||||
|
||||
// CHECK-LABEL: func.func @test_i1_splat_true
|
||||
func.func @test_i1_splat_true() -> tensor<100xi1> {
|
||||
@@ -43,3 +45,13 @@ func.func @test_i9_mixed() {
|
||||
%0 = arith.constant dense<[true, false, true, false, true, false, true, false, true]> : tensor<9xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// Test that the in-memory representation of i1 values is correctly handled
|
||||
// during bytecode roundtrip (must be unpacked to 0x01 not 0xFF).
|
||||
// See llvm/llvm-project#186178.
|
||||
func.func public @test_in_memory_repr() -> (tensor<32xi32> {jax.result_info = "result"}) {
|
||||
// CHECK-FOLD: dense<1> : tensor<32xi32>
|
||||
%cst = arith.constant dense<true> : tensor<32xi1>
|
||||
%0 = arith.extui %cst : tensor<32xi1> to tensor<32xi32>
|
||||
return %0 : tensor<32xi32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user