Files
Nikolas Klauser 12d8360727 [libc++] Add segmented iterator optimization to std::equal (#179242)
```
Benchmark                                                 97fa3e5936    a820f8f10736    Difference    % Difference
------------------------------------------------------  --------------  --------------  ------------  --------------
std::equal(deque<int>)_(it,_it,_it)/1024                        510.92           82.64       -428.27         -83.82%
std::equal(deque<int>)_(it,_it,_it)/1048576                  518795.61        87141.29    -431654.32         -83.20%
std::equal(deque<int>)_(it,_it,_it)/50                           29.24            6.77        -22.46         -76.84%
std::equal(deque<int>)_(it,_it,_it)/8                             4.20            3.71         -0.49         -11.61%
std::equal(deque<int>)_(it,_it,_it)/8192                       3972.84          643.83      -3329.01         -83.79%
std::equal(deque<int>)_(it,_it,_it,_it)/1024                    417.45           81.52       -335.93         -80.47%
std::equal(deque<int>)_(it,_it,_it,_it)/1048576              539228.26        87480.92    -451747.34         -83.78%
std::equal(deque<int>)_(it,_it,_it,_it)/50                       22.25            7.25        -15.00         -67.41%
std::equal(deque<int>)_(it,_it,_it,_it)/8                         4.75            4.44         -0.31          -6.45%
std::equal(deque<int>)_(it,_it,_it,_it)/8192                   3259.01          641.31      -2617.70         -80.32%
std::equal(deque<int>)_(it,_it,_it,_it,_pred)/1024              532.68          327.58       -205.10         -38.50%
std::equal(deque<int>)_(it,_it,_it,_it,_pred)/1048576        600755.28       402988.04    -197767.24         -32.92%
std::equal(deque<int>)_(it,_it,_it,_it,_pred)/50                 27.26           25.29         -1.97          -7.22%
std::equal(deque<int>)_(it,_it,_it,_it,_pred)/8                   5.20            5.58          0.38           7.31%
std::equal(deque<int>)_(it,_it,_it,_it,_pred)/8192             4204.16         2847.30      -1356.86         -32.27%
std::equal(deque<int>)_(it,_it,_it,_pred)/1024                  531.32          329.03       -202.30         -38.07%
std::equal(deque<int>)_(it,_it,_it,_pred)/1048576            598948.55       403822.65    -195125.90         -32.58%
std::equal(deque<int>)_(it,_it,_it,_pred)/50                     26.28           16.18        -10.10         -38.43%
std::equal(deque<int>)_(it,_it,_it,_pred)/8                       4.44            3.70         -0.74         -16.67%
std::equal(deque<int>)_(it,_it,_it,_pred)/8192                 4184.03         2902.98      -1281.05         -30.62%
std::equal(list<int>)_(it,_it,_it)/1024                        1168.78         1168.51         -0.27          -0.02%
std::equal(list<int>)_(it,_it,_it)/1048576                  1283003.12      1281885.44      -1117.69          -0.09%
std::equal(list<int>)_(it,_it,_it)/50                            60.19           44.38        -15.81         -26.27%
std::equal(list<int>)_(it,_it,_it)/8                              3.07            3.07          0.00           0.15%
std::equal(list<int>)_(it,_it,_it)/8192                       10367.41        11075.24        707.83           6.83%
std::equal(list<int>)_(it,_it,_it,_it)/1024                     728.32          734.18          5.86           0.80%
std::equal(list<int>)_(it,_it,_it,_it)/1048576               951276.58       953928.39       2651.81           0.28%
std::equal(list<int>)_(it,_it,_it,_it)/50                        31.86           32.32          0.46           1.44%
std::equal(list<int>)_(it,_it,_it,_it)/8                          3.11            3.10         -0.01          -0.34%
std::equal(list<int>)_(it,_it,_it,_it)/8192                   14940.68        16058.91       1118.22           7.48%
std::equal(list<int>)_(it,_it,_it,_it,_pred)/1024               803.49          813.53         10.05           1.25%
std::equal(list<int>)_(it,_it,_it,_it,_pred)/1048576        1012708.15      1026207.55      13499.40           1.33%
std::equal(list<int>)_(it,_it,_it,_it,_pred)/50                  38.68           39.24          0.56           1.46%
std::equal(list<int>)_(it,_it,_it,_it,_pred)/8                    4.07            4.07         -0.00          -0.08%
std::equal(list<int>)_(it,_it,_it,_it,_pred)/8192             16632.08        18073.63       1441.55           8.67%
std::equal(list<int>)_(it,_it,_it,_pred)/1024                  1162.99         1162.48         -0.51          -0.04%
std::equal(list<int>)_(it,_it,_it,_pred)/1048576            1291522.30      1303819.01      12296.72           0.95%
std::equal(list<int>)_(it,_it,_it,_pred)/50                      45.73           46.32          0.59           1.29%
std::equal(list<int>)_(it,_it,_it,_pred)/8                        4.35            4.40          0.04           1.03%
std::equal(list<int>)_(it,_it,_it,_pred)/8192                 15035.93        14598.06       -437.87          -2.91%
std::equal(vector<bool>)_(aligned)/1024                           0.22            0.22          0.00           0.04%
std::equal(vector<bool>)_(aligned)/1048576                        0.22            0.22          0.00           0.12%
std::equal(vector<bool>)_(aligned)/50                             0.22            0.22          0.00           0.02%
std::equal(vector<bool>)_(aligned)/8                              0.22            0.22          0.00           0.03%
std::equal(vector<bool>)_(aligned)/8192                           0.22            0.22          0.00           0.05%
std::equal(vector<bool>)_(unaligned)/1024                         6.34            6.39          0.04           0.70%
std::equal(vector<bool>)_(unaligned)/1048576                   6809.31         6833.52         24.21           0.36%
std::equal(vector<bool>)_(unaligned)/50                           1.11            0.92         -0.19         -17.55%
std::equal(vector<bool>)_(unaligned)/8                            1.11            1.05         -0.06          -5.29%
std::equal(vector<bool>)_(unaligned)/8192                        59.27           59.92          0.65           1.10%
std::equal(vector<int>)_(it,_it,_it)/1024                        80.39           80.59          0.20           0.25%
std::equal(vector<int>)_(it,_it,_it)/1048576                  72546.36        73803.43       1257.07           1.73%
std::equal(vector<int>)_(it,_it,_it)/50                           3.92            4.43          0.51          12.92%
std::equal(vector<int>)_(it,_it,_it)/8                            1.46            1.47          0.01           0.75%
std::equal(vector<int>)_(it,_it,_it)/8192                       553.63          559.59          5.95           1.07%
std::equal(vector<int>)_(it,_it,_it,_it)/1024                    78.69           78.37         -0.32          -0.40%
std::equal(vector<int>)_(it,_it,_it,_it)/1048576              72238.51        73582.13       1343.62           1.86%
std::equal(vector<int>)_(it,_it,_it,_it)/50                       4.18            4.62          0.44          10.52%
std::equal(vector<int>)_(it,_it,_it,_it)/8                        1.68            1.66         -0.01          -0.87%
std::equal(vector<int>)_(it,_it,_it,_it)/8192                   549.35          555.24          5.89           1.07%
std::equal(vector<int>)_(it,_it,_it,_it,_pred)/1024             361.08          363.32          2.24           0.62%
std::equal(vector<int>)_(it,_it,_it,_it,_pred)/1048576       391367.63       394209.88       2842.25           0.73%
std::equal(vector<int>)_(it,_it,_it,_it,_pred)/50                15.24           15.83          0.59           3.87%
std::equal(vector<int>)_(it,_it,_it,_it,_pred)/8                  3.18            3.19          0.01           0.40%
std::equal(vector<int>)_(it,_it,_it,_it,_pred)/8192            2992.57         3026.90         34.32           1.15%
std::equal(vector<int>)_(it,_it,_it,_pred)/1024                 362.45          365.46          3.01           0.83%
std::equal(vector<int>)_(it,_it,_it,_pred)/1048576           399898.16       402718.88       2820.72           0.71%
std::equal(vector<int>)_(it,_it,_it,_pred)/50                    14.79           14.79         -0.01          -0.04%
std::equal(vector<int>)_(it,_it,_it,_pred)/8                      2.45            2.52          0.06           2.64%
std::equal(vector<int>)_(it,_it,_it,_pred)/8192                3062.16         3088.11         25.95           0.85%
Geomean                                                         253.49          200.79        -52.70         -20.79%
```
2026-02-24 10:23:15 +01:00

326 lines
13 KiB
C++

// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef _LIBCPP___ALGORITHM_EQUAL_H
#define _LIBCPP___ALGORITHM_EQUAL_H
#include <__algorithm/comp.h>
#include <__algorithm/find_segment_if.h>
#include <__algorithm/min.h>
#include <__algorithm/unwrap_iter.h>
#include <__config>
#include <__functional/identity.h>
#include <__fwd/bit_reference.h>
#include <__iterator/iterator_traits.h>
#include <__iterator/segmented_iterator.h>
#include <__string/constexpr_c_functions.h>
#include <__type_traits/common_type.h>
#include <__type_traits/desugars_to.h>
#include <__type_traits/enable_if.h>
#include <__type_traits/invoke.h>
#include <__type_traits/is_equality_comparable.h>
#include <__type_traits/is_same.h>
#include <__type_traits/is_volatile.h>
#include <__utility/move.h>
#include <__utility/unreachable.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
#endif
_LIBCPP_PUSH_MACROS
#include <__undef_macros>
_LIBCPP_BEGIN_NAMESPACE_STD
template <class _Cp, bool _IsConst1, bool _IsConst2>
[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool
__equal_unaligned(__bit_iterator<_Cp, _IsConst1> __first1,
__bit_iterator<_Cp, _IsConst1> __last1,
__bit_iterator<_Cp, _IsConst2> __first2) {
using _It = __bit_iterator<_Cp, _IsConst1>;
using difference_type = typename _It::difference_type;
using __storage_type = typename _It::__storage_type;
const int __bits_per_word = _It::__bits_per_word;
difference_type __n = __last1 - __first1;
if (__n > 0) {
// do first word
if (__first1.__ctz_ != 0) {
unsigned __clz_f = __bits_per_word - __first1.__ctz_;
difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n);
__n -= __dn;
__storage_type __m = std::__middle_mask<__storage_type>(__clz_f - __dn, __first1.__ctz_);
__storage_type __b = *__first1.__seg_ & __m;
unsigned __clz_r = __bits_per_word - __first2.__ctz_;
__storage_type __ddn = std::min<__storage_type>(__dn, __clz_r);
__m = std::__middle_mask<__storage_type>(__clz_r - __ddn, __first2.__ctz_);
if (__first2.__ctz_ > __first1.__ctz_) {
if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
static_cast<__storage_type>(__b << (__first2.__ctz_ - __first1.__ctz_)))
return false;
} else {
if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
static_cast<__storage_type>(__b >> (__first1.__ctz_ - __first2.__ctz_)))
return false;
}
__first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word;
__first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word);
__dn -= __ddn;
if (__dn > 0) {
__m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
if (static_cast<__storage_type>(*__first2.__seg_ & __m) !=
static_cast<__storage_type>(__b >> (__first1.__ctz_ + __ddn)))
return false;
__first2.__ctz_ = static_cast<unsigned>(__dn);
}
++__first1.__seg_;
// __first1.__ctz_ = 0;
}
// __first1.__ctz_ == 0;
// do middle words
unsigned __clz_r = __bits_per_word - __first2.__ctz_;
__storage_type __m = std::__leading_mask<__storage_type>(__first2.__ctz_);
for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) {
__storage_type __b = *__first1.__seg_;
if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
return false;
++__first2.__seg_;
if (static_cast<__storage_type>(*__first2.__seg_ & static_cast<__storage_type>(~__m)) !=
static_cast<__storage_type>(__b >> __clz_r))
return false;
}
// do last word
if (__n > 0) {
__m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
__storage_type __b = *__first1.__seg_ & __m;
__storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r));
__m = std::__middle_mask<__storage_type>(__clz_r - __dn, __first2.__ctz_);
if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b << __first2.__ctz_))
return false;
__first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word;
__first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word);
__n -= __dn;
if (__n > 0) {
__m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
if (static_cast<__storage_type>(*__first2.__seg_ & __m) != static_cast<__storage_type>(__b >> __dn))
return false;
}
}
}
return true;
}
template <class _Cp, bool _IsConst1, bool _IsConst2>
[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool
__equal_aligned(__bit_iterator<_Cp, _IsConst1> __first1,
__bit_iterator<_Cp, _IsConst1> __last1,
__bit_iterator<_Cp, _IsConst2> __first2) {
using _It = __bit_iterator<_Cp, _IsConst1>;
using difference_type = typename _It::difference_type;
using __storage_type = typename _It::__storage_type;
const int __bits_per_word = _It::__bits_per_word;
difference_type __n = __last1 - __first1;
if (__n > 0) {
// do first word
if (__first1.__ctz_ != 0) {
unsigned __clz = __bits_per_word - __first1.__ctz_;
difference_type __dn = std::min(static_cast<difference_type>(__clz), __n);
__n -= __dn;
__storage_type __m = std::__middle_mask<__storage_type>(__clz - __dn, __first1.__ctz_);
if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
return false;
++__first2.__seg_;
++__first1.__seg_;
// __first1.__ctz_ = 0;
// __first2.__ctz_ = 0;
}
// __first1.__ctz_ == 0;
// __first2.__ctz_ == 0;
// do middle words
for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_, ++__first2.__seg_)
if (*__first2.__seg_ != *__first1.__seg_)
return false;
// do last word
if (__n > 0) {
__storage_type __m = std::__trailing_mask<__storage_type>(__bits_per_word - __n);
if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
return false;
}
}
return true;
}
template <class _Cp,
bool _IsConst1,
bool _IsConst2,
class _BinaryPredicate,
class _Proj1,
class _Proj2,
__enable_if_t<__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
__desugars_to_v<__equal_tag, _BinaryPredicate, bool, bool>,
int> = 0>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
__bit_iterator<_Cp, _IsConst1> __first1,
__bit_iterator<_Cp, _IsConst1> __last1,
__bit_iterator<_Cp, _IsConst2> __first2,
_BinaryPredicate,
_Proj1&,
_Proj2&) {
if (__first1.__ctz_ == __first2.__ctz_)
return std::__equal_aligned(__first1, __last1, __first2);
return std::__equal_unaligned(__first1, __last1, __first2);
}
template <class _Tp,
class _Up,
class _BinaryPredicate,
class _Proj1,
class _Proj2,
__enable_if_t<__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
__desugars_to_v<__equal_tag, _BinaryPredicate, _Tp, _Up> && !is_volatile<_Tp>::value &&
!is_volatile<_Up>::value && __is_trivially_equality_comparable_v<_Tp, _Up>,
int> = 0>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
__equal_iter_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _BinaryPredicate&, _Proj1&, _Proj2&) {
return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
}
template <class _InIter1, class _Sent1, class _InIter2, class _Pred, class _Proj1, class _Proj2>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
_InIter1 __first1, _Sent1 __last1, _InIter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
#ifndef _LIBCPP_CXX03_LANG
if constexpr (__has_random_access_iterator_category<_InIter1>::value &&
__has_random_access_iterator_category<_InIter2>::value) {
if constexpr (is_same<_InIter1, _Sent1>::value && __is_segmented_iterator_v<_InIter1>) {
using __local_iterator_t = typename __segmented_iterator_traits<_InIter1>::__local_iterator;
bool __is_equal = true;
std::__find_segment_if(__first1, __last1, [&](__local_iterator_t __lfirst, __local_iterator_t __llast) {
if (std::__equal_iter_impl(
std::__unwrap_iter(__lfirst), std::__unwrap_iter(__llast), __first2, __pred, __proj1, __proj2)) {
__first2 += __llast - __lfirst;
return __llast;
}
__is_equal = false;
return __lfirst;
});
return __is_equal;
} else if constexpr (__is_segmented_iterator_v<_InIter2>) {
using _Traits = __segmented_iterator_traits<_InIter2>;
using _DiffT =
typename common_type<__iterator_difference_type<_InIter1>, __iterator_difference_type<_InIter2> >::type;
if (__first1 == __last1)
return true;
auto __local_first = _Traits::__local(__first2);
auto __segment_iterator = _Traits::__segment(__first2);
while (true) {
auto __local_last = _Traits::__end(__segment_iterator);
auto __size = std::min<_DiffT>(__local_last - __local_first, __last1 - __first1);
if (!std::__equal_iter_impl(
__first1, __first1 + __size, std::__unwrap_iter(__local_first), __pred, __proj1, __proj2))
return false;
__first1 += __size;
if (__first1 == __last1)
return true;
__local_first = _Traits::__begin(++__segment_iterator);
}
}
}
#endif
for (; __first1 != __last1; ++__first1, (void)++__first2)
if (!std::__invoke(__pred, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
return false;
return true;
}
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
__identity __proj;
return std::__equal_iter_impl(
std::__unwrap_iter(__first1), std::__unwrap_iter(__last1), std::__unwrap_iter(__first2), __pred, __proj, __proj);
}
template <class _InputIterator1, class _InputIterator2>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2) {
return std::equal(__first1, __last1, __first2, __equal_to());
}
#if _LIBCPP_STD_VER >= 14
template <bool __known_equal_length,
class _Iter1,
class _Sent1,
class _Iter2,
class _Sent2,
class _Pred,
class _Proj1,
class _Proj2>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl(
_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Sent2 __last2, _Pred& __comp, _Proj1& __proj1, _Proj2& __proj2) {
if constexpr (__known_equal_length) {
return std::__equal_iter_impl(
std::move(__first1), std::move(__last1), std::move(__first2), __comp, __proj1, __proj2);
} else {
while (__first1 != __last1 && __first2 != __last2) {
if (!std::__invoke(__comp, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
return false;
++__first1;
++__first2;
}
return __first1 == __last1 && __first2 == __last2;
}
}
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
equal(_InputIterator1 __first1,
_InputIterator1 __last1,
_InputIterator2 __first2,
_InputIterator2 __last2,
_BinaryPredicate __pred) {
constexpr bool __both_random_access =
__has_random_access_iterator_category<_InputIterator1>::value &&
__has_random_access_iterator_category<_InputIterator2>::value;
if constexpr (__both_random_access) {
if (__last1 - __first1 != __last2 - __first2)
return false;
}
__identity __proj;
return std::__equal_impl<__both_random_access>(
std::__unwrap_iter(__first1),
std::__unwrap_iter(__last1),
std::__unwrap_iter(__first2),
std::__unwrap_iter(__last2),
__pred,
__proj,
__proj);
}
template <class _InputIterator1, class _InputIterator2>
[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
equal(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _InputIterator2 __last2) {
return std::equal(__first1, __last1, __first2, __last2, __equal_to());
}
#endif // _LIBCPP_STD_VER >= 14
_LIBCPP_END_NAMESPACE_STD
_LIBCPP_POP_MACROS
#endif // _LIBCPP___ALGORITHM_EQUAL_H