Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ void hnsw_search(
ndis += stats.ndis;
nhops += stats.nhops;
res.end();
vt.advance();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is needed? just for understanding?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VisitedTable vt is reused across queries inside the same parallel region for performance (to avoid reallocating/clearing a large visited array for every query). vt.advance() is what “resets” the visited state between queries in O(1) by bumping an epoch counter (or clearing the hashset variant), so each query starts with an empty visited set.

Without vt.advance(), nodes visited while processing query i would remain marked as visited for query i+1, which can cause the search to skip almost everything and return incorrect / empty results.

This is the same pattern used elsewhere in HNSW search code, e.g. in the main search loop that shares a VisitedTable across queries and calls advance() after each query.

}
}
InterruptCallback::check();
Expand Down Expand Up @@ -1080,6 +1081,70 @@ void IndexHNSWCagra::search(
}
}

void IndexHNSWCagra::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params) const {
if (!base_level_only) {
IndexHNSW::range_search(n, x, radius, result, params);
return;
}

const HNSW& hnsw = this->hnsw;
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
float threshold = is_similarity_metric(metric_type) ? -radius : radius;
RangeSearchPartialResult pres(result);

for (idx_t i = 0; i < n; i++) {
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
dis->set_query(x + i * d);

storage_idx_t nearest = -1;
float nearest_d = std::numeric_limits<float>::max();

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<idx_t> distrib(0, ntotal - 1);

for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) {
auto idx = distrib(gen);
auto distance = (*dis)(idx);
if (distance < nearest_d) {
nearest = idx;
nearest_d = distance;
}
}
FAISS_THROW_IF_NOT_MSG(
nearest >= 0, "Could not find a valid entrypoint.");

RangeQueryResult& qres = pres.new_result(i);
RangeResultHandler<HNSW::C> res(&qres, threshold);
VisitedTable vt(ntotal, hnsw.use_visited_hashset);
HNSWStats stats;
hnsw.search_level_0(
*dis, res, 1, &nearest, &nearest_d, 1, stats, vt, params);
n1 += stats.n1;
n2 += stats.n2;
ndis += stats.ndis;
nhops += stats.nhops;
}

pres.set_lims();
result->do_allocation();
pres.copy_result();

hnsw_stats.combine({n1, n2, ndis, nhops});

if (is_similarity_metric(metric_type)) {
for (size_t i = 0; i < result->lims[result->nq]; i++) {
result->distances[i] = -result->distances[i];
}
}
}

