diff --git a/src/core/dash.h b/src/core/dash.h index 3c34b16e5b99..a7fd01fc257b 100644 --- a/src/core/dash.h +++ b/src/core/dash.h @@ -149,6 +149,10 @@ class DashTable : public detail::DashTableBase { template const_iterator Find(U&& key) const; template iterator Find(U&& key); + // Find first entry with given key hash that evaulates to true on pred. + // Pred accepts either (const key&) or (const key&, const value&) + template iterator FindFirst(uint64_t key_hash, Pred&& pred); + // it must be valid. void Erase(iterator it); @@ -308,8 +312,8 @@ class DashTable : public detail::DashTableBase { // the same object. IterateDistinct goes over all distinct segments in the table. template void IterateDistinct(Cb&& cb); - auto EqPred() const { - return [p = &policy_](const auto& a, const auto& b) -> bool { return p->Equal(a, b); }; + template auto EqPred(const K& key) const { + return [p = &policy_, &key](const auto& probe) -> bool { return p->Equal(probe, key); }; } Policy policy_; @@ -669,32 +673,31 @@ template auto DashTable<_Key, _Value, Policy>::Find(U&& key) const -> const_iterator { uint64_t key_hash = DoHash(key); uint32_t seg_id = SegmentId(key_hash); // seg_id takes up global_depth_ high bits. - const auto* target = segment_[seg_id]; // Hash structure is like this: [SSUUUUBF], where S is segment id, U - unused, // B - bucket id and F is a fingerprint. Segment id is needed to identify the correct segment. // Once identified, the segment instance uses the lower part of hash to locate the key. // It uses 8 least significant bits for a fingerprint and few more bits for bucket id. - auto seg_it = target->FindIt(key, key_hash, EqPred()); - - if (seg_it.found()) { - return const_iterator{this, seg_id, seg_it.index, seg_it.slot}; + if (auto seg_it = segment_[seg_id]->FindIt(key_hash, EqPred(key)); seg_it.found()) { + return {this, seg_id, seg_it.index, seg_it.slot}; } - return const_iterator{}; + return {}; } template template auto DashTable<_Key, _Value, Policy>::Find(U&& key) -> iterator { - uint64_t key_hash = DoHash(key); - uint32_t segid = SegmentId(key_hash); - const auto* target = segment_[segid]; + return FindFirst(DoHash(key), EqPred(key)); +} - auto seg_it = target->FindIt(key, key_hash, EqPred()); - if (seg_it.found()) { - return iterator{this, segid, seg_it.index, seg_it.slot}; +template +template +auto DashTable<_Key, _Value, Policy>::FindFirst(uint64_t key_hash, Pred&& pred) -> iterator { + uint32_t seg_id = SegmentId(key_hash); + if (auto seg_it = segment_[seg_id]->FindIt(key_hash, pred); seg_it.found()) { + return {this, seg_id, seg_it.index, seg_it.slot}; } - return iterator{}; + return {}; } template @@ -702,7 +705,7 @@ size_t DashTable<_Key, _Value, Policy>::Erase(const Key_t& key) { uint64_t key_hash = DoHash(key); size_t x = SegmentId(key_hash); auto* target = segment_[x]; - auto it = target->FindIt(key, key_hash, EqPred()); + auto it = target->FindIt(key_hash, EqPred(key)); if (!it.found()) return 0; @@ -764,7 +767,7 @@ auto DashTable<_Key, _Value, Policy>::InsertInternal(U&& key, V&& value, Evictio res = it.found(); } else { std::tie(it, res) = - target->Insert(std::forward(key), std::forward(value), key_hash, EqPred()); + target->Insert(std::forward(key), std::forward(value), key_hash, EqPred(key)); } if (res) { // success diff --git a/src/core/dash_internal.h b/src/core/dash_internal.h index c335fedfce79..5a59c40f9787 100644 --- a/src/core/dash_internal.h +++ b/src/core/dash_internal.h @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include "core/sse_port.h" @@ -328,8 +328,7 @@ template SetHash(slot, meta_hash, probe); } - template - SlotId FindByFp(uint8_t fp_hash, bool probe, U&& k, Pred&& pred) const; + template SlotId FindByFp(uint8_t fp_hash, bool probe, Pred&& pred) const; bool ShiftRight(); @@ -403,7 +402,7 @@ template - std::pair Insert(K&& key, V&& value, Hash_t key_hash, Pred&& cmp_fun); + std::pair Insert(K&& key, V&& value, Hash_t key_hash, Pred&& pred); template void Split(HashFn&& hfunc, Segment* dest); @@ -501,7 +500,8 @@ template Iterator FindIt(U&& key, Hash_t key_hash, Pred&& cf) const; + // Find item with given key hash and truthy predicate + template Iterator FindIt(Hash_t key_hash, Pred&& pred) const; // Returns valid iterator if succeeded or invalid if not (it's full). // Requires: key should be not present in the segment. @@ -1021,9 +1021,9 @@ ___] |___ |__] | | |___ | \| | */ template -template -auto Segment::Bucket::FindByFp(uint8_t fp_hash, bool probe, U&& k, - Pred&& pred) const -> SlotId { +template +auto Segment::Bucket::FindByFp(uint8_t fp_hash, bool probe, Pred&& pred) const + -> SlotId { unsigned mask = this->Find(fp_hash, probe); if (!mask) return kNanSlot; @@ -1031,9 +1031,18 @@ auto Segment::Bucket::FindByFp(uint8_t fp_hash, bool probe, unsigned delta = __builtin_ctz(mask); mask >>= delta; for (unsigned i = delta; i < kSlotNum; ++i) { - if ((mask & 1) && pred(key[i], k)) { - return i; + // Filterable just by key + if constexpr (std::is_invocable_v) { + if ((mask & 1) && pred(key[i])) + return i; } + + // Filterable by key and value + if constexpr (std::is_invocable_v) { + if ((mask & 1) && pred(key[i], value[i])) + return i; + } + mask >>= 1; }; @@ -1094,9 +1103,9 @@ auto Segment::TryMoveFromStash(unsigned stash_id, unsigned s template template -auto Segment::Insert(U&& key, V&& value, Hash_t key_hash, Pred&& cmp_fun) +auto Segment::Insert(U&& key, V&& value, Hash_t key_hash, Pred&& pred) -> std::pair { - Iterator it = FindIt(key, key_hash, std::forward(cmp_fun)); + Iterator it = FindIt(key_hash, pred); if (it.found()) { return std::make_pair(it, false); /* duplicate insert*/ } @@ -1107,8 +1116,8 @@ auto Segment::Insert(U&& key, V&& value, Hash_t key_hash, Pr } template -template -auto Segment::FindIt(U&& key, Hash_t key_hash, Pred&& cf) const -> Iterator { +template +auto Segment::FindIt(Hash_t key_hash, Pred&& pred) const -> Iterator { uint8_t bidx = BucketIndex(key_hash); const Bucket& target = bucket_[bidx]; @@ -1117,7 +1126,7 @@ auto Segment::FindIt(U&& key, Hash_t key_hash, Pred&& cf) co __builtin_prefetch(&target); uint8_t fp_hash = key_hash & kFpMask; - SlotId sid = target.FindByFp(fp_hash, false, key, cf); + SlotId sid = target.FindByFp(fp_hash, false, pred); if (sid != BucketType::kNanSlot) { return Iterator{bidx, sid}; } @@ -1125,7 +1134,7 @@ auto Segment::FindIt(U&& key, Hash_t key_hash, Pred&& cf) co uint8_t nid = NextBid(bidx); const Bucket& probe = bucket_[nid]; - sid = probe.FindByFp(fp_hash, true, key, cf); + sid = probe.FindByFp(fp_hash, true, pred); #ifdef ENABLE_DASH_STATS stats.neighbour_probes++; @@ -1144,7 +1153,7 @@ auto Segment::FindIt(U&& key, Hash_t key_hash, Pred&& cf) co pos += kBucketNum; const Bucket& bucket = bucket_[pos]; - return bucket.FindByFp(fp_hash, false, key, cf); + return bucket.FindByFp(fp_hash, false, pred); }; if (target.HasStashOverflow()) { diff --git a/src/core/dash_test.cc b/src/core/dash_test.cc index 3bfacee38975..628c2eb7b918 100644 --- a/src/core/dash_test.cc +++ b/src/core/dash_test.cc @@ -36,6 +36,10 @@ static uint64_t callbackHash(const void* key) { return XXH64(&key, sizeof(key), 0); } +template auto EqTo(const K& key) { + return [&key](const auto& probe) { return key == probe; }; +} + static dictType IntDict = {callbackHash, NULL, NULL, NULL, NULL, NULL, NULL}; static uint64_t dictSdsHash(const void* key) { @@ -136,8 +140,7 @@ class DashTest : public testing::Test { bool Find(Segment::Key_t key, Segment::Value_t* val) const { uint64_t hash = dt_.DoHash(key); - std::equal_to eq; - auto it = segment_.FindIt(key, hash, eq); + auto it = segment_.FindIt(hash, EqTo(key)); if (!it.found()) return false; *val = segment_.Value(it.index, it.slot); @@ -146,9 +149,7 @@ class DashTest : public testing::Test { bool Contains(Segment::Key_t key) const { uint64_t hash = dt_.DoHash(key); - - std::equal_to eq; - auto it = segment_.FindIt(key, hash, eq); + auto it = segment_.FindIt(hash, EqTo(key)); return it.found(); } @@ -161,7 +162,6 @@ class DashTest : public testing::Test { set DashTest::FillSegment(unsigned bid) { std::set keys; - std::equal_to eq; for (Segment::Key_t key = 0; key < 1000000u; ++key) { uint64_t hash = dt_.DoHash(key); unsigned bi = (hash >> 8) % Segment::kBucketNum; @@ -170,7 +170,7 @@ set DashTest::FillSegment(unsigned bid) { uint8_t fp = hash & 0xFF; if (fp > 2) // limit fps considerably to find interesting cases. continue; - auto [it, success] = segment_.Insert(key, 0, hash, eq); + auto [it, success] = segment_.Insert(key, 0, hash, EqTo(key)); if (!success) { LOG(INFO) << "Stopped at " << key; break; @@ -203,10 +203,9 @@ TEST_F(DashTest, Basic) { Segment::Key_t key = 0; Segment::Value_t val = 0; uint64_t hash = dt_.DoHash(key); - std::equal_to eq; - EXPECT_TRUE(segment_.Insert(key, val, hash, eq).second); - auto [it, res] = segment_.Insert(key, val, hash, eq); + EXPECT_TRUE(segment_.Insert(key, val, hash, EqTo(key)).second); + auto [it, res] = segment_.Insert(key, val, hash, EqTo(key)); EXPECT_TRUE(!res && it.found()); EXPECT_TRUE(Find(key, &val)); @@ -262,10 +261,10 @@ TEST_F(DashTest, Segment) { const auto* k = &segment_.Key(i, 0); next = std::copy(k, k + Segment::kSlotNum, next); } - std::equal_to eq; + for (auto k : arr) { auto hash = hfun(k); - auto it = segment_.FindIt(k, hash, eq); + auto it = segment_.FindIt(hash, [&k](const auto& probe) { return k == probe; }); ASSERT_TRUE(it.found()); segment_.Delete(it, hash); } @@ -319,10 +318,10 @@ TEST_F(DashTest, Split) { segment_.Split(&UInt64Policy::HashFn, &s2); unsigned sum[2] = {0}; - std::equal_to eq; for (auto key : keys) { - auto it1 = segment_.FindIt(key, dt_.DoHash(key), eq); - auto it2 = s2.FindIt(key, dt_.DoHash(key), eq); + auto eq = [key](const auto& probe) { return key == probe; }; + auto it1 = segment_.FindIt(dt_.DoHash(key), eq); + auto it2 = s2.FindIt(dt_.DoHash(key), eq); ASSERT_NE(it1.found(), it2.found()) << key; sum[0] += it1.found(); @@ -476,12 +475,27 @@ TEST_F(DashTest, Custom) { (void)kBuckSz; ItemSegment seg{2}; - auto cb = [](auto v, auto u) { return v.buf[0] == u.buf[0] && v.buf[1] == u.buf[1]; }; - auto it = seg.FindIt(Item{1, 1}, 42, cb); + auto eq = [v = Item{1, 1}](auto u) { return v.buf[0] == u.buf[0] && v.buf[1] == u.buf[1]; }; + auto it = seg.FindIt(42, eq); ASSERT_FALSE(it.found()); } +TEST_F(DashTest, FindByValue) { + using ItemSegment = detail::Segment; + + // Insert three different values with the same hash + ItemSegment segment{2}; + segment.Insert(Item{1}, 1, 42, [](const auto& pred) { return pred.buf[0] == 1; }); + segment.Insert(Item{2}, 2, 42, [](const auto& pred) { return pred.buf[0] == 2; }); + segment.Insert(Item{3}, 3, 42, [](const auto& pred) { return pred.buf[0] == 3; }); + + // We should be able to find the middle one by value + auto it = segment.FindIt(42, [](const auto& key, const auto& value) { return value == 2; }); + EXPECT_TRUE(it.found()); + EXPECT_EQ(segment.Value(it.index, it.slot), 2); +} + TEST_F(DashTest, Reserve) { unsigned bc = dt_.capacity(); for (unsigned i = 0; i <= bc * 2; ++i) {