Skip to content

Commit 732a559

Browse files
FabioLissikAI Shtefan
authored andcommitted
feat: detect embedding model mismatch on collection open
Stores the embedding model identity in ChromaDB collection metadata at creation time. On every collection open, compares stored model to current config. Raises EmbeddingModelMismatchError on mismatch with recovery instructions (mempalace re-mine or MEMPALACE_FORCE_EMBEDDING). - get_embedding_model_name() returns canonical model identity string - _resolve_model_and_device() shared by both name and function resolvers - force_embedding config (env var / config.json) to bypass check - Legacy palaces silently stamped on first open - All 7 direct ChromaDB callers routed through palace.get_collection() - Fixed pre-existing test env isolation bug (monkeypatch.delenv)
1 parent ff11e3a commit 732a559

13 files changed

Lines changed: 429 additions & 323 deletions

mempalace/cli.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
import argparse
3434
from pathlib import Path
3535

36-
from .config import MempalaceConfig, get_embedding_function
36+
from .config import MempalaceConfig
37+
from .palace import get_collection as _palace_get_collection
3738

3839

3940
def cmd_init(args):
@@ -183,9 +184,7 @@ def cmd_repair(args):
183184

184185
# Try to read existing drawers
185186
try:
186-
ef = get_embedding_function()
187-
client = chromadb.PersistentClient(path=palace_path)
188-
col = client.get_collection("mempalace_drawers", embedding_function=ef)
187+
col = _palace_get_collection(palace_path, force=True)
189188
total = col.count()
190189
print(f" Drawers found: {total}")
191190
except Exception as e:
@@ -221,8 +220,9 @@ def cmd_repair(args):
221220
shutil.copytree(palace_path, backup_path)
222221

223222
print(" Rebuilding collection...")
223+
client = chromadb.PersistentClient(path=palace_path)
224224
client.delete_collection("mempalace_drawers")
225-
new_col = client.create_collection("mempalace_drawers", embedding_function=ef)
225+
new_col = _palace_get_collection(palace_path)
226226

227227
filed = 0
228228
for i in range(0, len(all_ids), batch_size):
@@ -275,7 +275,6 @@ def cmd_mcp(args):
275275

276276
def cmd_compress(args):
277277
"""Compress drawers in a wing using AAAK Dialect."""
278-
import chromadb
279278
from .dialect import Dialect
280279

281280
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
@@ -296,9 +295,7 @@ def cmd_compress(args):
296295

297296
# Connect to palace
298297
try:
299-
ef = get_embedding_function()
300-
client = chromadb.PersistentClient(path=palace_path)
301-
col = client.get_collection("mempalace_drawers", embedding_function=ef)
298+
col = _palace_get_collection(palace_path)
302299
except Exception:
303300
print(f"\n No palace found at {palace_path}")
304301
print(" Run: mempalace init <dir> then mempalace mine <dir>")
@@ -369,9 +366,7 @@ def cmd_compress(args):
369366
# Store compressed versions (unless dry-run)
370367
if not args.dry_run:
371368
try:
372-
comp_col = client.get_or_create_collection(
373-
"mempalace_compressed", embedding_function=ef
374-
)
369+
comp_col = _palace_get_collection(palace_path, "mempalace_compressed")
375370
for doc_id, compressed, meta, stats in compressed_entries:
376371
comp_meta = dict(meta)
377372
comp_meta["compression_ratio"] = round(stats["ratio"], 1)

mempalace/config.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,20 @@ def embedding_device(self):
222222
return env_val
223223
return self._file_config.get("embedding_device", None)
224224

225+
@property
226+
def force_embedding(self):
227+
"""Whether to bypass embedding model mismatch checks.
228+
229+
Priority: ``MEMPALACE_FORCE_EMBEDDING`` env var > ``config.json``
230+
``"force_embedding"`` key > ``False`` default.
231+
232+
Env var: ``"true"`` (case-insensitive) to enable, anything else is false.
233+
"""
234+
env_val = os.environ.get("MEMPALACE_FORCE_EMBEDDING")
235+
if env_val is not None:
236+
return env_val.lower() == "true"
237+
return bool(self._file_config.get("force_embedding", False))
238+
225239
def save_people_map(self, people_map):
226240
"""Write people_map.json to config directory.
227241
@@ -245,6 +259,57 @@ def save_people_map(self, people_map):
245259
_DEFAULT_MODEL_FOR_DEVICE = "sentence-transformers/all-MiniLM-L6-v2"
246260

247261

262+
class EmbeddingModelMismatchError(Exception):
263+
"""Raised when palace was created with a different embedding model."""
264+
265+
def __init__(self, stored_model: str, current_model: str):
266+
self.stored_model = stored_model
267+
self.current_model = current_model
268+
super().__init__(
269+
f"Embedding model mismatch.\n"
270+
f"Palace was created with: {stored_model}\n"
271+
f"Currently configured: {current_model}\n\n"
272+
f"To switch models, re-mine your palace:\n"
273+
f" mempalace re-mine\n\n"
274+
f"Or set MEMPALACE_FORCE_EMBEDDING=true to bypass this check."
275+
)
276+
277+
278+
def _resolve_model_and_device(config=None):
279+
"""Resolve the effective model name and device from config.
280+
281+
Returns (model_name, device) where model_name is None when no
282+
explicit model is configured (i.e. ChromaDB default should be used).
283+
"""
284+
cfg = config or MempalaceConfig()
285+
model_name = cfg.embedding_model
286+
device = cfg.embedding_device
287+
288+
# Ergonomic default: if the user asked for a device but didn't pick a
289+
# model, use the same model ChromaDB uses by default so vectors stay
290+
# compatible with existing palaces.
291+
if not model_name and device:
292+
model_name = _DEFAULT_MODEL_FOR_DEVICE
293+
294+
return model_name, device
295+
296+
297+
def get_embedding_model_name(config=None):
298+
"""Return the canonical identity string for the active embedding model.
299+
300+
Does **not** instantiate the model or touch the singleton cache.
301+
302+
Returns:
303+
``"chromadb-default"`` when no model/device is configured,
304+
``_DEFAULT_MODEL_FOR_DEVICE`` when only a device is set,
305+
or the explicit model name string.
306+
"""
307+
model_name, _ = _resolve_model_and_device(config)
308+
if not model_name:
309+
return "chromadb-default"
310+
return model_name
311+
312+
248313
def get_embedding_function(config=None):
249314
"""Return the configured ChromaDB embedding function.
250315
@@ -275,15 +340,7 @@ def get_embedding_function(config=None):
275340

