Skip to content

Commit 76bec0b

Browse files
authored
Facebook sync (#573)
Features: - automatic tracking of C++ references in Python - non-intel platforms supported -- some functions optimized for ARM - override nprobe for concurrent searches - support for floating-point quantizers in binary indexes Bug fixes: - no more segfaults in python (I know it's the same as the first feature but it's important!) - fix GpuIndexIVFFlat issues for float32 with 64 / 128 dims - fix sharding of flat indexes on GPU with index_cpu_to_gpu_multiple
1 parent 19cea3d commit 76bec0b

54 files changed

Lines changed: 9828 additions & 4801 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

AutoTune.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,21 @@ void ParameterSpace::set_index_parameter (
518518
return;
519519
}
520520
}
521+
522+
if (name == "efSearch") {
523+
if (DC (IndexHNSW)) {
524+
ix->hnsw.efSearch = int(val);
525+
return;
526+
}
527+
if (DC (IndexIVF)) {
528+
if (IndexHNSW *cq =
529+
dynamic_cast<IndexHNSW *>(ix->quantizer)) {
530+
cq->hnsw.efSearch = int(val);
531+
return;
532+
}
533+
}
534+
}
535+
521536
FAISS_THROW_FMT ("ParameterSpace::set_index_parameter:"
522537
"could not set parameter %s",
523538
name.c_str());
@@ -682,6 +697,7 @@ struct VTChain {
682697
char get_trains_alone(const Index *coarse_quantizer) {
683698
return
684699
dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ? 1 :
700+
dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer) ? 2 :
685701
0;
686702
}
687703

@@ -738,6 +754,11 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
738754
} else if (stok == "L2norm") {
739755
vt_1 = new NormalizationTransform (d, 2.0);
740756

757+
// coarse quantizers
758+
} else if (!coarse_quantizer &&
759+
sscanf (tok, "IVF%d_HNSW%d", &ncentroids, &M) == 2) {
760+
FAISS_THROW_IF_NOT (metric == METRIC_L2);
761+
coarse_quantizer_1 = new IndexHNSWFlat (d, M);
741762

742763
} else if (!coarse_quantizer &&
743764
sscanf (tok, "IVF%d", &ncentroids) == 1) {
@@ -935,4 +956,5 @@ IndexBinary *index_binary_factory(int d, const char *description)
935956
}
936957

937958

959+
938960
} // namespace faiss

AutoTune.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ Index *index_factory (int d, const char *description,
205205

206206
IndexBinary *index_binary_factory (int d, const char *description);
207207

208-
209-
210208
} // namespace faiss
211209

212210

