[ADT] Introduce bind_{front,back}, [not_]equal_to (#175056)

Introduce a llvm::bind_front and llvm::bind_back. To demonstrate its
utility, we pose the problem of shortening a common idiom where we use
an STL algorithm like all_of or any_of, and check the members of the
range against a value: we introduce llvm::{equal_to, not_equal_to} in
terms of llvm::bind_{front, back}.

---------

Co-authored-by: Jakub Kuderski <jakub@nod-labs.com>
Co-authored-by: Yanzuo Liu <zwuis@outlook.com>
This commit is contained in:
Ramkumar Ramachandra
2026-01-13 17:52:22 +00:00
committed by GitHub
parent 8784816a41
commit d2a521750a
4 changed files with 490 additions and 0 deletions

View File

@@ -2157,6 +2157,20 @@ template <typename T> bool all_equal(std::initializer_list<T> Values) {
return all_equal<std::initializer_list<T>>(std::move(Values));
}
/// Functor variant of std::equal_to that can be used as a UnaryPredicate in
/// functional algorithms like all_of. `Args` is forwarded and stored by value.
/// If you would like to pass by reference, use `std::ref` or `std::cref`.
template <typename T> constexpr auto equal_to(T &&Arg) {
return bind_front(std::equal_to<>{}, std::forward<T>(Arg));
}
/// Functor variant of std::not_equal_to that can be used as a UnaryPredicate in
/// functional algorithms like all_of. `Args` is forwarded and stored by value.
/// If you would like to pass by reference, use `std::ref` or `std::cref`.
template <typename T> constexpr auto not_equal_to(T &&Arg) {
return bind_front(std::not_equal_to<>{}, std::forward<T>(Arg));
}
/// Provide a container algorithm similar to C++ Library Fundamentals v2's
/// `erase_if` which is equivalent to:
///

View File

@@ -17,6 +17,7 @@
#ifndef LLVM_ADT_STLFORWARDCOMPAT_H
#define LLVM_ADT_STLFORWARDCOMPAT_H
#include <functional>
#include <optional>
#include <tuple>
#include <type_traits>
@@ -185,6 +186,140 @@ struct from_range_t {
explicit from_range_t() = default;
};
inline constexpr from_range_t from_range{};
//===----------------------------------------------------------------------===//
// Bind functions from C++20 / C++23 / C++26
//===----------------------------------------------------------------------===//
namespace detail {
// Tag for constructing with a runtime callable.
struct RuntimeFnTag {};
// Tag for constructing with a compile-time constant callable.
struct ConstantFnTag {};
/// Stores a callable as a data member.
template <typename FnT> struct FnHolder {
FnT Fn;
template <typename FnArgT>
constexpr explicit FnHolder(FnArgT &&F) : Fn(std::forward<FnArgT>(F)) {}
constexpr FnT &get() { return Fn; }
constexpr const FnT &get() const { return Fn; }
};
/// Holds a compile-time constant callable (empty storage).
template <auto ConstFn> struct FnConstant {
constexpr decltype(auto) get() const { return ConstFn; }
};
// Storage class for bind_front/bind_back that properly handles const/non-const
// qualification of the wrapper when invoking the stored callable.
// If BindFront is true, bound args are prepended; otherwise appended.
// FnStorageT is either FnHolder<FnT> (runtime) or FnConstant<ConstFn>.
template <bool BindFront, typename BoundArgsTupleT, typename FnStorageT,
typename IndicesT>
class BindStorage;
template <bool BindFront, typename BoundArgsTupleT, typename FnStorageT,
size_t... Indices>
class BindStorage<BindFront, BoundArgsTupleT, FnStorageT,
std::index_sequence<Indices...>> {
BoundArgsTupleT BoundArgs;
// This may be empty for const functions, hence the `no_unique_address`.
[[no_unique_address]] FnStorageT FnStorage;
public:
// Constructor for FnHolder (runtime callable).
template <typename FnArgT, typename... BoundArgsArgT>
constexpr BindStorage(RuntimeFnTag, FnArgT &&F, BoundArgsArgT &&...Args)
: BoundArgs(std::forward<BoundArgsArgT>(Args)...),
FnStorage(std::forward<FnArgT>(F)) {}
// Constructor for FnConstant (compile-time callable).
template <typename... BoundArgsArgT>
constexpr BindStorage(ConstantFnTag, BoundArgsArgT &&...Args)
: BoundArgs(std::forward<BoundArgsArgT>(Args)...), FnStorage() {}
template <typename... CallArgsT>
constexpr auto operator()(CallArgsT &&...CallArgs) {
if constexpr (BindFront)
return llvm::invoke(FnStorage.get(), std::get<Indices>(BoundArgs)...,
std::forward<CallArgsT>(CallArgs)...);
else
return llvm::invoke(FnStorage.get(), std::forward<CallArgsT>(CallArgs)...,
std::get<Indices>(BoundArgs)...);
}
template <typename... CallArgsT>
constexpr auto operator()(CallArgsT &&...CallArgs) const {
if constexpr (BindFront)
return llvm::invoke(FnStorage.get(), std::get<Indices>(BoundArgs)...,
std::forward<CallArgsT>(CallArgs)...);
else
return llvm::invoke(FnStorage.get(), std::forward<CallArgsT>(CallArgs)...,
std::get<Indices>(BoundArgs)...);
}
};
} // end namespace detail
/// C++20 bind_front. Prepends bound arguments to the callable. All bind
/// arguments and the callable are forwarded and *stored* by value. If you would
/// like to pass by reference, use `std::ref` or `std::cref`.
template <typename FnT, typename... BindArgsT>
constexpr auto bind_front(FnT &&Fn, // NOLINT(readability-identifier-naming)
BindArgsT &&...BindArgs) {
return detail::BindStorage</*BindFront=*/true,
std::tuple<std::decay_t<BindArgsT>...>,
detail::FnHolder<std::decay_t<FnT>>,
std::index_sequence_for<BindArgsT...>>(
detail::RuntimeFnTag{}, std::forward<FnT>(Fn),
std::forward<BindArgsT>(BindArgs)...);
}
/// C++26 bind_front with compile-time callable. Prepends bound arguments.
/// Bound arguments are forwarded and *stored* by value.
template <auto ConstFn, typename... BindArgsT>
constexpr auto
bind_front(BindArgsT &&...BindArgs) { // NOLINT(readability-identifier-naming)
if constexpr (std::is_pointer_v<decltype(ConstFn)> ||
std::is_member_pointer_v<decltype(ConstFn)>)
static_assert(ConstFn != nullptr);
return detail::BindStorage<
/*BindFront=*/true, std::tuple<std::decay_t<BindArgsT>...>,
detail::FnConstant<ConstFn>, std::index_sequence_for<BindArgsT...>>(
detail::ConstantFnTag{}, std::forward<BindArgsT>(BindArgs)...);
}
/// C++23 bind_back. Appends bound arguments to the callable. All bind
/// arguments and the callable are forwarded and *stored* by value. If you would
/// like to pass by reference, use `std::ref` or `std::cref`.
template <typename FnT, typename... BindArgsT>
constexpr auto bind_back(FnT &&Fn, // NOLINT(readability-identifier-naming)
BindArgsT &&...BindArgs) {
return detail::BindStorage</*BindFront=*/false,
std::tuple<std::decay_t<BindArgsT>...>,
detail::FnHolder<std::decay_t<FnT>>,
std::index_sequence_for<BindArgsT...>>(
detail::RuntimeFnTag{}, std::forward<FnT>(Fn),
std::forward<BindArgsT>(BindArgs)...);
}
/// C++26 bind_back with compile-time callable. Appends bound arguments.
/// Bound arguments are forwarded and *stored* by value.
template <auto ConstFn, typename... BindArgsT>
constexpr auto
bind_back(BindArgsT &&...BindArgs) { // NOLINT(readability-identifier-naming)
if constexpr (std::is_pointer_v<decltype(ConstFn)> ||
std::is_member_pointer_v<decltype(ConstFn)>)
static_assert(ConstFn != nullptr);
return detail::BindStorage<
/*BindFront=*/false, std::tuple<std::decay_t<BindArgsT>...>,
detail::FnConstant<ConstFn>, std::index_sequence_for<BindArgsT...>>(
detail::ConstantFnTag{}, std::forward<BindArgsT>(BindArgs)...);
}
} // namespace llvm
#endif // LLVM_ADT_STLFORWARDCOMPAT_H

