Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions datasketch/lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,24 @@ def collect_query_buffer(self) -> list[Hashable]:
list: a list of unique keys.

"""
collected_result_sets = [
set(collected_result_lists)
for hashtable in self.hashtables
for collected_result_lists in hashtable.collect_select_buffer()
collected_result_lists = [hashtable.collect_select_buffer() for hashtable in self.hashtables]
if not any(collected_result_lists):
return []

# Each buffered query contributes one result list per hashtable. We first
# union candidates across bands for each query, then intersect across the
# buffered queries to match repeated calls to `query()`.
per_query_result_sets = [
set().union(*query_result_lists)
for query_result_lists in zip(*collected_result_lists)
]
if not collected_result_sets:
if not per_query_result_sets:
return []

candidates = set.intersection(*per_query_result_sets)
if self.prepickle:
return [pickle.loads(key) for key in set.intersection(*collected_result_sets)]
return list(set.intersection(*collected_result_sets))
return [pickle.loads(key) for key in candidates]
return list(candidates)

def __contains__(self, key: Hashable) -> bool:
"""Args:
Expand Down
2 changes: 1 addition & 1 deletion datasketch/lshensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def index(self, entries: Iterable[tuple[Hashable, MinHash, int]]) -> None:
if not self.is_empty():
raise ValueError("Cannot call index again on a non-empty index")
if not isinstance(entries, list):
queue = deque([])
queue = deque()
for key, minhash, size in entries:
if size <= 0:
raise ValueError("Set size must be positive")
Expand Down
10 changes: 6 additions & 4 deletions datasketch/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,14 @@ def collect_select_buffer(self):
del self._select_statements_and_parameters_with_decoders[:]
statements_and_parameters, decoders = zip(*buffer)

ret = collections.defaultdict(list)
query_results = self._select(statements_and_parameters)
for rows, (key_decoder, val_decoder) in zip(query_results, decoders):
ret = []
for rows, (_key_decoder, val_decoder) in zip(query_results, decoders):
values = []
for row in rows:
ret[key_decoder(row.key)].append((val_decoder(row.value), row.ts))
return [[x[0] for x in sorted(v, key=operator.itemgetter(1))] for v in ret.values()]
values.append((val_decoder(row.value), row.ts))
ret.append([x[0] for x in sorted(values, key=operator.itemgetter(1))])
return ret

def select(self, keys):
"""Select all values for the given keys.
Expand Down
4 changes: 2 additions & 2 deletions docs/lshforest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ MinHash LSH Forest

:ref:`minhash_lsh` is useful for radius (or threshold) queries. However,
**top-k** queries are often more useful in some cases. `LSH
Forest <http://ilpubs.stanford.edu:8090/678/1/2005-14.pdf>`__ by Bawa et
Forest <https://dblp.org/rec/conf/www/BawaCG05>`__ by Bawa et
al. is a general LSH data structure that makes top-k query possible for
many different types of LSH indexes, which include MinHash LSH. I
implemented the MinHash LSH Forest, which takes a :ref:`minhash` data sketch of
Expand Down Expand Up @@ -70,7 +70,7 @@ for details.
:alt: MinHashLSHForest Benchmark

(Optional) If you have read the LSH Forest
`paper <http://ilpubs.stanford.edu:8090/678/1/2005-14.pdf>`__, and
`paper <https://dblp.org/rec/conf/www/BawaCG05>`__, and
understand the data structure, you may want to customize another
parameter for :class:`datasketch.MinHashLSHForest` -- ``l``, the number of prefix trees
(or "LSH Trees" as in the paper) in the LSH Forest index. Different from
Expand Down
18 changes: 18 additions & 0 deletions test/test_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def test_query_buffer(self):
m3 = MinHash(18)
self.assertRaises(ValueError, lsh.add_to_query_buffer, m3)

def test_query_buffer_matches_query_candidates(self):
lsh = MinHashLSH(threshold=0.5, num_perm=32)
docs = []
for tokens in ([b"a", b"b", b"c"], [b"a", b"b", b"d"], [b"x", b"y", b"z"]):
minhash = MinHash(num_perm=32)
for token in tokens:
minhash.update(token)
docs.append(minhash)
for key, minhash in enumerate(docs):
lsh.insert(key, minhash)

lsh.add_to_query_buffer(docs[0])
buffered_result = set(lsh.collect_query_buffer())
direct_result = set(lsh.query(docs[0]))

self.assertEqual(buffered_result, direct_result)
self.assertEqual(buffered_result, {0, 1})

def test_remove(self):
lsh = MinHashLSH(threshold=0.5, num_perm=16)
m1 = MinHash(16)
Expand Down
Loading