Skip to content

Commit fb904b9

Browse files
authored
chore: refactor hnsw index into separate files (#6134)
Signed-off-by: Roman Gershman <[email protected]>
1 parent 96afdef commit fb904b9

File tree

6 files changed

+236
-207
lines changed

6 files changed

+236
-207
lines changed

src/core/search/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ cur_gen_dir(gen_dir)
55

66
set_source_files_properties(${gen_dir}/parser.cc PROPERTIES
77
COMPILE_FLAGS "-Wno-maybe-uninitialized")
8-
add_library(dfly_search_core base.cc ast_expr.cc query_driver.cc search.cc indices.cc
9-
sort_indices.cc vector_utils.cc compressed_sorted_set.cc block_list.cc
8+
add_library(dfly_search_core ast_expr.cc base.cc hnsw_index.cc query_driver.cc search.cc
9+
indices.cc sort_indices.cc vector_utils.cc compressed_sorted_set.cc block_list.cc
1010
range_tree.cc synonyms.cc ${gen_dir}/parser.cc ${gen_dir}/lexer.cc)
1111

1212
target_link_libraries(dfly_search_core base redis_lib absl::strings

src/core/search/hnsw_index.cc

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// Copyright 2023, DragonflyDB authors. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
//
4+
5+
#include "core/search/hnsw_index.h"
6+
7+
#include <absl/strings/match.h>
8+
#include <absl/synchronization/mutex.h>
9+
#include <hnswlib/hnswalg.h>
10+
#include <hnswlib/hnswlib.h>
11+
#include <hnswlib/space_ip.h>
12+
#include <hnswlib/space_l2.h>
13+
14+
#include "base/logging.h"
15+
#include "core/search/vector_utils.h"
16+
17+
namespace dfly::search {
18+
19+
using namespace std;
20+
21+
namespace {
22+
23+
class HnswSpace : public hnswlib::SpaceInterface<float> {
24+
unsigned dim_;
25+
VectorSimilarity sim_;
26+
27+
static float L2DistanceStatic(const void* pVect1, const void* pVect2, const void* param) {
28+
return L2Distance(static_cast<const float*>(pVect1), static_cast<const float*>(pVect2),
29+
*static_cast<const unsigned*>(param));
30+
}
31+
32+
static float IPDistanceStatic(const void* pVect1, const void* pVect2, const void* param) {
33+
return IPDistance(static_cast<const float*>(pVect1), static_cast<const float*>(pVect2),
34+
*static_cast<const unsigned*>(param));
35+
}
36+
37+
public:
38+
explicit HnswSpace(size_t dim, VectorSimilarity sim) : dim_(dim), sim_(sim) {
39+
}
40+
41+
size_t get_data_size() {
42+
return dim_ * sizeof(float);
43+
}
44+
45+
hnswlib::DISTFUNC<float> get_dist_func() {
46+
if (sim_ == VectorSimilarity::L2) {
47+
return L2DistanceStatic;
48+
} else {
49+
return IPDistanceStatic;
50+
}
51+
}
52+
53+
void* get_dist_func_param() {
54+
return &dim_;
55+
}
56+
};
57+
} // namespace
58+
59+
struct HnswlibAdapter {
60+
// Default setting of hnswlib/hnswalg
61+
constexpr static size_t kDefaultEfRuntime = 10;
62+
63+
explicit HnswlibAdapter(const SchemaField::VectorParams& params)
64+
: space_{params.dim, params.sim},
65+
world_{&space_, params.capacity, params.hnsw_m, params.hnsw_ef_construction,
66+
100 /* seed*/} {
67+
}
68+
69+
void Add(const float* data, GlobalDocId id) {
70+
while (true) {
71+
try {
72+
absl::ReaderMutexLock lock(&resize_mutex_);
73+
world_.addPoint(data, id);
74+
return;
75+
} catch (const std::exception& e) {
76+
std::string error_msg = e.what();
77+
if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) {
78+
ResizeIfFull();
79+
continue;
80+
}
81+
LOG(ERROR) << "HnswlibAdapter::Add exception: " << e.what();
82+
}
83+
}
84+
}
85+
86+
void Remove(GlobalDocId id) {
87+
try {
88+
world_.markDelete(id);
89+
} catch (const std::exception& e) {
90+
LOG(WARNING) << "HnswlibAdapter::Remove exception: " << e.what();
91+
}
92+
}
93+
94+
vector<pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
95+
world_.setEf(ef.value_or(kDefaultEfRuntime));
96+
return QueueToVec(world_.searchKnn(target, k));
97+
}
98+
99+
vector<pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
100+
const vector<GlobalDocId>& allowed) {
101+
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
102+
virtual bool operator()(hnswlib::labeltype id) {
103+
return binary_search(allowed->begin(), allowed->end(), id);
104+
}
105+
106+
BinsearchFilter(const vector<GlobalDocId>* allowed) : allowed{allowed} {
107+
}
108+
const vector<GlobalDocId>* allowed;
109+
};
110+
111+
world_.setEf(ef.value_or(kDefaultEfRuntime));
112+
BinsearchFilter filter{&allowed};
113+
return QueueToVec(world_.searchKnn(target, k, &filter));
114+
}
115+
116+
private:
117+
// Function requires that we hold mutex while resizing index. resizeIndex is not thread safe with
118+
// insertion (https://github.com/nmslib/hnswlib/issues/267)
119+
void ResizeIfFull() {
120+
{
121+
// First check with reader lock to avoid contention.
122+
absl::ReaderMutexLock lock(&resize_mutex_);
123+
if (world_.getCurrentElementCount() < world_.getMaxElements() ||
124+
(world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) {
125+
return;
126+
}
127+
}
128+
try {
129+
// Upgrade to writer lock.
130+
absl::WriterMutexLock lock(&resize_mutex_);
131+
if (world_.getCurrentElementCount() == world_.getMaxElements() &&
132+
(!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) {
133+
auto max_elements = world_.getMaxElements();
134+
world_.resizeIndex(max_elements * 2);
135+
VLOG(1) << "Resizing HNSW Index from " << max_elements << " to " << max_elements * 2;
136+
}
137+
} catch (const std::exception& e) {
138+
LOG(FATAL) << "HnswlibAdapter::ResizeIfFull exception: " << e.what();
139+
}
140+
}
141+
142+
template <typename Q> static vector<pair<float, GlobalDocId>> QueueToVec(Q queue) {
143+
vector<pair<float, GlobalDocId>> out(queue.size());
144+
size_t idx = out.size();
145+
while (!queue.empty()) {
146+
out[--idx] = queue.top();
147+
queue.pop();
148+
}
149+
return out;
150+
}
151+
152+
HnswSpace space_;
153+
hnswlib::HierarchicalNSW<float> world_;
154+
absl::Mutex resize_mutex_;
155+
};
156+
157+
HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
158+
: dim_{params.dim}, sim_{params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
159+
DCHECK(params.use_hnsw);
160+
// TODO: Patch hnsw to use MR
161+
}
162+
163+
HnswVectorIndex::~HnswVectorIndex() {
164+
}
165+
166+
bool HnswVectorIndex::Add(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) {
167+
auto vector = doc.GetVector(field);
168+
169+
if (!vector) {
170+
return false;
171+
}
172+
173+
auto& [ptr, size] = vector.value();
174+
175+
if (ptr && size != dim_) {
176+
return false;
177+
}
178+
179+
if (ptr) {
180+
adapter_->Add(ptr.get(), id);
181+
}
182+
183+
return true;
184+
}
185+
186+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(float* target, size_t k,
187+
std::optional<size_t> ef) const {
188+
return adapter_->Knn(target, k, ef);
189+
}
190+
191+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(
192+
float* target, size_t k, std::optional<size_t> ef,
193+
const std::vector<GlobalDocId>& allowed) const {
194+
return adapter_->Knn(target, k, ef, allowed);
195+
}
196+
197+
void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
198+
adapter_->Remove(id);
199+
}
200+
} // namespace dfly::search

src/core/search/hnsw_index.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2023, DragonflyDB authors. All rights reserved.
2+
// See LICENSE for licensing terms.
3+
//
4+
5+
#pragma once
6+
7+
#include "core/search/search.h"
8+
9+
namespace dfly::search {
10+
11+
struct HnswlibAdapter;
12+
class HnswVectorIndex {
13+
public:
14+
explicit HnswVectorIndex(const search::SchemaField::VectorParams& params,
15+
PMR_NS::memory_resource* mr = PMR_NS::get_default_resource());
16+
17+
~HnswVectorIndex();
18+
19+
bool Add(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field);
20+
void Remove(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field);
21+
22+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target, size_t k,
23+
std::optional<size_t> ef) const;
24+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
25+
const std::vector<GlobalDocId>& allowed) const;
26+
27+
private:
28+
size_t dim_;
29+
VectorSimilarity sim_;
30+
std::unique_ptr<HnswlibAdapter> adapter_;
31+
};
32+
33+
} // namespace dfly::search

0 commit comments

Comments
 (0)