diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 017dda615f..50141443a3 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -284,6 +284,7 @@ void hnsw_search( ndis += stats.ndis; nhops += stats.nhops; res.end(); + vt.advance(); } } InterruptCallback::check(); @@ -1080,6 +1081,70 @@ void IndexHNSWCagra::search( } } +void IndexHNSWCagra::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const { + if (!base_level_only) { + IndexHNSW::range_search(n, x, radius, result, params); + return; + } + + const HNSW& hnsw = this->hnsw; + size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; + float threshold = is_similarity_metric(metric_type) ? -radius : radius; + RangeSearchPartialResult pres(result); + + for (idx_t i = 0; i < n; i++) { + std::unique_ptr dis( + storage_distance_computer(storage)); + dis->set_query(x + i * d); + + storage_idx_t nearest = -1; + float nearest_d = std::numeric_limits::max(); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution distrib(0, ntotal - 1); + + for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) { + auto idx = distrib(gen); + auto distance = (*dis)(idx); + if (distance < nearest_d) { + nearest = idx; + nearest_d = distance; + } + } + FAISS_THROW_IF_NOT_MSG( + nearest >= 0, "Could not find a valid entrypoint."); + + RangeQueryResult& qres = pres.new_result(i); + RangeResultHandler res(&qres, threshold); + VisitedTable vt(ntotal, hnsw.use_visited_hashset); + HNSWStats stats; + hnsw.search_level_0( + *dis, res, 1, &nearest, &nearest_d, 1, stats, vt, params); + n1 += stats.n1; + n2 += stats.n2; + ndis += stats.ndis; + nhops += stats.nhops; + } + + pres.set_lims(); + result->do_allocation(); + pres.copy_result(); + + hnsw_stats.combine({n1, n2, ndis, nhops}); + + if (is_similarity_metric(metric_type)) { + for (size_t i = 0; i < result->lims[result->nq]; i++) { + result->distances[i] = -result->distances[i]; + } + } +} + faiss::NumericType IndexHNSWCagra::get_numeric_type() const { return numeric_type_; } diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index a43828d428..b1cbd8bf7e 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -259,6 +259,13 @@ struct IndexHNSWCagra : IndexHNSW { idx_t* labels, const SearchParameters* params = nullptr) const override; + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + faiss::NumericType get_numeric_type() const; void set_numeric_type(faiss::NumericType numeric_type); NumericType numeric_type_; diff --git a/faiss/gpu/test/test_cagra.py b/faiss/gpu/test/test_cagra.py index 972ce5e293..d9e1b27a6c 100644 --- a/faiss/gpu/test/test_cagra.py +++ b/faiss/gpu/test/test_cagra.py @@ -150,6 +150,27 @@ def test_interop_L2_Int8(self): def test_interop_IP_Int8(self): self.do_interop(faiss.METRIC_INNER_PRODUCT, faiss.Int8) + def test_base_level_only_range_search(self): + d = 32 + nb = 1000 + nq = 10 + ds = datasets.SyntheticDataset(d, 0, nb, nq) + data_base = ds.get_database() + data_query = ds.get_queries() + + res = faiss.StandardGpuResources() + index = faiss.GpuIndexCagra(res, d, faiss.METRIC_L2) + index.train(data_base, numeric_type=faiss.Float32) + + cpu_index = faiss.index_gpu_to_cpu(index) + cpu_index.base_level_only = True + cpu_index.num_base_level_search_entrypoints = 8 + + radius = np.float32(1e9) + lims, _, _ = cpu_index.range_search(data_query, radius) + counts = lims[1:] - lims[:-1] + self.assertTrue(np.all(counts > 0)) + @unittest.skipIf( "CUVS" not in faiss.get_compile_options(), diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 8ff744c056..7deeea8d4b 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -827,6 +827,7 @@ void gpu_sync_all_devices() DOWNCAST ( IndexHNSWFlat ) DOWNCAST ( IndexHNSWPQ ) DOWNCAST ( IndexHNSWSQ ) + DOWNCAST ( IndexHNSWCagra ) DOWNCAST ( IndexHNSW ) DOWNCAST ( IndexHNSW2Level ) DOWNCAST ( IndexNNDescentFlat ) diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp index 9e2b46fbce..44cb774788 100644 --- a/tests/test_hnsw.cpp +++ b/tests/test_hnsw.cpp @@ -7,13 +7,16 @@ #include +#include #include #include #include #include #include +#include #include +#include #include #include #include @@ -169,6 +172,52 @@ void test_popmin_identical_distances( ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis); } +void copy_base_level_only( + const faiss::IndexHNSWCagra& src, + faiss::IndexHNSWCagra& dst) { + auto n = src.ntotal; + auto d = src.d; + auto M = src.hnsw.nb_neighbors(0) / 2; + auto graph_degree = src.hnsw.nb_neighbors(0); + + if (dst.storage && dst.own_fields) { + delete dst.storage; + } + dst.storage = new faiss::IndexFlatL2(d); + dst.own_fields = true; + dst.d = d; + dst.metric_type = src.metric_type; + dst.is_trained = true; + dst.keep_max_size_level0 = true; + + dst.hnsw.reset(); + dst.hnsw.assign_probas.clear(); + dst.hnsw.cum_nneighbor_per_level.clear(); + dst.hnsw.set_default_probas(M, 1.0 / std::log(M)); + + dst.hnsw.prepare_level_tab(n, false); + + auto src_flat = dynamic_cast(src.storage); + FAISS_THROW_IF_NOT(src_flat); + dst.storage->add(n, src_flat->get_xb()); + dst.ntotal = n; + + for (faiss::idx_t i = 0; i < n; i++) { + size_t src_begin, src_end; + src.hnsw.neighbor_range(i, 0, &src_begin, &src_end); + + size_t dst_begin, dst_end; + dst.hnsw.neighbor_range(i, 0, &dst_begin, &dst_end); + + for (size_t j = 0; j < graph_degree && j < (dst_end - dst_begin); j++) { + dst.hnsw.neighbors[dst_begin + j] = + src.hnsw.neighbors[src_begin + j]; + } + } + + dst.base_level_only = true; +} + TEST(HNSW, Test_popmin) { std::vector sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32, 64, 128}; for (const size_t size : sizes) { @@ -218,6 +267,36 @@ TEST(HNSW, Test_IndexHNSW_METRIC_Lp) { EXPECT_EQ(label, 0); // Label should be 0 } +TEST(HNSW, Test_IndexHNSWCagra_BaseLevelOnly_RangeSearch) { + int d = 8; + int nb = 100; + int nq = 5; + int M = 4; + + std::vector xb(nb * d); + std::vector xq(nq * d); + faiss::float_rand(xb.data(), xb.size(), 1234); + faiss::float_rand(xq.data(), xq.size(), 4321); + + faiss::IndexHNSWCagra index(d, M, faiss::METRIC_L2); + index.add(nb, xb.data()); + index.base_level_only = true; + index.num_base_level_search_entrypoints = 8; + + faiss::IndexHNSWCagra dst_index; + copy_base_level_only(index, dst_index); + dst_index.num_base_level_search_entrypoints = 8; + + faiss::RangeSearchResult res(nq); + float radius = 1e9f; + dst_index.range_search(nq, xq.data(), radius, &res); + + for (int i = 0; i < nq; i++) { + auto count = res.lims[i + 1] - res.lims[i]; + EXPECT_GT(count, 0); + } +} + class HNSWTest : public testing::Test { protected: HNSWTest() {