Skip to content

Commit 714ba8d

Browse files
inabaoclaude
andcommitted
feat(warp): add warp algorithm for multi-vector similarity search
- Add warp algorithm for multi-vector similarity search - Add warp_parameter for configuration - Add serialize/deserialize support - Add test cases for warp index - Optimize compute_maxsin_similarity by moving vec_indices and dists allocation outside the query loop - Implement binary tree merge for parallel heap merging in SearchWithRequest - Add parallel support for RangeSearch using atomic index distribution - Use BatchInsertVector in add_one_doc for batch insertion - Simplify serialization/deserialization using WriteVector/ReadVector for doc_offsets - Fix Train to use total vector count instead of document count - Pre-calculate and reserve vector capacity in Add to avoid frequent resizes - Remove unused variables (num_elements, query_num) - Fix nullptr pointer arithmetic for extra_info Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
1 parent cafb34b commit 714ba8d

21 files changed

Lines changed: 1516 additions & 36 deletions

examples/cpp/110_index_warp.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include <vsag/vsag.h>
17+
18+
#include <iostream>
19+
#include <vector>
20+
21+
int
22+
main(int argc, char** argv) {
23+
vsag::init();
24+
25+
/******************* Prepare Multi-Vector Base Dataset *****************/
26+
// In ColBERT-style retrieval, each document has multiple vectors (one per token)
27+
// We'll create 100 documents, each with variable number of vectors (5-10)
28+
int64_t num_docs = 100;
29+
int64_t dim = 128;
30+
31+
std::vector<int64_t> ids(num_docs);
32+
std::vector<uint32_t> vector_counts(num_docs);
33+
std::vector<float> datas;
34+
std::mt19937 rng(47);
35+
std::uniform_real_distribution<float> distrib_real;
36+
std::uniform_int_distribution<int> vec_count_dist(5, 10);
37+
38+
// Generate document IDs and variable vector counts
39+
uint64_t total_vectors = 0;
40+
for (int64_t i = 0; i < num_docs; ++i) {
41+
ids[i] = i;
42+
vector_counts[i] = vec_count_dist(rng);
43+
total_vectors += vector_counts[i];
44+
}
45+
46+
// Generate all vectors
47+
datas.reserve(total_vectors * dim);
48+
for (uint64_t i = 0; i < total_vectors * dim; ++i) {
49+
datas.push_back(distrib_real(rng));
50+
}
51+
52+
// Create dataset with VectorCounts for multi-vector support
53+
auto base = vsag::Dataset::Make();
54+
base->NumElements(num_docs)
55+
->Dim(dim)
56+
->Ids(ids.data())
57+
->Float32Vectors(datas.data())
58+
->VectorCounts(vector_counts.data()) // Specify number of vectors per document
59+
->Owner(false);
60+
61+
std::cout << "Created multi-vector dataset with " << num_docs << " documents" << std::endl;
62+
std::cout << "Total vectors: " << total_vectors << " (avg " << (double)total_vectors / num_docs
63+
<< " vectors per doc)" << std::endl;
64+
65+
/******************* Create WARP Index *****************/
66+
// WARP index for ColBERT-style maxsin similarity
67+
std::string warp_build_parameters = R"(
68+
{
69+
"dtype": "float32",
70+
"metric_type": "ip",
71+
"dim": 128
72+
}
73+
)";
74+
auto index = vsag::Factory::CreateIndex("warp", warp_build_parameters).value();
75+
76+
/******************* Build WARP Index *****************/
77+
if (auto build_result = index->Build(base); build_result.has_value()) {
78+
std::cout << "After Build(), Index WARP contains: " << index->GetNumElements()
79+
<< " documents" << std::endl;
80+
} else {
81+
std::cerr << "Failed to build index: " << build_result.error().message << std::endl;
82+
exit(-1);
83+
}
84+
85+
/******************* Prepare Multi-Vector Query *****************/
86+
// Query also has multiple vectors (representing query tokens)
87+
uint32_t query_vec_count = 3;
88+
std::vector<float> query_vectors(query_vec_count * dim);
89+
for (uint32_t i = 0; i < query_vec_count * dim; ++i) {
90+
query_vectors[i] = distrib_real(rng);
91+
}
92+
93+
auto query = vsag::Dataset::Make();
94+
query->NumElements(1)
95+
->Dim(dim)
96+
->Float32Vectors(query_vectors.data())
97+
->VectorCounts(&query_vec_count) // Specify query has multiple vectors
98+
->Owner(false);
99+
100+
std::cout << "Query has " << query_vec_count << " vectors" << std::endl;
101+
102+
/******************* KnnSearch For WARP Index *****************/
103+
// WARP performs maxsin similarity: for each query vector, find max similarity
104+
// with any document vector, sum across query vectors
105+
auto warp_search_parameters = R"({})";
106+
int64_t topk = 5;
107+
auto result = index->KnnSearch(query, topk, warp_search_parameters).value();
108+
109+
/******************* Print Search Result *****************/
110+
std::cout << "Top-" << topk << " results (maxsin similarity): " << std::endl;
111+
for (int64_t i = 0; i < result->GetDim(); ++i) {
112+
std::cout << " Document " << result->GetIds()[i]
113+
<< ": score = " << result->GetDistances()[i] << std::endl;
114+
}
115+
116+
return 0;
117+
}

