-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsync_hybrid.py
More file actions
201 lines (163 loc) · 6.47 KB
/
sync_hybrid.py
File metadata and controls
201 lines (163 loc) · 6.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/env python3
"""Sync missing messages from conversations to conversations_hybrid with sparse vectors."""
import asyncio
import sys
from datetime import UTC, datetime
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))
from qdrant_client import models
from rich.console import Console
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
from claude_kb.db import AsyncQdrantDB, QdrantDB
console = Console()
# Config
QDRANT_URL = "http://localhost:6333"
EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5"
async def sync_collections():
"""Find and migrate missing messages to hybrid collection."""
console.print("[bold cyan]=== Syncing Missing Messages to Hybrid Collection ===[/]")
console.print()
# Initialize clients
db = QdrantDB(QDRANT_URL, None, EMBEDDING_MODEL)
async_db = AsyncQdrantDB(QDRANT_URL, None, EMBEDDING_MODEL)
# Load sparse model
console.print("Loading sparse embedding model...")
db.sparse_model.load()
console.print("[green]✓ Sparse model ready[/]")
console.print()
# Get all IDs from both collections
console.print("Fetching message IDs from both collections...")
conversations_ids = set()
hybrid_ids = set()
# Scroll conversations
offset = None
while True:
results = db.client.scroll(
collection_name="conversations",
limit=1000,
offset=offset,
with_payload=False,
with_vectors=False,
)
points, offset = results
if not points:
break
conversations_ids.update(p.id for p in points)
if offset is None:
break
console.print(f" conversations: {len(conversations_ids):,} IDs")
# Scroll hybrid
offset = None
while True:
results = db.client.scroll(
collection_name="conversations_hybrid",
limit=1000,
offset=offset,
with_payload=False,
with_vectors=False,
)
points, offset = results
if not points:
break
hybrid_ids.update(p.id for p in points)
if offset is None:
break
console.print(f" conversations_hybrid: {len(hybrid_ids):,} IDs")
console.print()
# Find missing IDs
missing_ids = conversations_ids - hybrid_ids
if not missing_ids:
console.print("[green]✓ No missing messages - collections are in sync![/]")
return
console.print(f"[yellow]Found {len(missing_ids):,} missing messages[/]")
console.print()
console.print(f"[cyan]Migrating {len(missing_ids):,} messages to hybrid collection...[/]")
console.print()
# Fetch and migrate in batches
missing_list = list(missing_ids)
batch_size = 100
migrated = 0
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task("Migrating...", total=len(missing_list))
for i in range(0, len(missing_list), batch_size):
batch_ids = missing_list[i : i + batch_size]
# Retrieve points from conversations (with vectors and payload)
points = db.client.retrieve(
collection_name="conversations",
ids=batch_ids,
with_payload=True,
with_vectors=True,
)
# Extract content for sparse embeddings
contents = []
for point in points:
if not point.payload:
contents.append("")
continue
content = point.payload.get("content", "")
if isinstance(content, list | dict):
import json
content = json.dumps(content)
contents.append(str(content)[:8000])
# Generate sparse embeddings
sparse_embeddings = db.sparse_model.encode(contents)
# Build new points with both vectors
new_points = []
for j, point in enumerate(points):
# Get existing dense vector
dense_vec = point.vector
# Get sparse vector
sparse = sparse_embeddings[j]
# Ensure timestamp_unix is in payload
payload = dict(point.payload) if point.payload else {}
if "timestamp_unix" not in payload:
timestamp_str = payload.get("timestamp", "")
try:
if "+" in timestamp_str or timestamp_str.endswith("Z"):
ts = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
else:
ts = datetime.fromisoformat(timestamp_str)
if ts.tzinfo is None:
ts = ts.replace(tzinfo=UTC)
payload["timestamp_unix"] = int(ts.timestamp())
payload["schema_version"] = 2
except (ValueError, TypeError, AttributeError):
# Use current time as fallback
payload["timestamp_unix"] = int(datetime.now(UTC).timestamp())
payload["schema_version"] = 2
new_points.append(
models.PointStruct(
id=point.id,
payload=payload,
vector={
"dense": dense_vec,
"sparse": models.SparseVector(
indices=sparse.indices.tolist(),
values=sparse.values.tolist(),
),
},
)
)
# Upsert to hybrid collection
await async_db.client.upsert(
collection_name="conversations_hybrid",
points=new_points,
)
migrated += len(points)
progress.update(task, completed=migrated)
console.print()
console.print(f"[bold green]✓ Successfully migrated {migrated:,} messages![/]")
console.print()
# Verify
hybrid_info = db.client.get_collection("conversations_hybrid")
console.print(f"conversations_hybrid now has: {hybrid_info.points_count:,} points")
await async_db.close()
if __name__ == "__main__":
asyncio.run(sync_collections())