Skip to content

Commit f9fb20a

Browse files
authored
Pass kwargs to BM25
Pass `find_similar` kwargs to `bm25` function
2 parents efdefaf + 041bab4 commit f9fb20a

6 files changed

Lines changed: 34 additions & 9 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77

88
.DS_Store
99
.vscode/
10+
11+
# Scratch files
12+
_*.jl

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Fixed
1212

13+
## [0.2.1]
14+
15+
### Fixed
16+
- Fixed `find_closest` to pass kwargs to `bm25` to allow for normalization of scores
17+
- Fixed a bug in `ChunkEmbeddingsIndex` where users couldn't create a bitpacked index with `embeddings` of type `BitMatrix` (to use `finder=BitPackedCosineSimilarity()`)
18+
1319
## [0.2.0]
1420

1521
### Added

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RAGTools"
22
uuid = "16ddad29-bbe8-45a7-857d-3d9514eb0023"
33
authors = ["J S <49557684+svilupp@users.noreply.github.com> and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/bm25.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ end
103103
bm25(
104104
dtm::AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
105105
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
106-
normalize_min_doc_rel_length::Float32 = 1.0f0)
106+
normalize_min_doc_rel_length::Float32 = 1.0f0, kwargs...)
107107
108108
Scores all documents in `dtm` based on the `query`.
109109
@@ -120,6 +120,7 @@ Theoretically, if you choose `normalize_max_tf` and `normalize_min_doc_rel_lengt
120120
- `normalize_min_doc_rel_length`: The minimum document relative length to normalize to. 0.5 is a good default.
121121
Ideally, pick the minimum document relative length of the corpus that is non-zero
122122
`min_doc_rel_length = minimum(x for x in doc_rel_length(chunkdata(key_index)) if x > 0) |> Float32`
123+
123124
# Example
124125
```
125126
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
@@ -132,7 +133,6 @@ scores = bm25(dtm, query)
132133
Normalization is done by dividing the score by the maximum possible score (given some assumptions).
133134
It's useful to be get results in the same range as cosine similarity scores and when comparing different queries or documents.
134135
135-
# Example
136136
```
137137
documents = [["this", "is", "a", "test"], ["this", "is", "another", "test"], ["foo", "bar", "baz"]]
138138
dtm = document_term_matrix(documents)
@@ -149,7 +149,7 @@ scores_norm = bm25(dtm, query; normalize = true, normalize_max_tf, normalize_min
149149
function bm25(
150150
dtm::AbstractDocumentTermMatrix, query::AbstractVector{<:AbstractString};
151151
k1::Float32 = 1.2f0, b::Float32 = 0.75f0, normalize::Bool = false, normalize_max_tf::Real = 3,
152-
normalize_min_doc_rel_length::Float32 = 0.5f0)
152+
normalize_min_doc_rel_length::Float32 = 0.5f0, kwargs...)
153153
@assert normalize_max_tf>0 "normalize_max_tf term frequency must be positive (got $normalize_max_tf)"
154154
@assert normalize_min_doc_rel_length>0 "normalize_min_doc_rel_length must be positive (got $normalize_min_doc_rel_length)"
155155

src/retrieval.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,23 @@ Finds the closest chunks to a query embedding by measuring the BM25 similarity b
6565
6666
Reference: [Wikipedia: BM25](https://en.wikipedia.org/wiki/Okapi_BM25).
6767
Implementation follows: [The Next Generation of Lucene Relevance](https://opensourceconnections.com/blog/2015/10/16/bm25-the-next-generation-of-lucene-relevation/).
68-
"""
69-
struct BM25Similarity <: AbstractSimilarityFinder end
68+
69+
Fields mimic the arguments of `bm25`.
70+
71+
# Fields
72+
- `k1`: The k1 parameter for BM25. Default is 1.2.
73+
- `b`: The b parameter for BM25. Default is 0.75.
74+
- `normalize`: Whether to normalize the scores. Default is false.
75+
- `normalize_max_tf`: The maximum term frequency to normalize to. Default is 3.
76+
- `normalize_min_doc_rel_length`: The minimum document relative length to normalize to. Default is 1.0.
77+
"""
78+
@kwdef struct BM25Similarity <: AbstractSimilarityFinder
79+
k1::Float32 = 1.2f0
80+
b::Float32 = 0.75f0
81+
normalize::Bool = false
82+
normalize_max_tf::Real = 3
83+
normalize_min_doc_rel_length::Float32 = 1.0f0
84+
end
7085

7186
"""
7287
MultiFinder <: AbstractSimilarityFinder
@@ -452,7 +467,6 @@ function find_closest(
452467
return positions[new_positions], scores
453468
end
454469

455-
function max_bm25_score end
456470
"""
457471
find_closest(
458472
finder::BM25Similarity, dtm::AbstractDocumentTermMatrix,
@@ -468,7 +482,9 @@ function find_closest(
468482
finder::BM25Similarity, dtm::AbstractDocumentTermMatrix,
469483
query_emb::AbstractVector{<:Real}, query_tokens::AbstractVector{<:AbstractString} = String[];
470484
top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...)
471-
scores = bm25(dtm, query_tokens)
485+
## unroll finder kwargs, but let it be overwritten by kwargs if provided
486+
finder_kwargs = [f => getfield(finder, f) for f in fieldnames(BM25Similarity)]
487+
scores = bm25(dtm, query_tokens; finder_kwargs..., kwargs...)
472488
top_k_min = min(top_k, length(scores))
473489
## Take the top_k largest because higher is better in BM25
474490
## BM25 score are non-negative but unbounded (grows with number of keywords)

src/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Previously, this struct was called `ChunkIndex`.
110110
"""
111111
@kwdef struct ChunkEmbeddingsIndex{
112112
T1 <: AbstractString,
113-
T2 <: Union{Nothing, Matrix{<:Real}},
113+
T2 <: Union{Nothing, AbstractMatrix{<:Real}},
114114
T3 <: Union{Nothing, AbstractMatrix{<:Bool}},
115115
T4 <: Union{Nothing, AbstractVector}
116116
} <: AbstractChunkIndex

0 commit comments

Comments
 (0)