[mlir] Add Repeated<T> constructors for TypeRange and ValueRange (#186923)
Many MLIR APIs end up using a range of the same Type / Value repeated N times, due to the (function of the) dimensionality of the problem. Allocating a vector of N identical element is wasteful. Add `Repeated<T>` as PointerUnion variants in TypeRange and ValueRange, enabling O(1) storage for repeated elements. Size remains 2 pointers (16 bytes on 64-bit) for both range types. This required variable-width `PointerUnion` encoding added in https://github.com/llvm/llvm-project/pull/188167 on 32-bit systems. Also update several MLIR dialects and conversions to exercise the new code. --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "llvm/ADT/iterator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
@@ -72,21 +73,19 @@ public:
|
||||
///
|
||||
/// `Repeated<T>` is also a proper random-access range: `begin()`/`end()`
|
||||
/// return iterators that always dereference to the same stored value.
|
||||
template <typename T> struct [[nodiscard]] Repeated {
|
||||
/// Wrapper for the stored value used as a PointerUnion target in range
|
||||
/// types (e.g., TypeRange, ValueRange).
|
||||
struct Storage {
|
||||
T value;
|
||||
};
|
||||
|
||||
Storage storage;
|
||||
// At least 16-byte aligned so that Repeated<T>* has more low bits available
|
||||
// than a plain pointer. The primary use case is pointer-like types (e.g. MLIR
|
||||
// Type, Value) where Repeated<T>* appears in a PointerUnion alongside them.
|
||||
template <typename T>
|
||||
struct [[nodiscard]] alignas(std::max(size_t{16}, alignof(T))) Repeated {
|
||||
T storage;
|
||||
size_t count;
|
||||
|
||||
/// Create a `value` repeated `count` times.
|
||||
/// Uses the same argument order like STD container constructors.
|
||||
/// Uses the same argument order like std container constructors.
|
||||
template <typename U>
|
||||
Repeated(size_t count, U &&value)
|
||||
: storage{std::forward<U>(value)}, count(count) {}
|
||||
: storage(std::forward<U>(value)), count(count) {}
|
||||
|
||||
using iterator = RepeatedIterator<T>;
|
||||
using const_iterator = iterator;
|
||||
@@ -95,21 +94,19 @@ template <typename T> struct [[nodiscard]] Repeated {
|
||||
using value_type = T;
|
||||
using size_type = size_t;
|
||||
|
||||
iterator begin() const { return {&storage.value, 0}; }
|
||||
iterator end() const {
|
||||
return {&storage.value, static_cast<ptrdiff_t>(count)};
|
||||
}
|
||||
iterator begin() const { return {&storage, 0}; }
|
||||
iterator end() const { return {&storage, static_cast<ptrdiff_t>(count)}; }
|
||||
reverse_iterator rbegin() const { return reverse_iterator(end()); }
|
||||
reverse_iterator rend() const { return reverse_iterator(begin()); }
|
||||
|
||||
size_t size() const { return count; }
|
||||
bool empty() const { return count == 0; }
|
||||
|
||||
const T &value() const { return storage.value; }
|
||||
const T &value() const { return storage; }
|
||||
const T &operator[](size_t idx) const {
|
||||
assert(idx < size() && "Out of bounds");
|
||||
(void)idx;
|
||||
return storage.value;
|
||||
return storage;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
#include "llvm/ADT/Repeated.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -30,11 +31,13 @@ namespace mlir {
|
||||
/// a SmallVector/std::vector. This class should be used in places that are not
|
||||
/// suitable for a more derived type (e.g. ArrayRef) or a template range
|
||||
/// parameter.
|
||||
class TypeRange : public llvm::detail::indexed_accessor_range_base<
|
||||
TypeRange,
|
||||
llvm::PointerUnion<const Value *, const Type *,
|
||||
OpOperand *, detail::OpResultImpl *>,
|
||||
Type, Type, Type> {
|
||||
class TypeRange
|
||||
: public llvm::detail::indexed_accessor_range_base<
|
||||
TypeRange,
|
||||
llvm::PointerUnion<const Value *, const Type *, OpOperand *,
|
||||
detail::OpResultImpl *, const Repeated<Type> *,
|
||||
const Repeated<Value> *>,
|
||||
Type, Type, Type> {
|
||||
public:
|
||||
using RangeBaseT::RangeBaseT;
|
||||
TypeRange(ArrayRef<Type> types = {});
|
||||
@@ -51,6 +54,10 @@ public:
|
||||
: TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
|
||||
TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND)
|
||||
: TypeRange(ArrayRef<Type>(types)) {}
|
||||
/// Constructs a range from a repeated type. The Repeated object must outlive
|
||||
/// this range.
|
||||
TypeRange(const Repeated<Type> &repeatedValue LLVM_LIFETIME_BOUND)
|
||||
: RangeBaseT(&repeatedValue, repeatedValue.count) {}
|
||||
|
||||
private:
|
||||
/// The owner of the range is either:
|
||||
@@ -58,8 +65,13 @@ private:
|
||||
/// * A pointer to the first element of an array of types.
|
||||
/// * A pointer to the first element of an array of operands.
|
||||
/// * A pointer to the first element of an array of results.
|
||||
using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
|
||||
detail::OpResultImpl *>;
|
||||
/// * A pointer to a Repeated<Type> (single type repeated N times).
|
||||
/// * A pointer to a Repeated<Value> (single value repeated N times,
|
||||
/// dereferenced via getType()).
|
||||
using OwnerT =
|
||||
llvm::PointerUnion<const Value *, const Type *, OpOperand *,
|
||||
detail::OpResultImpl *, const Repeated<Type> *,
|
||||
const Repeated<Value> *>;
|
||||
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
static OwnerT offset_base(OwnerT object, ptrdiff_t index);
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
#include "llvm/ADT/Repeated.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include <optional>
|
||||
|
||||
@@ -383,13 +384,14 @@ private:
|
||||
class ValueRange final
|
||||
: public llvm::detail::indexed_accessor_range_base<
|
||||
ValueRange,
|
||||
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
|
||||
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
|
||||
const Repeated<Value> *>,
|
||||
Value, Value, Value> {
|
||||
public:
|
||||
/// The type representing the owner of a ValueRange. This is either a list of
|
||||
/// values, operands, or results.
|
||||
using OwnerT =
|
||||
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
|
||||
/// values, operands, results, or a repeated single value.
|
||||
using OwnerT = PointerUnion<const Value *, OpOperand *,
|
||||
detail::OpResultImpl *, const Repeated<Value> *>;
|
||||
|
||||
using RangeBaseT::RangeBaseT;
|
||||
|
||||
@@ -412,6 +414,10 @@ public:
|
||||
ValueRange(ArrayRef<Value> values = {});
|
||||
ValueRange(OperandRange values);
|
||||
ValueRange(ResultRange values);
|
||||
/// Constructs a range from a repeated value. The Repeated object must outlive
|
||||
/// this range.
|
||||
ValueRange(const Repeated<Value> &repeatedValue LLVM_LIFETIME_BOUND)
|
||||
: RangeBaseT(&repeatedValue, repeatedValue.count) {}
|
||||
|
||||
/// Returns the types of the values within this range.
|
||||
using type_iterator = ValueTypeIterator<iterator>;
|
||||
|
||||
@@ -69,6 +69,8 @@ class StringSet;
|
||||
template <typename T, typename R>
|
||||
class StringSwitch;
|
||||
template <typename T>
|
||||
struct Repeated;
|
||||
template <typename T>
|
||||
class TinyPtrVector;
|
||||
template <typename T, typename ResultT>
|
||||
class TypeSwitch;
|
||||
@@ -125,6 +127,7 @@ template <typename AllocatorTy = llvm::MallocAllocator>
|
||||
using StringSet = llvm::StringSet<AllocatorTy>;
|
||||
using llvm::MutableArrayRef;
|
||||
using llvm::PointerUnion;
|
||||
using llvm::Repeated;
|
||||
using llvm::SmallPtrSet;
|
||||
using llvm::SmallPtrSetImpl;
|
||||
using llvm::SmallVector;
|
||||
|
||||
@@ -95,7 +95,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
// New arguments will simply be `llvm.ptr` with the correct address space
|
||||
Type workgroupPtrType =
|
||||
rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
|
||||
SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
|
||||
Repeated<Type> argTypes(numAttributions, workgroupPtrType);
|
||||
|
||||
// Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
|
||||
std::array attrs{
|
||||
|
||||
@@ -155,11 +155,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
|
||||
int count = vectorType.getNumElements();
|
||||
intType = VectorType::get(count, intType);
|
||||
|
||||
SmallVector<Value> signSplat(count, signMask);
|
||||
Repeated<Value> signSplat(count, signMask);
|
||||
signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
|
||||
signSplat);
|
||||
|
||||
SmallVector<Value> valueSplat(count, valueMask);
|
||||
Repeated<Value> valueSplat(count, valueMask);
|
||||
valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
|
||||
valueSplat);
|
||||
}
|
||||
|
||||
@@ -117,8 +117,8 @@ public:
|
||||
auto one = createIndexConst(rewriter, loc, 1);
|
||||
|
||||
// Loop bounds
|
||||
auto lbs = llvm::SmallVector<Value>(2, zero);
|
||||
auto steps = llvm::SmallVector<Value>(2, one);
|
||||
auto lbs = Repeated<Value>(2, zero);
|
||||
auto steps = Repeated<Value>(2, one);
|
||||
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
|
||||
|
||||
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
|
||||
|
||||
@@ -354,7 +354,7 @@ static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
|
||||
rewriter, loc,
|
||||
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
|
||||
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
|
||||
SmallVector<Value> indices(2, zeroIndex);
|
||||
Repeated<Value> indices(2, zeroIndex);
|
||||
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
|
||||
|
||||
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
|
||||
|
||||
@@ -723,7 +723,7 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
|
||||
// the outside.
|
||||
Value falseVal = buildBoolValue(builder, op.getLoc(), false);
|
||||
op->insertOperands(op->getNumOperands(),
|
||||
SmallVector<Value>(numMemrefOperands, falseVal));
|
||||
Repeated<Value>(numMemrefOperands, falseVal));
|
||||
|
||||
int counter = op->getNumResults();
|
||||
unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
|
||||
|
||||
@@ -544,7 +544,7 @@ class TransferReadDropUnitDimsPattern
|
||||
Value reducedShapeSource =
|
||||
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
SmallVector<Value> zeros(reducedRank, c0);
|
||||
Repeated<Value> zeros(reducedRank, c0);
|
||||
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
|
||||
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
|
||||
Operation *newTransferReadOp = vector::TransferReadOp::create(
|
||||
@@ -658,7 +658,7 @@ class TransferWriteDropUnitDimsPattern
|
||||
Value reducedShapeSource =
|
||||
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
SmallVector<Value> zeros(reducedRank, c0);
|
||||
Repeated<Value> zeros(reducedRank, c0);
|
||||
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
|
||||
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
|
||||
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
|
||||
|
||||
@@ -357,7 +357,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
|
||||
builder, loc,
|
||||
/*vectorType=*/vecToReadTy,
|
||||
/*source=*/source,
|
||||
/*indices=*/SmallVector<Value>(vecToReadRank, zero),
|
||||
/*indices=*/Repeated<Value>(vecToReadRank, zero),
|
||||
/*padding=*/padValue,
|
||||
/*inBounds=*/inBoundsVal);
|
||||
|
||||
|
||||
@@ -654,6 +654,9 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
|
||||
return {value + index};
|
||||
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||
return {operand + index};
|
||||
// All elements are identical; the owner pointer never advances.
|
||||
if (llvm::isa<const Repeated<Value> *>(owner))
|
||||
return owner;
|
||||
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
|
||||
}
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
@@ -662,6 +665,9 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
|
||||
return value[index];
|
||||
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||
return operand[index].get();
|
||||
if (auto *repeated =
|
||||
llvm::dyn_cast_if_present<const Repeated<Value> *>(owner))
|
||||
return repeated->value();
|
||||
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
|
||||
this->base = result;
|
||||
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||
this->base = operand;
|
||||
else if (auto *repeated =
|
||||
llvm::dyn_cast_if_present<const Repeated<Value> *>(owner))
|
||||
this->base = repeated;
|
||||
else
|
||||
this->base = cast<const Value *>(owner);
|
||||
}
|
||||
@@ -43,6 +46,9 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
|
||||
return {operand + index};
|
||||
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
|
||||
return {result->getNextResultAtOffset(index)};
|
||||
// All elements are identical; the owner pointer never advances.
|
||||
if (llvm::isa<const Repeated<Type> *, const Repeated<Value> *>(object))
|
||||
return object;
|
||||
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
|
||||
}
|
||||
|
||||
@@ -54,5 +60,11 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
||||
return (operand + index)->get().getType();
|
||||
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
|
||||
return result->getNextResultAtOffset(index)->getType();
|
||||
if (auto *repeated =
|
||||
llvm::dyn_cast_if_present<const Repeated<Type> *>(object))
|
||||
return repeated->value();
|
||||
if (auto *repeated =
|
||||
llvm::dyn_cast_if_present<const Repeated<Value> *>(object))
|
||||
return repeated->value().getType();
|
||||
return llvm::dyn_cast_if_present<const Type *>(object)[index];
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/Repeated.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
@@ -375,4 +376,66 @@ TEST(OperationCloneTest, CloneWithDifferentResults) {
|
||||
cloneOp->destroy();
|
||||
}
|
||||
|
||||
TEST(RepeatedRangeTest, TypeRangeFromRepeatedType) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
Type i32 = builder.getI32Type();
|
||||
|
||||
llvm::Repeated<Type> rep(3, i32);
|
||||
TypeRange range(rep);
|
||||
|
||||
EXPECT_EQ(range.size(), 3u);
|
||||
EXPECT_FALSE(range.empty());
|
||||
for (Type t : range)
|
||||
EXPECT_EQ(t, i32);
|
||||
|
||||
// Indexing and slicing exercise offset_base (which must not advance).
|
||||
EXPECT_EQ(range[0], i32);
|
||||
EXPECT_EQ(range[2], i32);
|
||||
TypeRange sliced = range.drop_front(1);
|
||||
EXPECT_EQ(sliced.size(), 2u);
|
||||
EXPECT_EQ(sliced[0], i32);
|
||||
|
||||
llvm::Repeated<Type> emptyRep(0, Type{});
|
||||
TypeRange emptyTypeRange(emptyRep);
|
||||
|
||||
EXPECT_EQ(emptyTypeRange.size(), 0u);
|
||||
EXPECT_TRUE(emptyTypeRange.empty());
|
||||
}
|
||||
|
||||
TEST(RepeatedRangeTest, ValueRangeFromRepeatedValue) {
|
||||
Value nullVal;
|
||||
llvm::Repeated<Value> rep(4, nullVal);
|
||||
ValueRange range(rep);
|
||||
|
||||
EXPECT_EQ(range.size(), 4u);
|
||||
EXPECT_FALSE(range.empty());
|
||||
for (Value v : range)
|
||||
EXPECT_EQ(v, nullVal);
|
||||
|
||||
llvm::Repeated<Value> emptyRep(0, nullVal);
|
||||
ValueRange emptyRange(emptyRep);
|
||||
EXPECT_EQ(emptyRange.size(), 0u);
|
||||
EXPECT_TRUE(emptyRange.empty());
|
||||
}
|
||||
|
||||
TEST(RepeatedRangeTest, TypeRangeFromRepeatedValueViaValueRange) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
createOp(&context, /*operands=*/{}, builder.getIntegerType(16));
|
||||
Value operand = useOp->getResult(0);
|
||||
|
||||
llvm::Repeated<Value> rep(3, operand);
|
||||
ValueRange vr(rep);
|
||||
TypeRange tr(vr);
|
||||
|
||||
EXPECT_EQ(tr.size(), 3u);
|
||||
for (Type t : tr)
|
||||
EXPECT_EQ(t, builder.getIntegerType(16));
|
||||
|
||||
useOp->destroy();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user