Skip to content

Commit

Permalink
Use search_pool to control iterator->Next()
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Dec 28, 2024
1 parent d448fd6 commit 98048fc
Show file tree
Hide file tree
Showing 14 changed files with 413 additions and 287 deletions.
2 changes: 1 addition & 1 deletion include/knowhere/comp/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BruteForce {
template <typename DataType>
static expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset);
const BitsetView& bitset, bool use_knowhere_search_pool = true);
};

} // namespace knowhere
Expand Down
3 changes: 3 additions & 0 deletions include/knowhere/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum class Status {
internal_error = 27,
invalid_serialized_index_type = 28,
sparse_inner_error = 29,
brute_force_inner_error = 30,
};

inline std::string
Expand Down Expand Up @@ -104,6 +105,8 @@ Status2String(knowhere::Status status) {
return "the serialized index type is not recognized";
case knowhere::Status::sparse_inner_error:
return "sparse index inner error";
case knowhere::Status::brute_force_inner_error:
return "brute_force index inner error";
default:
return "unexpected status";
}
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class Index {
Search(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;

expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;
AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset,
bool use_knowhere_search_pool = true) const;

expected<DataSetPtr>
RangeSearch(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;
Expand Down
162 changes: 121 additions & 41 deletions include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ class IndexNode : public Object {
using IteratorPtr = std::shared_ptr<iterator>;

virtual expected<std::vector<IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const {
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset,
bool use_knowhere_search_pool = true) const {
return expected<std::vector<std::shared_ptr<iterator>>>::Err(
Status::not_implemented, "annIterator not supported for current index type");
}
Expand Down Expand Up @@ -202,7 +203,10 @@ class IndexNode : public Object {
return GenResultDataSet(nq, std::move(range_search_result));
}

auto its_or = AnnIterator(dataset, std::move(cfg), bitset);
// The range_search function has utilized the search_pool to concurrently handle various queries.
// To prevent potential deadlocks, the iterator for a single query no longer requires additional thread
// control over the next() call.
auto its_or = AnnIterator(dataset, std::move(cfg), bitset, false);
if (!its_or.has_value()) {
return expected<DataSetPtr>::Err(its_or.error(),
"RangeSearch failed due to AnnIterator failure: " + its_or.what());
Expand Down Expand Up @@ -290,13 +294,17 @@ class IndexNode : public Object {
futs.reserve(nq);
if (retain_iterator_order) {
for (size_t i = 0; i < nq; i++) {
futs.emplace_back(
ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { task_with_ordered_iterator(idx); }));
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
task_with_ordered_iterator(idx);
}));
}
} else {
for (size_t i = 0; i < nq; i++) {
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push(
[&, idx = i]() { task_with_unordered_iterator(idx); }));
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
task_with_unordered_iterator(idx);
}));
}
}
WaitAllSuccess(futs);
Expand Down Expand Up @@ -437,6 +445,10 @@ class IndexNode : public Object {
// with quantization, override `raw_distance`.
// Internally, this structure uses the same priority queue class, but may multiply all
// incoming distances to (-1) value in order to turn max priority queue into a min one.
// If use_knowhere_search_pool is True (the default), the iterator->Next() will be scheduled by the
// knowhere_search_thread_pool.
// If False, will Not involve thread scheduling internally, so please take caution.
template <bool use_knowhere_search_pool = true>
class IndexIterator : public IndexNode::iterator {
public:
IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f, bool retain_iterator_order = false)
Expand All @@ -449,36 +461,54 @@ class IndexIterator : public IndexNode::iterator {
std::pair<int64_t, float>
Next() override {
if (!initialized_) {
throw std::runtime_error("Next should not be called before initialization");
initialize();
}
auto& q = refined_res_.empty() ? res_ : refined_res_;
if (q.empty()) {
throw std::runtime_error("No more elements");
}
auto ret = q.top();
q.pop();
UpdateNext();
if (retain_iterator_order_) {
while (HasNext()) {
auto& q = refined_res_.empty() ? res_ : refined_res_;
auto next_ret = q.top();
// with the help of `sign_`, both `res_` and `refine_res` are min-heap.
// such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`.
// just make sure that the next value is greater than or equal to the current value.
if (next_ret.val >= ret.val) {
break;

auto update_next_func = [&]() {
UpdateNext();
if (retain_iterator_order_) {
while (HasNext()) {
auto& q = refined_res_.empty() ? res_ : refined_res_;
auto next_ret = q.top();
// with the help of `sign_`, both `res_` and `refine_res` are min-heap.
// such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`.
// just make sure that the next value is greater than or equal to the current value.
if (next_ret.val >= ret.val) {
break;
}
q.pop();
UpdateNext();
}
q.pop();
UpdateNext();
}
};
if constexpr (use_knowhere_search_pool) {
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
update_next_func();
}));
WaitAllSuccess(futs);
#else
update_next_func();
#endif
} else {
update_next_func();
}

return std::make_pair(ret.id, ret.val * sign_);
}

[[nodiscard]] bool
HasNext() override {
if (!initialized_) {
throw std::runtime_error("HasNext should not be called before initialization");
initialize();
}
return !res_.empty() || !refined_res_.empty();
}
Expand Down Expand Up @@ -544,42 +574,90 @@ class IndexIterator : public IndexNode::iterator {
};

// An iterator implementation that accepts a list of distances and ids and returns them in order.
template <typename Compute_Dist_Func, bool use_knowhere_search_pool = true>
class PrecomputedDistanceIterator : public IndexNode::iterator {
public:
PrecomputedDistanceIterator(std::vector<DistId>&& distances_ids, bool larger_is_closer)
: larger_is_closer_(larger_is_closer), results_(std::move(distances_ids)) {
sort_size_ = get_sort_size(results_.size());
sort_next();
}

// Construct an iterator from a list of distances with index being id, filtering out zero distances.
PrecomputedDistanceIterator(const std::vector<float>& distances, bool larger_is_closer)
: larger_is_closer_(larger_is_closer) {
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non zero
// dims out of 30000 total dims) sharing at least 1 common non-zero dimension.
results_.reserve(distances.size() * 0.3);
for (size_t i = 0; i < distances.size(); i++) {
if (distances[i] != 0) {
results_.emplace_back((int64_t)i, distances[i]);
}
}
sort_size_ = get_sort_size(results_.size());
sort_next();
PrecomputedDistanceIterator(Compute_Dist_Func compute_dist_func, bool larger_is_closer)
: compute_dist_func_(compute_dist_func), larger_is_closer_(larger_is_closer) {
}

std::pair<int64_t, float>
Next() override {
sort_next();
if (!initialized_) {
initialize();
}
if constexpr (use_knowhere_search_pool) {
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
sort_next();
}));
WaitAllSuccess(futs);
#else
sort_next();
#endif
} else {
sort_next();
}
auto& result = results_[next_++];
return std::make_pair(result.id, result.val);
}

[[nodiscard]] bool
HasNext() override {
if (!initialized_) {
initialize();
}
return next_ < results_.size() && results_[next_].id != -1;
}

void
initialize() {
if (initialized_) {
throw std::runtime_error("initialize should not be called twice");
}
if constexpr (use_knowhere_search_pool) {
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
compute_all_dist_ids();
}));
WaitAllSuccess(futs);
#else
compute_all_dist_ids();
#endif
} else {
compute_all_dist_ids();
}
sort_next();
initialized_ = true;
}

