forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AuxIndexStructures.h
286 lines (206 loc) · 8 KB
/
AuxIndexStructures.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
// Auxiliary index structures, that are used in indexes but that can
// be forward-declared
#ifndef FAISS_AUX_INDEX_STRUCTURES_H
#define FAISS_AUX_INDEX_STRUCTURES_H
#include <stdint.h>
#include <vector>
#include <unordered_set>
#include <memory>
#include <mutex>
#include "Index.h"
namespace faiss {
/** The objective is to have a simple result structure while
* minimizing the number of mem copies in the result. The method
* do_allocation can be overloaded to allocate the result tables in
* the matrix type of a scripting language like Lua or Python. */
struct RangeSearchResult {
size_t nq; ///< nb of queries
size_t *lims; ///< size (nq + 1)
typedef Index::idx_t idx_t;
idx_t *labels; ///< result for query i is labels[lims[i]:lims[i+1]]
float *distances; ///< corresponding distances (not sorted)
size_t buffer_size; ///< size of the result buffers used
/// lims must be allocated on input to range_search.
explicit RangeSearchResult (idx_t nq, bool alloc_lims=true);
/// called when lims contains the nb of elements result entries
/// for each query
virtual void do_allocation ();
virtual ~RangeSearchResult ();
};
/** Encapsulates a set of ids to remove. */
struct IDSelector {
typedef Index::idx_t idx_t;
virtual bool is_member (idx_t id) const = 0;
virtual ~IDSelector() {}
};
/** remove ids between [imni, imax) */
struct IDSelectorRange: IDSelector {
idx_t imin, imax;
IDSelectorRange (idx_t imin, idx_t imax);
bool is_member(idx_t id) const override;
~IDSelectorRange() override {}
};
/** Remove ids from a set. Repetitions of ids in the indices set
* passed to the constructor does not hurt performance. The hash
* function used for the bloom filter and GCC's implementation of
* unordered_set are just the least significant bits of the id. This
* works fine for random ids or ids in sequences but will produce many
* hash collisions if lsb's are always the same */
struct IDSelectorBatch: IDSelector {
std::unordered_set<idx_t> set;
typedef unsigned char uint8_t;
std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
int nbits;
idx_t mask;
IDSelectorBatch (size_t n, const idx_t *indices);
bool is_member(idx_t id) const override;
~IDSelectorBatch() override {}
};
/****************************************************************
* Result structures for range search.
*
* The main constraint here is that we want to support parallel
* queries from different threads in various ways: 1 thread per query,
* several threads per query. We store the actual results in blocks of
* fixed size rather than exponentially increasing memory. At the end,
* we copy the block content to a linear result array.
*****************************************************************/
/** List of temporary buffers used to store results before they are
* copied to the RangeSearchResult object. */
struct BufferList {
typedef Index::idx_t idx_t;
// buffer sizes in # entries
size_t buffer_size;
struct Buffer {
idx_t *ids;
float *dis;
};
std::vector<Buffer> buffers;
size_t wp; ///< write pointer in the last buffer.
explicit BufferList (size_t buffer_size);
~BufferList ();
/// create a new buffer
void append_buffer ();
/// add one result, possibly appending a new buffer if needed
void add (idx_t id, float dis);
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
/// tables dest_ids, dest_dis
void copy_range (size_t ofs, size_t n,
idx_t * dest_ids, float *dest_dis);
};
struct RangeSearchPartialResult;
/// result structure for a single query
struct RangeQueryResult {
using idx_t = Index::idx_t;
idx_t qno; //< id of the query
size_t nres; //< nb of results for this query
RangeSearchPartialResult * pres;
/// called by search function to report a new result
void add (float dis, idx_t id);
};
/// the entries in the buffers are split per query
struct RangeSearchPartialResult: BufferList {
RangeSearchResult * res;
/// eventually the result will be stored in res_in
explicit RangeSearchPartialResult (RangeSearchResult * res_in);
/// query ids + nb of results per query.
std::vector<RangeQueryResult> queries;
/// begin a new result
RangeQueryResult & new_result (idx_t qno);
/*****************************************
* functions used at the end of the search to merge the result
* lists */
void finalize ();
/// called by range_search before do_allocation
void set_lims ();
/// called by range_search after do_allocation
void copy_result (bool incremental = false);
/// merge a set of PartialResult's into one RangeSearchResult
/// on ouptut the partialresults are empty!
static void merge (std::vector <RangeSearchPartialResult *> &
partial_results, bool do_delete=true);
};
/***********************************************************
* Abstract I/O objects
***********************************************************/
struct IOReader {
// name that can be used in error messages
std::string name;
// fread
virtual size_t operator()(
void *ptr, size_t size, size_t nitems) = 0;
// return a file number that can be memory-mapped
virtual int fileno ();
virtual ~IOReader() {}
};
struct IOWriter {
// name that can be used in error messages
std::string name;
// fwrite
virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) = 0;
// return a file number that can be memory-mapped
virtual int fileno ();
virtual ~IOWriter() {}
};
struct VectorIOReader:IOReader {
std::vector<uint8_t> data;
size_t rp = 0;
size_t operator()(void *ptr, size_t size, size_t nitems) override;
};
struct VectorIOWriter:IOWriter {
std::vector<uint8_t> data;
size_t operator()(const void *ptr, size_t size, size_t nitems) override;
};
/***********************************************************
* The distance computer maintains a current query and computes
* distances to elements in an index that supports random access.
*
* The DistanceComputer is not intended to be thread-safe (eg. because
* it maintains counters) so the distance functions are not const,
* instanciate one from each thread if needed.
***********************************************************/
struct DistanceComputer {
using idx_t = Index::idx_t;
/// called before computing distances
virtual void set_query(const float *x) = 0;
/// compute distance of vector i to current query
virtual float operator () (idx_t i) = 0;
/// compute distance between two stored vectors
virtual float symmetric_dis (idx_t i, idx_t j) = 0;
virtual ~DistanceComputer() {}
};
/***********************************************************
* Interrupt callback
***********************************************************/
struct InterruptCallback {
virtual bool want_interrupt () = 0;
virtual ~InterruptCallback() {}
// lock that protects concurrent calls to is_interrupted
static std::mutex lock;
static std::unique_ptr<InterruptCallback> instance;
static void clear_instance ();
/** check if:
* - an interrupt callback is set
* - the callback retuns true
* if this is the case, then throw an exception. Should not be called
* from multiple threds.
*/
static void check ();
/// same as check() but return true if is interrupted instead of
/// throwing. Can be called from multiple threads.
static bool is_interrupted ();
/** assuming each iteration takes a certain number of flops, what
* is a reasonable interval to check for interrupts?
*/
static size_t get_period_hint (size_t flops);
};
}; // namespace faiss
#endif