Skip to content

Commit

Permalink
chore(dash): Replace comparator with predicate (#3025)
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg authored May 8, 2024
1 parent 25e6930 commit d675e63
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 51 deletions.
37 changes: 20 additions & 17 deletions src/core/dash.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class DashTable : public detail::DashTableBase {
template <typename U> const_iterator Find(U&& key) const;
template <typename U> iterator Find(U&& key);

// Find first entry with given key hash that evaulates to true on pred.
// Pred accepts either (const key&) or (const key&, const value&)
template <typename Pred> iterator FindFirst(uint64_t key_hash, Pred&& pred);

// it must be valid.
void Erase(iterator it);

Expand Down Expand Up @@ -308,8 +312,8 @@ class DashTable : public detail::DashTableBase {
// the same object. IterateDistinct goes over all distinct segments in the table.
template <typename Cb> void IterateDistinct(Cb&& cb);

auto EqPred() const {
return [p = &policy_](const auto& a, const auto& b) -> bool { return p->Equal(a, b); };
template <typename K> auto EqPred(const K& key) const {
return [p = &policy_, &key](const auto& probe) -> bool { return p->Equal(probe, key); };
}

Policy policy_;
Expand Down Expand Up @@ -669,40 +673,39 @@ template <typename U>
auto DashTable<_Key, _Value, Policy>::Find(U&& key) const -> const_iterator {
uint64_t key_hash = DoHash(key);
uint32_t seg_id = SegmentId(key_hash); // seg_id takes up global_depth_ high bits.
const auto* target = segment_[seg_id];

// Hash structure is like this: [SSUUUUBF], where S is segment id, U - unused,
// B - bucket id and F is a fingerprint. Segment id is needed to identify the correct segment.
// Once identified, the segment instance uses the lower part of hash to locate the key.
// It uses 8 least significant bits for a fingerprint and few more bits for bucket id.
auto seg_it = target->FindIt(key, key_hash, EqPred());

if (seg_it.found()) {
return const_iterator{this, seg_id, seg_it.index, seg_it.slot};
if (auto seg_it = segment_[seg_id]->FindIt(key_hash, EqPred(key)); seg_it.found()) {
return {this, seg_id, seg_it.index, seg_it.slot};
}
return const_iterator{};
return {};
}

template <typename _Key, typename _Value, typename Policy>
template <typename U>
auto DashTable<_Key, _Value, Policy>::Find(U&& key) -> iterator {
uint64_t key_hash = DoHash(key);
uint32_t segid = SegmentId(key_hash);
const auto* target = segment_[segid];
return FindFirst(DoHash(key), EqPred(key));
}

auto seg_it = target->FindIt(key, key_hash, EqPred());
if (seg_it.found()) {
return iterator{this, segid, seg_it.index, seg_it.slot};
template <typename _Key, typename _Value, typename Policy>
template <typename Pred>
auto DashTable<_Key, _Value, Policy>::FindFirst(uint64_t key_hash, Pred&& pred) -> iterator {
uint32_t seg_id = SegmentId(key_hash);
if (auto seg_it = segment_[seg_id]->FindIt(key_hash, pred); seg_it.found()) {
return {this, seg_id, seg_it.index, seg_it.slot};
}
return iterator{};
return {};
}

template <typename _Key, typename _Value, typename Policy>
size_t DashTable<_Key, _Value, Policy>::Erase(const Key_t& key) {
uint64_t key_hash = DoHash(key);
size_t x = SegmentId(key_hash);
auto* target = segment_[x];
auto it = target->FindIt(key, key_hash, EqPred());
auto it = target->FindIt(key_hash, EqPred(key));
if (!it.found())
return 0;

Expand Down Expand Up @@ -764,7 +767,7 @@ auto DashTable<_Key, _Value, Policy>::InsertInternal(U&& key, V&& value, Evictio
res = it.found();
} else {
std::tie(it, res) =
target->Insert(std::forward<U>(key), std::forward<V>(value), key_hash, EqPred());
target->Insert(std::forward<U>(key), std::forward<V>(value), key_hash, EqPred(key));
}

if (res) { // success
Expand Down
43 changes: 26 additions & 17 deletions src/core/dash_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <functional>
#include <type_traits>

#include "core/sse_port.h"

Expand Down Expand Up @@ -328,8 +328,7 @@ template <typename _Key, typename _Value, typename Policy = DefaultSegmentPolicy
this->SetHash(slot, meta_hash, probe);
}

template <typename U, typename Pred>
SlotId FindByFp(uint8_t fp_hash, bool probe, U&& k, Pred&& pred) const;
template <typename Pred> SlotId FindByFp(uint8_t fp_hash, bool probe, Pred&& pred) const;

bool ShiftRight();

Expand Down Expand Up @@ -403,7 +402,7 @@ template <typename _Key, typename _Value, typename Policy = DefaultSegmentPolicy
// Returns (iterator, true) if insert succeeds,
// (iterator, false) for duplicate and (invalid-iterator, false) if it's full
template <typename K, typename V, typename Pred>
std::pair<Iterator, bool> Insert(K&& key, V&& value, Hash_t key_hash, Pred&& cmp_fun);
std::pair<Iterator, bool> Insert(K&& key, V&& value, Hash_t key_hash, Pred&& pred);

template <typename HashFn> void Split(HashFn&& hfunc, Segment* dest);

Expand Down Expand Up @@ -501,7 +500,8 @@ template <typename _Key, typename _Value, typename Policy = DefaultSegmentPolicy
dest[3] = NextBid(dest[2]);
}

template <typename U, typename Pred> Iterator FindIt(U&& key, Hash_t key_hash, Pred&& cf) const;
// Find item with given key hash and truthy predicate
template <typename Pred> Iterator FindIt(Hash_t key_hash, Pred&& pred) const;

// Returns valid iterator if succeeded or invalid if not (it's full).
// Requires: key should be not present in the segment.
Expand Down Expand Up @@ -1021,19 +1021,28 @@ ___] |___ |__] | | |___ | \| |
*/

template <typename Key, typename Value, typename Policy>
template <typename U, typename Pred>
auto Segment<Key, Value, Policy>::Bucket::FindByFp(uint8_t fp_hash, bool probe, U&& k,
Pred&& pred) const -> SlotId {
template <typename Pred>
auto Segment<Key, Value, Policy>::Bucket::FindByFp(uint8_t fp_hash, bool probe, Pred&& pred) const
-> SlotId {
unsigned mask = this->Find(fp_hash, probe);
if (!mask)
return kNanSlot;

unsigned delta = __builtin_ctz(mask);
mask >>= delta;
for (unsigned i = delta; i < kSlotNum; ++i) {
if ((mask & 1) && pred(key[i], k)) {
return i;
// Filterable just by key
if constexpr (std::is_invocable_v<Pred, const Key_t&>) {
if ((mask & 1) && pred(key[i]))
return i;
}

// Filterable by key and value
if constexpr (std::is_invocable_v<Pred, const Key_t&, const Value_t&>) {
if ((mask & 1) && pred(key[i], value[i]))
return i;
}

mask >>= 1;
};

Expand Down Expand Up @@ -1094,9 +1103,9 @@ auto Segment<Key, Value, Policy>::TryMoveFromStash(unsigned stash_id, unsigned s

template <typename Key, typename Value, typename Policy>
template <typename U, typename V, typename Pred>
auto Segment<Key, Value, Policy>::Insert(U&& key, V&& value, Hash_t key_hash, Pred&& cmp_fun)
auto Segment<Key, Value, Policy>::Insert(U&& key, V&& value, Hash_t key_hash, Pred&& pred)
-> std::pair<Iterator, bool> {
Iterator it = FindIt(key, key_hash, std::forward<Pred>(cmp_fun));
Iterator it = FindIt(key_hash, pred);
if (it.found()) {
return std::make_pair(it, false); /* duplicate insert*/
}
Expand All @@ -1107,8 +1116,8 @@ auto Segment<Key, Value, Policy>::Insert(U&& key, V&& value, Hash_t key_hash, Pr
}

template <typename Key, typename Value, typename Policy>
template <typename U, typename Pred>
auto Segment<Key, Value, Policy>::FindIt(U&& key, Hash_t key_hash, Pred&& cf) const -> Iterator {
template <typename Pred>
auto Segment<Key, Value, Policy>::FindIt(Hash_t key_hash, Pred&& pred) const -> Iterator {
uint8_t bidx = BucketIndex(key_hash);
const Bucket& target = bucket_[bidx];

Expand All @@ -1117,15 +1126,15 @@ auto Segment<Key, Value, Policy>::FindIt(U&& key, Hash_t key_hash, Pred&& cf) co
__builtin_prefetch(&target);

uint8_t fp_hash = key_hash & kFpMask;
SlotId sid = target.FindByFp(fp_hash, false, key, cf);
SlotId sid = target.FindByFp(fp_hash, false, pred);
if (sid != BucketType::kNanSlot) {
return Iterator{bidx, sid};
}

uint8_t nid = NextBid(bidx);
const Bucket& probe = bucket_[nid];

sid = probe.FindByFp(fp_hash, true, key, cf);
sid = probe.FindByFp(fp_hash, true, pred);

#ifdef ENABLE_DASH_STATS
stats.neighbour_probes++;
Expand All @@ -1144,7 +1153,7 @@ auto Segment<Key, Value, Policy>::FindIt(U&& key, Hash_t key_hash, Pred&& cf) co

pos += kBucketNum;
const Bucket& bucket = bucket_[pos];
return bucket.FindByFp(fp_hash, false, key, cf);
return bucket.FindByFp(fp_hash, false, pred);
};

if (target.HasStashOverflow()) {
Expand Down
48 changes: 31 additions & 17 deletions src/core/dash_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ static uint64_t callbackHash(const void* key) {
return XXH64(&key, sizeof(key), 0);
}

template <typename K> auto EqTo(const K& key) {
return [&key](const auto& probe) { return key == probe; };
}

static dictType IntDict = {callbackHash, NULL, NULL, NULL, NULL, NULL, NULL};

static uint64_t dictSdsHash(const void* key) {
Expand Down Expand Up @@ -136,8 +140,7 @@ class DashTest : public testing::Test {
bool Find(Segment::Key_t key, Segment::Value_t* val) const {
uint64_t hash = dt_.DoHash(key);

std::equal_to<Segment::Key_t> eq;
auto it = segment_.FindIt(key, hash, eq);
auto it = segment_.FindIt(hash, EqTo(key));
if (!it.found())
return false;
*val = segment_.Value(it.index, it.slot);
Expand All @@ -146,9 +149,7 @@ class DashTest : public testing::Test {

bool Contains(Segment::Key_t key) const {
uint64_t hash = dt_.DoHash(key);

std::equal_to<Segment::Key_t> eq;
auto it = segment_.FindIt(key, hash, eq);
auto it = segment_.FindIt(hash, EqTo(key));
return it.found();
}

Expand All @@ -161,7 +162,6 @@ class DashTest : public testing::Test {
set<Segment::Key_t> DashTest::FillSegment(unsigned bid) {
std::set<Segment::Key_t> keys;

std::equal_to<Segment::Key_t> eq;
for (Segment::Key_t key = 0; key < 1000000u; ++key) {
uint64_t hash = dt_.DoHash(key);
unsigned bi = (hash >> 8) % Segment::kBucketNum;
Expand All @@ -170,7 +170,7 @@ set<Segment::Key_t> DashTest::FillSegment(unsigned bid) {
uint8_t fp = hash & 0xFF;
if (fp > 2) // limit fps considerably to find interesting cases.
continue;
auto [it, success] = segment_.Insert(key, 0, hash, eq);
auto [it, success] = segment_.Insert(key, 0, hash, EqTo(key));
if (!success) {
LOG(INFO) << "Stopped at " << key;
break;
Expand Down Expand Up @@ -203,10 +203,9 @@ TEST_F(DashTest, Basic) {
Segment::Key_t key = 0;
Segment::Value_t val = 0;
uint64_t hash = dt_.DoHash(key);
std::equal_to<Segment::Key_t> eq;

EXPECT_TRUE(segment_.Insert(key, val, hash, eq).second);
auto [it, res] = segment_.Insert(key, val, hash, eq);
EXPECT_TRUE(segment_.Insert(key, val, hash, EqTo(key)).second);
auto [it, res] = segment_.Insert(key, val, hash, EqTo(key));
EXPECT_TRUE(!res && it.found());

EXPECT_TRUE(Find(key, &val));
Expand Down Expand Up @@ -262,10 +261,10 @@ TEST_F(DashTest, Segment) {
const auto* k = &segment_.Key(i, 0);
next = std::copy(k, k + Segment::kSlotNum, next);
}
std::equal_to<Segment::Key_t> eq;

for (auto k : arr) {
auto hash = hfun(k);
auto it = segment_.FindIt(k, hash, eq);
auto it = segment_.FindIt(hash, [&k](const auto& probe) { return k == probe; });
ASSERT_TRUE(it.found());
segment_.Delete(it, hash);
}
Expand Down Expand Up @@ -319,10 +318,10 @@ TEST_F(DashTest, Split) {

segment_.Split(&UInt64Policy::HashFn, &s2);
unsigned sum[2] = {0};
std::equal_to<Segment::Key_t> eq;
for (auto key : keys) {
auto it1 = segment_.FindIt(key, dt_.DoHash(key), eq);
auto it2 = s2.FindIt(key, dt_.DoHash(key), eq);
auto eq = [key](const auto& probe) { return key == probe; };
auto it1 = segment_.FindIt(dt_.DoHash(key), eq);
auto it2 = s2.FindIt(dt_.DoHash(key), eq);
ASSERT_NE(it1.found(), it2.found()) << key;

sum[0] += it1.found();
Expand Down Expand Up @@ -476,12 +475,27 @@ TEST_F(DashTest, Custom) {
(void)kBuckSz;

ItemSegment seg{2};
auto cb = [](auto v, auto u) { return v.buf[0] == u.buf[0] && v.buf[1] == u.buf[1]; };

auto it = seg.FindIt(Item{1, 1}, 42, cb);
auto eq = [v = Item{1, 1}](auto u) { return v.buf[0] == u.buf[0] && v.buf[1] == u.buf[1]; };
auto it = seg.FindIt(42, eq);
ASSERT_FALSE(it.found());
}

TEST_F(DashTest, FindByValue) {
using ItemSegment = detail::Segment<Item, uint64_t>;

// Insert three different values with the same hash
ItemSegment segment{2};
segment.Insert(Item{1}, 1, 42, [](const auto& pred) { return pred.buf[0] == 1; });
segment.Insert(Item{2}, 2, 42, [](const auto& pred) { return pred.buf[0] == 2; });
segment.Insert(Item{3}, 3, 42, [](const auto& pred) { return pred.buf[0] == 3; });

// We should be able to find the middle one by value
auto it = segment.FindIt(42, [](const auto& key, const auto& value) { return value == 2; });
EXPECT_TRUE(it.found());
EXPECT_EQ(segment.Value(it.index, it.slot), 2);
}

TEST_F(DashTest, Reserve) {
unsigned bc = dt_.capacity();
for (unsigned i = 0; i <= bc * 2; ++i) {
Expand Down

0 comments on commit d675e63

Please sign in to comment.