private:
void
compute_all_dist_ids() {
using ReturnType = std::invoke_result_t<Compute_Dist_Func>;
if constexpr (std::is_same_v<ReturnType, std::vector<DistId>>) {
results_ = compute_dist_func_();
} else if constexpr (std::is_same_v<ReturnType, std::vector<float>>) {
// From a list of distances with index being id, filtering out zero distances.
std::vector<float> dists = compute_dist_func_();
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non
// zero dims out of 30000 total dims) sharing at least 1 common non-zero dimension.
results_.reserve(dists.size() * 0.3);
for (size_t i = 0; i < dists.size(); i++) {
if (dists[i] != 0) {
results_.emplace_back((int64_t)i, dists[i]);
}
}
} else {
throw std::runtime_error("unknown compute_dist_func");
}
sort_size_ = get_sort_size(results_.size());
}

static inline size_t
get_sort_size(size_t rows) {
return std::max((size_t)50000, rows / 10);
Expand All @@ -602,8 +680,10 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {

sorted_ = current_end;
}
const bool larger_is_closer_;

Compute_Dist_Func compute_dist_func_;
const bool larger_is_closer_;
bool initialized_ = false;
std::vector<DistId> results_;
size_t next_ = 0;
size_t sorted_ = 0;
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index_node_data_mock_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class IndexNodeDataMockWrapper : public IndexNode {
RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;

expected<std::vector<IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset,
bool use_knowhere_search_pool) const override;

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override;
Expand Down
Loading

0 comments on commit 98048fc

Please sign in to comment.