Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions mempalace/backends/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ def backend_version() -> str:
# Collection lifecycle
# ------------------------------------------------------------------

def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
def get_collection(
self,
palace_path: str,
collection_name: str,
create: bool = False,
embedding_function=None,
embedding_model_name: str = None,
):
if not create and not os.path.isdir(palace_path):
raise FileNotFoundError(palace_path)

Expand All @@ -124,29 +131,52 @@ def get_collection(self, palace_path: str, collection_name: str, create: bool =
pass

client = self._client(palace_path)
ef_kwargs = {"embedding_function": embedding_function} if embedding_function else {}
if create:
metadata = {"hnsw:space": "cosine"}
if embedding_model_name:
metadata["embedding_model"] = embedding_model_name
collection = client.get_or_create_collection(
collection_name, metadata={"hnsw:space": "cosine"}
collection_name, metadata=metadata, **ef_kwargs
)
else:
collection = client.get_collection(collection_name)
collection = client.get_collection(collection_name, **ef_kwargs)
return ChromaCollection(collection)

def get_or_create_collection(
self, palace_path: str, collection_name: str
self,
palace_path: str,
collection_name: str,
embedding_function=None,
embedding_model_name: str = None,
) -> "ChromaCollection":
"""Shorthand for get_collection(..., create=True)."""
return self.get_collection(palace_path, collection_name, create=True)
return self.get_collection(
palace_path,
collection_name,
create=True,
embedding_function=embedding_function,
embedding_model_name=embedding_model_name,
)

def delete_collection(self, palace_path: str, collection_name: str) -> None:
"""Delete *collection_name* from the palace at *palace_path*."""
self._client(palace_path).delete_collection(collection_name)

def create_collection(
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
self,
palace_path: str,
collection_name: str,
hnsw_space: str = "cosine",
embedding_function=None,
embedding_model_name: str = None,
) -> "ChromaCollection":
"""Create (not get-or-create) *collection_name* with cosine HNSW space."""
"""Create (not get-or-create) *collection_name* with HNSW space and optional embedding config."""
metadata = {"hnsw:space": hnsw_space}
if embedding_model_name:
metadata["embedding_model"] = embedding_model_name
ef_kwargs = {"embedding_function": embedding_function} if embedding_function else {}
collection = self._client(palace_path).create_collection(
collection_name, metadata={"hnsw:space": hnsw_space}
collection_name, metadata=metadata, **ef_kwargs
)
return ChromaCollection(collection)
18 changes: 17 additions & 1 deletion mempalace/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,25 @@ def cmd_repair(args):
print(f" Backing up to {backup_path}...")
shutil.copytree(palace_path, backup_path)

# Read the embedding model from the existing collection before deleting
from .embedding import get_embedding_function, resolve_model_from_metadata

try:
raw_client = backend._client(palace_path)
raw_col = raw_client.get_collection("mempalace_drawers")
repair_model = resolve_model_from_metadata(raw_col.metadata)
except Exception:
repair_model = resolve_model_from_metadata(None)
repair_ef = get_embedding_function(repair_model)

print(" Rebuilding collection...")
backend.delete_collection(palace_path, "mempalace_drawers")
new_col = backend.create_collection(palace_path, "mempalace_drawers")
new_col = backend.create_collection(
palace_path,
"mempalace_drawers",
embedding_function=repair_ef,
embedding_model_name=repair_model,
)

filed = 0
for i in range(0, len(all_ids), batch_size):
Expand Down
14 changes: 14 additions & 0 deletions mempalace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@ def hook_desktop_toast(self):
"""Whether the stop hook shows a desktop notification via notify-send."""
return self._file_config.get("hooks", {}).get("desktop_toast", False)

@property
def embedding_model(self):
"""Preferred embedding model for new palaces.

Used only when creating a new palace. Existing palaces read
the model from their collection metadata (see embedding.py).
"""
from .embedding import NEW_PALACE_MODEL

env = os.environ.get("MEMPALACE_EMBEDDING_MODEL")
if env:
return env
return self._file_config.get("embedding_model", NEW_PALACE_MODEL)

def set_hook_setting(self, key: str, value: bool):
"""Update a hook setting and write config to disk."""
if "hooks" not in self._file_config:
Expand Down
51 changes: 51 additions & 0 deletions mempalace/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
embedding.py — Centralized embedding model configuration.

Single source of truth for which embedding model the palace uses.
Resolves model from ChromaDB collection metadata (stamped at build time),
with fallback to legacy default for existing palaces.
"""

import os

# Legacy model — used by all palaces created before this feature.
# ChromaDB 0.6.3 uses this as its built-in default (384 dimensions).
DEFAULT_MODEL = "all-MiniLM-L6-v2"

# Model for newly created palaces — better search quality (768 dimensions).
NEW_PALACE_MODEL = "all-mpnet-base-v2"


def get_embedding_function(model_name: str):
"""Return a ChromaDB-compatible embedding function for the given model.

Uses ChromaDB's built-in SentenceTransformerEmbeddingFunction which
auto-downloads and caches the model on first use.
"""
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction

return SentenceTransformerEmbeddingFunction(model_name=model_name)


def resolve_model_from_metadata(collection_metadata: dict) -> str:
"""Resolve which embedding model was used from collection metadata.

Returns DEFAULT_MODEL if metadata is missing or doesn't contain the key —
this means the palace was created before embedding model tracking existed.
"""
if collection_metadata and "embedding_model" in collection_metadata:
return collection_metadata["embedding_model"]
return DEFAULT_MODEL


def new_palace_model(config=None) -> str:
"""Return the embedding model to use for new palace creation.

Resolution: env var > config > NEW_PALACE_MODEL constant.
"""
env = os.environ.get("MEMPALACE_EMBEDDING_MODEL")
if env:
return env
if config is not None and hasattr(config, "embedding_model"):
return config.embedding_model
return NEW_PALACE_MODEL
25 changes: 23 additions & 2 deletions mempalace/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,30 @@ def _get_client():
def _get_collection(create=False):
"""Return the ChromaDB collection, caching the client between calls."""
global _collection_cache, _metadata_cache, _metadata_cache_time
from .embedding import get_embedding_function, new_palace_model, resolve_model_from_metadata

try:
client = _get_client()
if create:
model = new_palace_model(_config)
ef = get_embedding_function(model)
_collection_cache = ChromaCollection(
client.get_or_create_collection(
_config.collection_name, metadata={"hnsw:space": "cosine"}
_config.collection_name,
metadata={"hnsw:space": "cosine", "embedding_model": model},
embedding_function=ef,
)
)
_metadata_cache = None
_metadata_cache_time = 0
elif _collection_cache is None:
_collection_cache = ChromaCollection(client.get_collection(_config.collection_name))
# Read metadata to resolve model, then open with correct EF
raw_col = client.get_collection(_config.collection_name)
model = resolve_model_from_metadata(raw_col.metadata)
ef = get_embedding_function(model)
_collection_cache = ChromaCollection(
client.get_collection(_config.collection_name, embedding_function=ef)
)
_metadata_cache = None
_metadata_cache_time = 0
return _collection_cache
Expand Down Expand Up @@ -312,6 +324,15 @@ def tool_status():
"protocol": PALACE_PROTOCOL,
"aaak_dialect": AAAK_SPEC,
}
# Report which embedding model this palace uses
from .embedding import resolve_model_from_metadata

try:
raw_client = _get_client()
raw_col = raw_client.get_collection(_config.collection_name)
result["embedding_model"] = resolve_model_from_metadata(raw_col.metadata)
except Exception:
result["embedding_model"] = "unknown"
try:
all_meta = _get_cached_metadata(col)
for m in all_meta:
Expand Down
11 changes: 10 additions & 1 deletion mempalace/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,17 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):

temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
print(f" Creating fresh palace in {temp_palace}...")
from .embedding import get_embedding_function, new_palace_model

migrate_model = new_palace_model()
migrate_ef = get_embedding_function(migrate_model)
fresh_backend = ChromaBackend()
col = fresh_backend.get_or_create_collection(temp_palace, "mempalace_drawers")
col = fresh_backend.get_or_create_collection(
temp_palace,
"mempalace_drawers",
embedding_function=migrate_ef,
embedding_model_name=migrate_model,
)

# Re-import in batches
batch_size = 500
Expand Down
43 changes: 41 additions & 2 deletions mempalace/palace.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,51 @@ def get_collection(
palace_path: str,
collection_name: str = "mempalace_drawers",
create: bool = True,
config=None,
):
"""Get the palace collection through the backend layer."""
"""Get the palace collection through the backend layer.

Resolves the correct embedding model:
- On create: uses new_palace_model(config) (mpnet by default)
- On read: reads model from collection metadata, falls back to MiniLM
"""
from .embedding import (
get_embedding_function,
new_palace_model,
resolve_model_from_metadata,
)

if create:
model = new_palace_model(config)
ef = get_embedding_function(model)
return _DEFAULT_BACKEND.get_collection(
palace_path,
collection_name=collection_name,
create=True,
embedding_function=ef,
embedding_model_name=model,
)

# For existing palaces: read metadata to determine model.
# Check path existence first — _client() would create the directory
# as a side effect of PersistentClient(), which breaks callers that
# expect FileNotFoundError for missing palaces (e.g. status()).
if not os.path.isdir(palace_path):
raise FileNotFoundError(palace_path)

try:
raw_client = _DEFAULT_BACKEND._client(palace_path)
raw_col = raw_client.get_collection(collection_name)
model = resolve_model_from_metadata(raw_col.metadata)
except Exception:
model = resolve_model_from_metadata(None)

ef = get_embedding_function(model)
return _DEFAULT_BACKEND.get_collection(
palace_path,
collection_name=collection_name,
create=create,
create=False,
embedding_function=ef,
)


Expand Down
19 changes: 17 additions & 2 deletions mempalace/repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,25 @@ def rebuild_index(palace_path=None):
shutil.copy2(sqlite_path, backup_path)
print(f" Backup: {backup_path}")

# Rebuild with correct HNSW settings
# Read the embedding model from the existing collection before deleting
from .embedding import get_embedding_function, resolve_model_from_metadata

try:
raw_client = backend._client(palace_path)
raw_col = raw_client.get_collection(COLLECTION_NAME)
rebuild_model = resolve_model_from_metadata(raw_col.metadata)
except Exception:
rebuild_model = resolve_model_from_metadata(None)
rebuild_ef = get_embedding_function(rebuild_model)

print(" Rebuilding collection with hnsw:space=cosine...")
backend.delete_collection(palace_path, COLLECTION_NAME)
new_col = backend.create_collection(palace_path, COLLECTION_NAME)
new_col = backend.create_collection(
palace_path,
COLLECTION_NAME,
embedding_function=rebuild_ef,
embedding_model_name=rebuild_model,
)

filed = 0
for i in range(0, len(all_ids), batch_size):
Expand Down
54 changes: 54 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids
from mempalace.embedding import DEFAULT_MODEL, get_embedding_function


class _FakeCollection:
Expand Down Expand Up @@ -140,3 +141,56 @@ def test_fix_blob_seq_ids_noop_without_blobs(tmp_path):
def test_fix_blob_seq_ids_noop_without_database(tmp_path):
"""No error when palace has no chroma.sqlite3."""
_fix_blob_seq_ids(str(tmp_path)) # should not raise


def test_chroma_backend_stamps_embedding_model_on_create(tmp_path):
palace_path = tmp_path / "palace"
ef = get_embedding_function(DEFAULT_MODEL)

ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=True,
embedding_function=ef,
embedding_model_name=DEFAULT_MODEL,
)

# Verify model name was stamped in collection metadata
client = chromadb.PersistentClient(path=str(palace_path))
col = client.get_collection("mempalace_drawers")
assert col.metadata.get("embedding_model") == DEFAULT_MODEL


def test_chroma_backend_passes_embedding_function_on_get(tmp_path):
palace_path = tmp_path / "palace"
ef = get_embedding_function(DEFAULT_MODEL)

# Create first
ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=True,
embedding_function=ef,
embedding_model_name=DEFAULT_MODEL,
)

# Get with embedding function — should not raise
col = ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=False,
embedding_function=ef,
)
assert col.count() == 0


def test_chroma_backend_works_without_embedding_function(tmp_path):
"""Backwards compatibility — no embedding_function still works."""
palace_path = tmp_path / "palace"

col = ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=True,
)
assert col.count() == 0
Loading