examples/cpp/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ target_link_libraries (108_index_gno_imi vsag)
4141
add_executable (109_index_sindi 109_index_sindi.cpp)
4242
target_link_libraries (109_index_sindi vsag)
4343

44+
add_executable (110_index_warp 110_index_warp.cpp)
45+
target_link_libraries (110_index_warp vsag)
46+
4447
add_executable (201_custom_allocator 201_custom_allocator.cpp)
4548
target_link_libraries (201_custom_allocator vsag)
4649

include/vsag/constants.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ extern const char* const INDEX_SPARSE;
2525
extern const char* const INDEX_SINDI;
2626
extern const char* const INDEX_BRUTE_FORCE;
2727
extern const char* const INDEX_IVF;
28+
extern const char* const INDEX_WARP;
2829
extern const char* const DIM;
2930
extern const char* const NUM_ELEMENTS;
3031
extern const char* const IDS;
@@ -36,6 +37,7 @@ extern const char* const ATTRIBUTE_SETS;
3637
extern const char* const DATASET_PATHS;
3738
extern const char* const EXTRA_INFOS;
3839
extern const char* const EXTRA_INFO_SIZE;
40+
extern const char* const VECTOR_COUNTS;
3941

4042
extern const char* const HNSW_DATA;
4143
extern const char* const CONJUGATE_GRAPH_DATA;

include/vsag/dataset.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,23 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
354354
*/
355355
virtual std::vector<std::string>
356356
GetStatistics(const std::vector<std::string>& stat_keys) const = 0;
357+
358+
/**
359+
* @brief Sets the vector counts array for the dataset (for multi-vector documents).
360+
*
361+
* @param counts Pointer to the array of vector counts per document.
362+
* @return DatasetPtr A shared pointer to the dataset with updated vector counts.
363+
*/
364+
virtual DatasetPtr
365+
VectorCounts(const uint32_t* counts) = 0;
366+
367+
/**
368+
* @brief Retrieves the vector counts array of the dataset.
369+
*
370+
* @return const uint32_t* Pointer to the array of vector counts per document.
371+
*/
372+
virtual const uint32_t*
373+
GetVectorCounts() const = 0;
357374
};
358375

359376
}; // namespace vsag

include/vsag/index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct MergeUnit {
4343
IdMapFunction id_map_func = nullptr;
4444
};
4545

46-
enum class IndexType { HNSW, DISKANN, HGRAPH, IVF, PYRAMID, BRUTEFORCE, SPARSE, SINDI };
46+
enum class IndexType { HNSW, DISKANN, HGRAPH, IVF, PYRAMID, BRUTEFORCE, SPARSE, SINDI, WARP };
4747

4848
#define DATA_FLAG_FLOAT32_VECTOR 0x01
4949
#define DATA_FLAG_INT8_VECTOR 0x02

0 commit comments

Comments
 (0)