View File

@@ -1055,6 +1055,28 @@ TEST(STLExtrasTest, to_address) {
EXPECT_EQ(V1, llvm::to_address(V3));
}
TEST(STLExtras, EqualToNotEqualTo) {
std::vector<int> V;
EXPECT_TRUE(all_of(V, equal_to(1)));
EXPECT_TRUE(all_of(V, not_equal_to(1)));
V.push_back(1);
EXPECT_TRUE(all_of(V, equal_to(1)));
EXPECT_TRUE(all_of(V, not_equal_to(2)));
V.push_back(1);
V.push_back(1);
EXPECT_TRUE(all_of(V, equal_to(1)));
EXPECT_TRUE(all_of(V, not_equal_to(2)));
EXPECT_TRUE(none_of(V, equal_to(2)));
V.push_back(2);
EXPECT_FALSE(all_of(V, equal_to(1)));
EXPECT_FALSE(all_of(V, not_equal_to(1)));
EXPECT_TRUE(any_of(V, equal_to(2)));
EXPECT_TRUE(any_of(V, not_equal_to(2)));
}
TEST(STLExtrasTest, partition_point) {
std::vector<int> V = {1, 3, 5, 7, 9};

View File

@@ -14,6 +14,7 @@
#include <type_traits>
#include <utility>
namespace llvm {
namespace {
template <typename T>
@@ -205,6 +206,10 @@ TEST(STLForwardCompatTest, IdentityCxx20) {
static_assert(std::is_same_v<int &&, decltype(identity(int(5)))>);
}
//===----------------------------------------------------------------------===//
// llvm::invoke tests
//===----------------------------------------------------------------------===//
TEST(STLForwardCompatTest, InvokePerfectForwarding) {
auto CheckArgs = [](auto &&A, auto &&B, auto &&C) {
static_assert(std::is_same_v<decltype(A), int &>);
@@ -263,4 +268,318 @@ TEST(STLForwardCompatTest, InvokeConstexpr) {
static_assert(C == 15);
}
TEST(STLForwardCompatTest, BindFrontReferences) {
// All bound arguments are forwarded (for ints, this is a copy) into the
// wrapper. Call arguments are forwarded with their original value category.
int A = 1;
const int B = 2;
int C = 3;
int D = 4;
const int E = 5;
int F = 6;
auto TestTypes = [](auto &&AArg, auto &&BArg, auto &&CArg, auto &&DArg,
auto &&EArg, auto &&FArg) {
// Bound args: all stored as values, passed as lvalue refs.
EXPECT_EQ(AArg, 1);
static_assert(std::is_same_v<decltype(AArg), int &>);
EXPECT_EQ(BArg, 2);
static_assert(std::is_same_v<decltype(BArg), int &>); // Const decayed away.
EXPECT_EQ(CArg, 3);
static_assert(std::is_same_v<decltype(CArg), int &>);
// Call args: forwarded with original value category.
EXPECT_EQ(DArg, 4);
static_assert(std::is_same_v<decltype(DArg), int &>);
EXPECT_EQ(EArg, 5);
static_assert(std::is_same_v<decltype(EArg), const int &>);
EXPECT_EQ(FArg, 6);
static_assert(std::is_same_v<decltype(FArg), int &&>);
++AArg;
++DArg;
};
bind_front(TestTypes, A, B, std::move(C))(D, E, std::move(F));
EXPECT_EQ(A, 1); // A was copied, original unchanged.
EXPECT_EQ(D, 5); // D was passed by reference and incremented.
}
TEST(STLForwardCompatTest, BindBackReferences) {
// With std::decay_t, all bound arguments are copied into the wrapper.
// Call arguments are forwarded with their original value category.
int A = 1;
const int B = 2;
int C = 3;
int D = 4;
const int E = 5;
int F = 6;
auto TestTypes = [](auto &&AArg, auto &&BArg, auto &&CArg, auto &&DArg,
auto &&EArg, auto &&FArg) {
// Call args: forwarded with original value category.
EXPECT_EQ(AArg, 1);
static_assert(std::is_same_v<decltype(AArg), int &>);
EXPECT_EQ(BArg, 2);
static_assert(std::is_same_v<decltype(BArg), const int &>);
EXPECT_EQ(CArg, 3);
static_assert(std::is_same_v<decltype(CArg), int &&>);
// Bound args: all stored as values, passed as lvalue refs.
EXPECT_EQ(DArg, 4);
static_assert(std::is_same_v<decltype(DArg), int &>);
EXPECT_EQ(EArg, 5);
static_assert(std::is_same_v<decltype(EArg), int &>); // Const decayed away.
EXPECT_EQ(FArg, 6);
static_assert(std::is_same_v<decltype(FArg), int &>);
++AArg;
++DArg;
};
bind_back(TestTypes, D, E, std::move(F))(A, B, std::move(C));
EXPECT_EQ(A, 2); // A was passed by reference and incremented.
EXPECT_EQ(D, 4); // D was copied, original unchanged.
}
// Check that bound args are copied once during bind, then passed by reference.
TEST(STLForwardCompatTest, BindBoundArgsForwarding) {
auto Fn = [](CountCopyAndMove &A) -> int { return A.val; };
CountCopyAndMove::ResetCounts();
CountCopyAndMove Arg(42);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 0);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 0);
// Creating the wrapper should copy the bound args once.
auto Bound = bind_front(Fn, Arg);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 1);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 0);
// Calling should not copy -- bound args are passed by reference.
EXPECT_EQ(Bound(), 42);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 1);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 0);
}
// Check that call args are forwarded without copies.
TEST(STLForwardCompatTest, BindCallArgsForwarding) {
auto Fn = [](int, CountCopyAndMove &&A) -> int { return A.val; };
CountCopyAndMove::ResetCounts();
auto Bound = bind_front(Fn, 1);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 0);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 0);
// Call arg should be forwarded as rvalue, no copies.
EXPECT_EQ(Bound(CountCopyAndMove(42)), 42);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 0);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 0);
}
// Check that the callable itself is moved, not copied excessively.
TEST(STLForwardCompatTest, BindCallableForwarding) {
CountCopyAndMove::ResetCounts();
CountCopyAndMove Capture(42);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 0);
// Lambda captures by value -- this copy is outside bind's control.
auto Fn = [Capture]() { return Capture.val; };
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 1);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 0);
// Moving lambda into bind should move, not copy.
auto Bound = bind_front(std::move(Fn));
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 1);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 1);
// Calling should not copy the callable.
EXPECT_EQ(Bound(), 42);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 1);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 1);
// The bind object should be copyable.
auto BoundCopy = Bound;
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 2);
// The bind object should be movable.
auto BoundMove = std::move(BoundCopy);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 2);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 2);
}
// Check that moving bound args works correctly.
TEST(STLForwardCompatTest, BindMoveBoundArgs) {
auto Fn = [](CountCopyAndMove &A) -> int { return A.val; };
CountCopyAndMove::ResetCounts();
CountCopyAndMove Arg(42);
// Moving into bind should move, not copy.
auto Bound = bind_front(Fn, std::move(Arg));
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 0);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 1);
EXPECT_EQ(Bound(), 42);
EXPECT_EQ(CountCopyAndMove::TotalCopies(), 0);
EXPECT_EQ(CountCopyAndMove::TotalMoves(), 1);
}
TEST(STLForwardCompatTest, BindFrontMutableStorage) {
// With std::decay_t, A is copied into the wrapper. The stored copy can be
// mutated across calls, but the original A is unchanged.
int A = 1;
auto TestMutation = [](int &AArg, int &BArg, auto ExtraCheckFn) {
++AArg;
++BArg;
ExtraCheckFn(AArg, BArg);
};
auto BoundA = bind_front(TestMutation, A, 42);
BoundA([](int AVal, int BVal) {
EXPECT_EQ(AVal, 2); // Stored copy incremented from 1.
EXPECT_EQ(BVal, 43);
});
EXPECT_EQ(A, 1); // Original unchanged.
BoundA([](int AVal, int BVal) {
EXPECT_EQ(AVal, 3); // Stored copy incremented again.
EXPECT_EQ(BVal, 44);
});
EXPECT_EQ(A, 1); // Original still unchanged.
}
TEST(STLForwardCompatTest, BindBackMutableStorage) {
// With std::decay_t, A is copied into the wrapper. The stored copy can be
// mutated across calls, but the original A is unchanged.
int A = 1;
auto TestMutation = [](auto ExtraCheckFn, int &AArg, int &BArg) {
++AArg;
++BArg;
ExtraCheckFn(AArg, BArg);
};
auto BoundA = bind_back(TestMutation, A, 42);
BoundA([](int AVal, int BVal) {
EXPECT_EQ(AVal, 2); // Stored copy incremented from 1.
EXPECT_EQ(BVal, 43);
});
EXPECT_EQ(A, 1); // Original unchanged.
BoundA([](int AVal, int BVal) {
EXPECT_EQ(AVal, 3); // Stored copy incremented again.
EXPECT_EQ(BVal, 44);
});
EXPECT_EQ(A, 1); // Original still unchanged.
}
// Free function for compile-time bind tests.
static int subtract(int A, int B) { return A - B; }
TEST(STLForwardCompatTest, BindFrontConstexprCallable) {
// Test compile-time callable with bind_front.
auto TimesFive = bind_front<subtract>(5);
EXPECT_EQ(TimesFive(3), 2);
// Test compile-time callable with bind_back.
auto FiveTimesX = bind_back<subtract>(5);
EXPECT_EQ(FiveTimesX(3), -2);
}
TEST(STLForwardCompatTest, BindFrontBackNoBoundArgs) {
auto Fn1 = bind_front([](int A, int B) { return A + B; });
EXPECT_EQ(Fn1(3, 4), 7);
auto Fn2 = bind_back([](int A, int B) { return A + B; });
EXPECT_EQ(Fn2(3, 4), 7);
}
TEST(STLForwardCompatTest, BindFrontBindBackConstexpr) {
static constexpr auto Fn1 = bind_front([](int A, int B) { return A + B; }, 1);
static_assert(Fn1(3) == 4);
static constexpr auto Fn2 = bind_back([](int A, int B) { return A + B; }, 1);
static_assert(Fn2(3) == 4);
}
// Use std::ref/std::cref to bind references (bound args are decay-copied).
TEST(STLForwardCompatTest, BindWithReferenceWrapper) {
int X = 1;
auto Increment = bind_front([](int &Val) { ++Val; }, std::ref(X));
Increment();
EXPECT_EQ(X, 2);
Increment();
EXPECT_EQ(X, 3);
}
// The callable itself can have mutable state.
TEST(STLForwardCompatTest, BindMutableCallable) {
auto Counter = bind_front([N = 0]() mutable { return ++N; });
EXPECT_EQ(Counter(), 1);
EXPECT_EQ(Counter(), 2);
EXPECT_EQ(Counter(), 3);
}
namespace {
struct MemberTest {
int Value;
int scale(int Factor) const { return Value * Factor; }
};
} // namespace
TEST(STLForwardCompatTest, BindMembers) {
// Member function pointer support via std::apply (with std::invoke used
// internally).
MemberTest Obj{10};
auto ScaleObj = bind_front(&MemberTest::scale, Obj);
EXPECT_EQ(ScaleObj(3), 30);
auto ScaleBy5 = bind_back(&MemberTest::scale, 5);
EXPECT_EQ(ScaleBy5(Obj), 50);
// Member data pointer support via std::apply (with std::invoke used
// internally).
auto GetValue = bind_front(&MemberTest::Value);
EXPECT_EQ(GetValue(Obj), 10);
// Make sure we can use member data pointers for constexpr callables.
static constexpr int MemberVal =
bind_front(&MemberTest::Value)(MemberTest{10});
EXPECT_EQ(MemberVal, 10);
}
TEST(STLForwardCompatTest, BindFrontBindBack) {
std::vector<int> V;
auto MulAdd = [](int A, int B, int C) { return A * (B + C) == 12; };
auto MulAdd1 = [](const int &A, const int &B, const int &C) {
return A * (B + C) == 12;
};
auto Mul0 = bind_back(MulAdd, 4, 2);
auto MulL = bind_front(MulAdd1, 2, 4);
auto Mul20 = bind_back(MulAdd, 4);
auto Mul21 = bind_front(MulAdd1, 2);
EXPECT_TRUE(all_of(V, Mul0));
EXPECT_TRUE(all_of(V, MulL));
V.push_back(2);
EXPECT_TRUE(all_of(V, Mul0));
EXPECT_TRUE(all_of(V, MulL));
V.push_back(2);
V.push_back(2);
EXPECT_TRUE(all_of(V, Mul0));
EXPECT_TRUE(all_of(V, MulL));
auto Spec0 = bind_front(Mul20, 2);
auto Spec1 = bind_back(Mul21, 4);
EXPECT_TRUE(all_of(V, Spec0));
EXPECT_TRUE(all_of(V, Spec1));
V.push_back(3);
EXPECT_FALSE(all_of(V, Mul0));
EXPECT_FALSE(all_of(V, MulL));
EXPECT_FALSE(all_of(V, Spec0));
EXPECT_FALSE(all_of(V, Spec1));
EXPECT_TRUE(any_of(V, Spec0));
EXPECT_TRUE(any_of(V, Spec1));
}
} // namespace
} // namespace llvm