faiss::NumericType IndexHNSWCagra::get_numeric_type() const {
return numeric_type_;
}
Expand Down
7 changes: 7 additions & 0 deletions faiss/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ struct IndexHNSWCagra : IndexHNSW {
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

faiss::NumericType get_numeric_type() const;
void set_numeric_type(faiss::NumericType numeric_type);
NumericType numeric_type_;
Expand Down
21 changes: 21 additions & 0 deletions faiss/gpu/test/test_cagra.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ def test_interop_L2_Int8(self):
def test_interop_IP_Int8(self):
self.do_interop(faiss.METRIC_INNER_PRODUCT, faiss.Int8)

def test_base_level_only_range_search(self):
d = 32
nb = 1000
nq = 10
ds = datasets.SyntheticDataset(d, 0, nb, nq)
data_base = ds.get_database()
data_query = ds.get_queries()

res = faiss.StandardGpuResources()
index = faiss.GpuIndexCagra(res, d, faiss.METRIC_L2)
index.train(data_base, numeric_type=faiss.Float32)

cpu_index = faiss.index_gpu_to_cpu(index)
cpu_index.base_level_only = True
cpu_index.num_base_level_search_entrypoints = 8

radius = np.float32(1e9)
lims, _, _ = cpu_index.range_search(data_query, radius)
counts = lims[1:] - lims[:-1]
self.assertTrue(np.all(counts > 0))


@unittest.skipIf(
"CUVS" not in faiss.get_compile_options(),
Expand Down
1 change: 1 addition & 0 deletions faiss/python/swigfaiss.swig
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ void gpu_sync_all_devices()
DOWNCAST ( IndexHNSWFlat )
DOWNCAST ( IndexHNSWPQ )
DOWNCAST ( IndexHNSWSQ )
DOWNCAST ( IndexHNSWCagra )
DOWNCAST ( IndexHNSW )
DOWNCAST ( IndexHNSW2Level )
DOWNCAST ( IndexNNDescentFlat )
Expand Down
79 changes: 79 additions & 0 deletions tests/test_hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

#include <gtest/gtest.h>

#include <cmath>
#include <cstddef>
#include <limits>
#include <random>
#include <unordered_set>
#include <vector>

#include <faiss/IndexFlat.h>
#include <faiss/IndexHNSW.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/HNSW.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/impl/VisitedTable.h>
Expand Down Expand Up @@ -169,6 +172,52 @@ void test_popmin_identical_distances(
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
}

void copy_base_level_only(
const faiss::IndexHNSWCagra& src,
faiss::IndexHNSWCagra& dst) {
auto n = src.ntotal;
auto d = src.d;
auto M = src.hnsw.nb_neighbors(0) / 2;
auto graph_degree = src.hnsw.nb_neighbors(0);

if (dst.storage && dst.own_fields) {
delete dst.storage;
}
dst.storage = new faiss::IndexFlatL2(d);
dst.own_fields = true;
dst.d = d;
dst.metric_type = src.metric_type;
dst.is_trained = true;
dst.keep_max_size_level0 = true;

dst.hnsw.reset();
dst.hnsw.assign_probas.clear();
dst.hnsw.cum_nneighbor_per_level.clear();
dst.hnsw.set_default_probas(M, 1.0 / std::log(M));

dst.hnsw.prepare_level_tab(n, false);

auto src_flat = dynamic_cast<faiss::IndexFlat*>(src.storage);
FAISS_THROW_IF_NOT(src_flat);
dst.storage->add(n, src_flat->get_xb());
dst.ntotal = n;

for (faiss::idx_t i = 0; i < n; i++) {
size_t src_begin, src_end;
src.hnsw.neighbor_range(i, 0, &src_begin, &src_end);

size_t dst_begin, dst_end;
dst.hnsw.neighbor_range(i, 0, &dst_begin, &dst_end);

for (size_t j = 0; j < graph_degree && j < (dst_end - dst_begin); j++) {
dst.hnsw.neighbors[dst_begin + j] =
src.hnsw.neighbors[src_begin + j];
}
}

dst.base_level_only = true;
}

TEST(HNSW, Test_popmin) {
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32, 64, 128};
for (const size_t size : sizes) {
Expand Down Expand Up @@ -218,6 +267,36 @@ TEST(HNSW, Test_IndexHNSW_METRIC_Lp) {
EXPECT_EQ(label, 0); // Label should be 0
}

TEST(HNSW, Test_IndexHNSWCagra_BaseLevelOnly_RangeSearch) {
int d = 8;
int nb = 100;
int nq = 5;
int M = 4;

std::vector<float> xb(nb * d);
std::vector<float> xq(nq * d);
faiss::float_rand(xb.data(), xb.size(), 1234);
faiss::float_rand(xq.data(), xq.size(), 4321);

faiss::IndexHNSWCagra index(d, M, faiss::METRIC_L2);
index.add(nb, xb.data());
index.base_level_only = true;
index.num_base_level_search_entrypoints = 8;

faiss::IndexHNSWCagra dst_index;
copy_base_level_only(index, dst_index);
dst_index.num_base_level_search_entrypoints = 8;

faiss::RangeSearchResult res(nq);
float radius = 1e9f;
dst_index.range_search(nq, xq.data(), radius, &res);

for (int i = 0; i < nq; i++) {
auto count = res.lims[i + 1] - res.lims[i];
EXPECT_GT(count, 0);
}
}

class HNSWTest : public testing::Test {
protected:
HNSWTest() {
Expand Down