276341
_embedding_function_resolved = True
277342

278-
cfg = config or MempalaceConfig()
279-
model_name = cfg.embedding_model
280-
device = cfg.embedding_device
281-
282-
# Ergonomic default: if the user asked for a device but didn't pick a
283-
# model, use the same model ChromaDB uses by default so vectors stay
284-
# compatible with existing palaces.
285-
if not model_name and device:
286-
model_name = _DEFAULT_MODEL_FOR_DEVICE
343+
model_name, device = _resolve_model_and_device(config)
287344

288345
if not model_name:
289346
# No explicit configuration — use ChromaDB's default embedder.

mempalace/layers.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from pathlib import Path
2222
from collections import defaultdict
2323

24-
import chromadb
25-
26-
from .config import MempalaceConfig, get_embedding_function
24+
from .config import MempalaceConfig
25+
from .palace import get_collection as _palace_get_collection
2726

2827

2928
# ---------------------------------------------------------------------------
@@ -91,10 +90,7 @@ def __init__(self, palace_path: str = None, wing: str = None):
9190
def generate(self) -> str:
9291
"""Pull top drawers from ChromaDB and format as compact L1 text."""
9392
try:
94-
client = chromadb.PersistentClient(path=self.palace_path)
95-
col = client.get_collection(
96-
"mempalace_drawers", embedding_function=get_embedding_function()
97-
)
93+
col = _palace_get_collection(self.palace_path)
9894
except Exception:
9995
return "## L1 — No palace found. Run: mempalace mine <dir>"
10096

@@ -198,10 +194,7 @@ def __init__(self, palace_path: str = None):
198194
def retrieve(self, wing: str = None, room: str = None, n_results: int = 10) -> str:
199195
"""Retrieve drawers filtered by wing and/or room."""
200196
try:
201-
client = chromadb.PersistentClient(path=self.palace_path)
202-
col = client.get_collection(
203-
"mempalace_drawers", embedding_function=get_embedding_function()
204-
)
197+
col = _palace_get_collection(self.palace_path)
205198
except Exception:
206199
return "No palace found."
207200

@@ -264,10 +257,7 @@ def __init__(self, palace_path: str = None):
264257
def search(self, query: str, wing: str = None, room: str = None, n_results: int = 5) -> str:
265258
"""Semantic search, returns compact result text."""
266259
try:
267-
client = chromadb.PersistentClient(path=self.palace_path)
268-
col = client.get_collection(
269-
"mempalace_drawers", embedding_function=get_embedding_function()
270-
)
260+
col = _palace_get_collection(self.palace_path)
271261
except Exception:
272262
return "No palace found."
273263

@@ -322,10 +312,7 @@ def search_raw(
322312
) -> list:
323313
"""Return raw dicts instead of formatted text."""
324314
try:
325-
client = chromadb.PersistentClient(path=self.palace_path)
326-
col = client.get_collection(
327-
"mempalace_drawers", embedding_function=get_embedding_function()
328-
)
315+
col = _palace_get_collection(self.palace_path)
329316
except Exception:
330317
return []
331318

@@ -445,10 +432,7 @@ def status(self) -> dict:
445432

446433
# Count drawers
447434
try:
448-
client = chromadb.PersistentClient(path=self.palace_path)
449-
col = client.get_collection(
450-
"mempalace_drawers", embedding_function=get_embedding_function()
451-
)
435+
col = _palace_get_collection(self.palace_path)
452436
count = col.count()
453437
result["total_drawers"] = count
454438
except Exception:

mempalace/mcp_server.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
from datetime import datetime
2727
from pathlib import Path
2828