AuxIndexStructures.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ struct IOWriter {
198198

199199

200200
struct VectorIOReader:IOReader {
201-
const std::vector<uint8_t> data;
201+
std::vector<uint8_t> data;
202202
size_t rp = 0;
203203
size_t operator()(void *ptr, size_t size, size_t nitems) override;
204204
};

IVFlib.cpp

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
/**
2+
* Copyright (c) 2015-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD+Patents license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// -*- c++ -*-
10+
11+
/*
12+
* implementation of Hyper-parameter auto-tuning
13+
*/
14+
15+
#include "IVFlib.h"
16+
17+
#include <memory>
18+
19+
#include "VectorTransform.h"
20+
#include "FaissAssert.h"
21+
22+
23+
24+
namespace faiss { namespace ivflib {
25+
26+
27+
void check_compatible_for_merge (const Index * index0,
28+
const Index * index1)
29+
{
30+
31+
const faiss::IndexPreTransform *pt0 =
32+
dynamic_cast<const faiss::IndexPreTransform *>(index0);
33+
34+
if (pt0) {
35+
const faiss::IndexPreTransform *pt1 =
36+
dynamic_cast<const faiss::IndexPreTransform *>(index1);
37+
FAISS_THROW_IF_NOT_MSG (pt1, "both indexes should be pretransforms");
38+
39+
FAISS_THROW_IF_NOT (pt0->chain.size() == pt1->chain.size());
40+
for (int i = 0; i < pt0->chain.size(); i++) {
41+
FAISS_THROW_IF_NOT (typeid(pt0->chain[i]) == typeid(pt1->chain[i]));
42+
}
43+
44+
index0 = pt0->index;
45+
index1 = pt1->index;
46+
}
47+
FAISS_THROW_IF_NOT (typeid(index0) == typeid(index1));
48+
FAISS_THROW_IF_NOT (index0->d == index1->d &&
49+
index0->metric_type == index1->metric_type);
50+
51+
const faiss::IndexIVF *ivf0 = dynamic_cast<const faiss::IndexIVF *>(index0);
52+
if (ivf0) {
53+
const faiss::IndexIVF *ivf1 =
54+
dynamic_cast<const faiss::IndexIVF *>(index1);
55+
FAISS_THROW_IF_NOT (ivf1);
56+
57+
ivf0->check_compatible_for_merge (*ivf1);
58+
}
59+
60+
// TODO: check as thoroughfully for other index types
61+
62+
}
63+
64+
const IndexIVF * extract_index_ivf (const Index * index)
65+
{
66+
if (auto *pt =
67+
dynamic_cast<const IndexPreTransform *>(index)) {
68+
index = pt->index;
69+
}
70+
71+
auto *ivf = dynamic_cast<const IndexIVF *>(index);
72+
73+
FAISS_THROW_IF_NOT (ivf);
74+
75+
return ivf;
76+
}
77+
78+
IndexIVF * extract_index_ivf (Index * index) {
79+
return const_cast<IndexIVF*> (extract_index_ivf ((const Index*)(index)));
80+
}
81+
82+
void merge_into(faiss::Index *index0, faiss::Index *index1, bool shift_ids) {
83+
84+
check_compatible_for_merge (index0, index1);
85+
IndexIVF * ivf0 = extract_index_ivf (index0);
86+
IndexIVF * ivf1 = extract_index_ivf (index1);
87+
88+
ivf0->merge_from (*ivf1, shift_ids ? ivf0->ntotal : 0);
89+
90+
// useful for IndexPreTransform
91+
index0->ntotal = ivf0->ntotal;
92+
index1->ntotal = ivf1->ntotal;
93+
}
94+
95+
96+
97+
void search_centroid(faiss::Index *index,
98+
const float* x, int n,
99+
idx_t* centroid_ids)
100+
{
101+
std::unique_ptr<float[]> del;
102+
if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
103+
x = index_pre->apply_chain(n, x);
104+
del.reset((float*)x);
105+
index = index_pre->index;
106+
}
107+
faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
108+
assert(index_ivf);
109+
index_ivf->quantizer->assign(n, x, centroid_ids);
110+
}
111+
112+
113+
114+
void search_and_return_centroids(faiss::Index *index,
115+
size_t n,
116+
const float* xin,
117+
long k,
118+
float *distances,
119+
idx_t* labels,
120+
idx_t* query_centroid_ids,
121+
idx_t* result_centroid_ids)
122+
{
123+
const float *x = xin;
124+
std::unique_ptr<float []> del;
125+
if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
126+
x = index_pre->apply_chain(n, x);
127+
del.reset((float*)x);
128+
index = index_pre->index;
129+
}
130+
faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
131+
assert(index_ivf);
132+
133+
size_t nprobe = index_ivf->nprobe;
134+
std::vector<idx_t> cent_nos (n * nprobe);
135+
std::vector<float> cent_dis (n * nprobe);
136+
index_ivf->quantizer->search(
137+
n, x, nprobe, cent_dis.data(), cent_nos.data());
138+
139+
if (query_centroid_ids) {
140+
for (size_t i = 0; i < n; i++)
141+
query_centroid_ids[i] = cent_nos[i * nprobe];
142+
}
143+
144+
index_ivf->search_preassigned (n, x, k,
145+
cent_nos.data(), cent_dis.data(),
146+
distances, labels, true);
147+
148+
for (size_t i = 0; i < n * k; i++) {
149+
idx_t label = labels[i];
150+
if (label < 0) {
151+
if (result_centroid_ids)
152+
result_centroid_ids[i] = -1;
153+
} else {
154+
long list_no = label >> 32;
155+
long list_index = label & 0xffffffff;
156+
if (result_centroid_ids)
157+
result_centroid_ids[i] = list_no;
158+
labels[i] = index_ivf->invlists->get_single_id(list_no, list_index);
159+
}
160+
}
161+
}
162+
163+
164+
SlidingIndexWindow::SlidingIndexWindow (Index *index): index (index) {
165+
n_slice = 0;
166+
IndexIVF* index_ivf = const_cast<IndexIVF*>(extract_index_ivf (index));
167+
ils = dynamic_cast<ArrayInvertedLists *> (index_ivf->invlists);
168+
nlist = ils->nlist;
169+
FAISS_THROW_IF_NOT_MSG (ils,
170+
"only supports indexes with ArrayInvertedLists");
171+
sizes.resize(nlist);
172+
}
173+
174+
template<class T>
175+
static void shift_and_add (std::vector<T> & dst,
176+
size_t remove,
177+
const std::vector<T> & src)
178+
{
179+
if (remove > 0)
180+
memmove (dst.data(), dst.data() + remove,
181+
(dst.size() - remove) * sizeof (T));
182+
size_t insert_point = dst.size() - remove;
183+
dst.resize (insert_point + src.size());
184+
memcpy (dst.data() + insert_point, src.data (), src.size() * sizeof(T));
185+
}
186+
187+
template<class T>
188+
static void remove_from_begin (std::vector<T> & v,
189+
size_t remove)
190+
{
191+
if (remove > 0)
192+
v.erase (v.begin(), v.begin() + remove);
193+
}
194+
195+
void SlidingIndexWindow::step(const Index *sub_index, bool remove_oldest) {
196+
197+
FAISS_THROW_IF_NOT_MSG (!remove_oldest || n_slice > 0,
198+
"cannot remove slice: there is none");
199+
200+
const ArrayInvertedLists *ils2 = nullptr;
201+
if(sub_index) {
202+
check_compatible_for_merge (index, sub_index);
203+
ils2 = dynamic_cast<const ArrayInvertedLists*>(
204+
extract_index_ivf (sub_index)->invlists);
205+
FAISS_THROW_IF_NOT_MSG (ils2, "supports only ArrayInvertedLists");
206+
}
207+
IndexIVF *index_ivf = extract_index_ivf (index);
208+
209+
if (remove_oldest && ils2) {
210+
for (int i = 0; i < nlist; i++) {
211+
std::vector<size_t> & sizesi = sizes[i];
212+
size_t amount_to_remove = sizesi[0];
213+
index_ivf->ntotal += ils2->ids[i].size() - amount_to_remove;
214+
215+
shift_and_add (ils->ids[i], amount_to_remove, ils2->ids[i]);
216+
shift_and_add (ils->codes[i], amount_to_remove * ils->code_size,
217+
ils2->codes[i]);
218+
for (int j = 0; j + 1 < n_slice; j++) {
219+
sizesi[j] = sizesi[j + 1] - amount_to_remove;
220+
}
221+
sizesi[n_slice - 1] = ils->ids[i].size();
222+
}
223+
} else if (ils2) {
224+
for (int i = 0; i < nlist; i++) {
225+
index_ivf->ntotal += ils2->ids[i].size();
226+
shift_and_add (ils->ids[i], 0, ils2->ids[i]);
227+
shift_and_add (ils->codes[i], 0, ils2->codes[i]);
228+
sizes[i].push_back(ils->ids[i].size());
229+
}
230+
n_slice++;
231+
} else if (remove_oldest) {
232+
for (int i = 0; i < nlist; i++) {
233+
size_t amount_to_remove = sizes[i][0];
234+
index_ivf->ntotal -= amount_to_remove;
235+
remove_from_begin (ils->ids[i], amount_to_remove);
236+
remove_from_begin (ils->codes[i],
237+
amount_to_remove * ils->code_size);
238+
for (int j = 0; j + 1 < n_slice; j++) {
239+
sizes[i][j] = sizes[i][j + 1] - amount_to_remove;
240+
}
241+
sizes[i].resize(sizes[i].size() - 1);
242+
}
243+
n_slice--;
244+
} else {
245+
FAISS_THROW_MSG ("nothing to do???");
246+
}
247+
index->ntotal = index_ivf->ntotal;
248+
}
249+
250+
251+
252+
// Get a subset of inverted lists [i0, i1). Works on IndexIVF's and
253+
// IndexIVF's embedded in a IndexPreTransform
254+
255+
ArrayInvertedLists *
256+
get_invlist_range (const Index *index, long i0, long i1)
257+
{
258+
const IndexIVF *ivf = extract_index_ivf (index);
259+
260+
FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist);
261+
262+
const InvertedLists *src = ivf->invlists;
263+
264+
ArrayInvertedLists * il = new ArrayInvertedLists(i1 - i0, src->code_size);
265+
266+
for (long i = i0; i < i1; i++) {
267+
il->add_entries(i - i0, src->list_size(i),
268+
InvertedLists::ScopedIds (src, i).get(),
269+
InvertedLists::ScopedCodes (src, i).get());
270+
}
271+
return il;
272+
}
273+
274+
275+
276+
void set_invlist_range (Index *index, long i0, long i1,
277+
ArrayInvertedLists * src)
278+
{
279+
IndexIVF *ivf = extract_index_ivf (index);
280+
281+
FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist);
282+
283+
ArrayInvertedLists *dst = dynamic_cast<ArrayInvertedLists *>(ivf->invlists);
284+
FAISS_THROW_IF_NOT_MSG (dst, "only ArrayInvertedLists supported");
285+
FAISS_THROW_IF_NOT (src->nlist == i1 - i0 &&
286+
dst->code_size == src->code_size);
287+
288+
size_t ntotal = index->ntotal;
289+
for (long i = i0 ; i < i1; i++) {
290+
ntotal -= dst->list_size (i);
291+
ntotal += src->list_size (i - i0);
292+
std::swap (src->codes[i - i0], dst->codes[i]);
293+
std::swap (src->ids[i - i0], dst->ids[i]);
294+
}
295+
ivf->ntotal = index->ntotal = ntotal;
296+
}
297+
298+
299+
void search_with_parameters (const Index *index,
300+
idx_t n, const float *x, idx_t k,
301+
float *distances, idx_t *labels,
302+
IVFSearchParameters *params)
303+
{
304+
FAISS_THROW_IF_NOT (params);
305+
const float *prev_x = x;
306+
ScopeDeleter<float> del;
307+
308+
if (auto ip = dynamic_cast<const IndexPreTransform *> (index)) {
309+
x = ip->apply_chain (n, x);
310+
if (x != prev_x) {
311+
del.set(x);
312+
}
313+
index = ip->index;
314+
}
315+
316+
std::vector<idx_t> Iq(params->nprobe * n);
317+
std::vector<float> Dq(params->nprobe * n);
318+
319+
const IndexIVF *index_ivf = dynamic_cast<const IndexIVF *>(index);
320+
FAISS_THROW_IF_NOT (index_ivf);
321+
322+
index_ivf->quantizer->search(n, x, params->nprobe,
323+
Dq.data(), Iq.data());
324+
325+
index_ivf->search_preassigned(n, x, k, Iq.data(), Dq.data(),
326+
distances, labels,
327+
false, params);
328+
}
329+
330+
331+
332+
} } // namespace faiss::ivflib

0 commit comments

Comments
 (0)