29-
from .config import MempalaceConfig, sanitize_name, sanitize_content, get_embedding_function
29+
from .config import MempalaceConfig, sanitize_name, sanitize_content
3030
from .version import __version__
31+
from .palace import iter_all_metadatas, get_collection as _palace_get_collection
3132
from .searcher import search_memories
3233
from .palace_graph import traverse, find_tunnels, graph_stats
33-
import chromadb
3434

3535
from .knowledge_graph import KnowledgeGraph
3636

@@ -100,31 +100,16 @@ def _wal_log(operation: str, params: dict, result: dict = None):
100100
logger.error(f"WAL write failed: {e}")
101101

102102

103-
_client_cache = None
104103
_collection_cache = None
105104

106105

107-
def _get_client():
108-
"""Return a singleton ChromaDB PersistentClient."""
109-
global _client_cache
110-
if _client_cache is None:
111-
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
112-
return _client_cache
113-
114-
115106
def _get_collection(create=False):
116107
"""Return the ChromaDB collection, caching the client between calls."""
117108
global _collection_cache
118109
try:
119-
ef = get_embedding_function()
120-
client = _get_client()
121-
if create:
122-
_collection_cache = client.get_or_create_collection(
123-
_config.collection_name, embedding_function=ef
124-
)
125-
elif _collection_cache is None:
126-
_collection_cache = client.get_collection(
127-
_config.collection_name, embedding_function=ef
110+
if create or _collection_cache is None:
111+
_collection_cache = _palace_get_collection(
112+
_config.palace_path, _config.collection_name
128113
)
129114
return _collection_cache
130115
except Exception:

mempalace/miner.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
from datetime import datetime
1616
from collections import defaultdict
1717

18-
import chromadb
19-
20-
from .config import get_embedding_function
2118
from .palace import SKIP_DIRS, get_collection, file_already_mined
2219

2320
READABLE_EXTENSIONS = {
@@ -625,10 +622,10 @@ def mine(
625622

626623
def status(palace_path: str):
627624
"""Show what's been filed in the palace."""
625+
from .palace import iter_all_metadatas, get_collection as _palace_get_collection
626+
628627
try:
629-
ef = get_embedding_function()
630-
client = chromadb.PersistentClient(path=palace_path)
631-
col = client.get_collection("mempalace_drawers", embedding_function=ef)
628+
col = _palace_get_collection(palace_path)
632629
except Exception:
633630
print(f"\n No palace found at {palace_path}")
634631
print(" Run: mempalace init <dir> then mempalace mine <dir>")

mempalace/palace.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44
Consolidates ChromaDB access patterns used by both miners and the MCP server.
55
"""
66

7+
import logging
78
import os
9+
810
import chromadb
911

10-
from .config import get_embedding_function
12+
from .config import (
13+
EmbeddingModelMismatchError,
14+
MempalaceConfig,
15+
get_embedding_function,
16+
get_embedding_model_name,
17+
)
18+
19+
logger = logging.getLogger("mempalace")
1120

1221
SKIP_DIRS = {
1322
".git",
@@ -36,19 +45,51 @@
3645
}
3746

3847

39-
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
40-
"""Get or create the palace ChromaDB collection."""
48+
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers", force: bool = None):
49+
"""Get or create the palace ChromaDB collection.
50+
51+
Verifies that the collection's embedding model matches the currently
52+
configured model. Raises EmbeddingModelMismatchError on mismatch
53+
unless force=True or MEMPALACE_FORCE_EMBEDDING=true.
54+
"""
4155
os.makedirs(palace_path, exist_ok=True)
4256
try:
4357
os.chmod(palace_path, 0o700)
4458
except (OSError, NotImplementedError):
4559
pass
60+
4661
ef = get_embedding_function()
62+
current_model = get_embedding_model_name()
63+
if force is None:
64+
force = MempalaceConfig().force_embedding
65+
4766
client = chromadb.PersistentClient(path=palace_path)
4867
try:
49-
return client.get_collection(collection_name, embedding_function=ef)
68+
col = client.get_collection(collection_name, embedding_function=ef)
69+
stored_model = (col.metadata or {}).get("embedding_model")
70+
71+
if stored_model is None:
72+
# Legacy palace — silent stamp
73+
col.modify(metadata={**(col.metadata or {}), "embedding_model": current_model})
74+
elif stored_model != current_model:
75+
if force:
76+
logger.warning(
77+
"Embedding model mismatch (forced): %s -> %s",
78+
stored_model, current_model,
79+
)
80+
col.modify(metadata={**(col.metadata or {}), "embedding_model": current_model})
81+
else:
82+
raise EmbeddingModelMismatchError(stored_model, current_model)
83+
84+
return col
85+
except EmbeddingModelMismatchError:
86+
raise
5087
except Exception:
51-
return client.create_collection(collection_name, embedding_function=ef)
88+
return client.create_collection(
89+
collection_name,
90+
embedding_function=ef,
91+
metadata={"embedding_model": current_model},
92+
)
5293

5394

5495
def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:

0 commit comments

Comments
 (0)