Skip to content

Commit 5cf62ee

Browse files
committed
feat: add embedding device support + fix DefaultEmbeddingFunction regression
Builds on top of MemPalace#442 to add two improvements: 1. Expose embedding device via MEMPALACE_EMBEDDING_DEVICE env var / embedding_device config key. This lets Apple Silicon users set device='mps' and NVIDIA users set device='cuda' to dramatically speed up embedding generation during mempalace mine (5-15x measured on M-series). 2. Ergonomic default: setting only MEMPALACE_EMBEDDING_DEVICE automatically uses sentence-transformers/all-MiniLM-L6-v2 (same weights as ChromaDB's default ONNX embedder), so users don't have to know the model name to get GPU acceleration, and existing palaces remain vector-compatible. 3. Fix a regression in MemPalace#442: when no model is configured, get_embedding_function() used to return None, which newer ChromaDB rejects with 'You must provide an embedding function' at collection.add() time. Now returns ChromaDB's DefaultEmbeddingFunction() explicitly, restoring the pre-MemPalace#442 default behavior and making tests/test_convo_miner.py pass again. All 552 tests pass, including 7 new tests covering: - embedding_device property reads from env var and config.json - device is passed through to SentenceTransformerEmbeddingFunction when set - device alone activates the default model - device is NOT passed when unset (preserves original MemPalace#442 call signature) - device can be set via config.json independent of model
1 parent b442c55 commit 5cf62ee

2 files changed

Lines changed: 218 additions & 21 deletions

File tree

mempalace/config.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,20 @@ def embedding_model(self):
208208
return env_val
209209
return self._file_config.get("embedding_model", None)
210210

211+
@property
212+
def embedding_device(self):
213+
"""Configured embedding device ('cpu', 'mps', 'cuda', ...) or None.
214+
215+
When None, ``SentenceTransformerEmbeddingFunction`` picks its own
216+
default (CPU). Setting this to ``'mps'`` on Apple Silicon or
217+
``'cuda'`` on NVIDIA GPUs can dramatically speed up embedding
218+
generation during ``mempalace mine``.
219+
"""
220+
env_val = os.environ.get("MEMPALACE_EMBEDDING_DEVICE")
221+
if env_val:
222+
return env_val
223+
return self._file_config.get("embedding_device", None)
224+
211225
def save_people_map(self, people_map):
212226
"""Write people_map.json to config directory.
213227
@@ -225,15 +239,33 @@ def save_people_map(self, people_map):
225239
_embedding_function = None
226240
_embedding_function_resolved = False
227241

242+
# Default model used when a device is explicitly requested but no model name
243+
# is configured. This is the same underlying model ChromaDB uses by default
244+
# (via its ONNX runtime), so vectors remain compatible with existing palaces.
245+
_DEFAULT_MODEL_FOR_DEVICE = "sentence-transformers/all-MiniLM-L6-v2"
228246

229-
def get_embedding_function(config=None):
230-
"""Return the configured ChromaDB embedding function, or None for default.
231247

232-
Checks MEMPALACE_EMBEDDING_MODEL env var first, then config.json
233-
``embedding_model`` key. When a model name is found, attempts to import
234-
``SentenceTransformerEmbeddingFunction`` from chromadb. If
235-
sentence-transformers is not installed the import will fail and we fall
236-
back to None (ChromaDB's built-in default), logging a warning.
248+
def get_embedding_function(config=None):
249+
"""Return the configured ChromaDB embedding function.
250+
251+
Resolution order:
252+
253+
1. If ``MEMPALACE_EMBEDDING_MODEL`` / ``embedding_model`` is set, use that
254+
model via :class:`SentenceTransformerEmbeddingFunction`.
255+
2. Else, if ``MEMPALACE_EMBEDDING_DEVICE`` / ``embedding_device`` is set
256+
(e.g. ``'mps'``, ``'cuda'``), use the default model
257+
``sentence-transformers/all-MiniLM-L6-v2`` on that device. This gives
258+
Apple Silicon / NVIDIA users a GPU speedup without having to think
259+
about model names, while staying vector-compatible with ChromaDB's
260+
default ONNX embedder (same underlying weights).
261+
3. Else, return ChromaDB's built-in ``DefaultEmbeddingFunction`` (ONNX
262+
MiniLM on CPU). Newer ChromaDB versions require an explicit embedding
263+
function at collection creation time, so returning ``None`` here would
264+
break ``collection.add()`` calls.
265+
266+
When a model is resolved, the configured device (if any) is passed through
267+
to ``SentenceTransformerEmbeddingFunction``. If ``sentence-transformers``
268+
isn't installed, we log a warning and fall back to ChromaDB's default.
237269
238270
The result is cached so the function is only resolved once per process.
239271
"""
@@ -245,19 +277,49 @@ def get_embedding_function(config=None):
245277

246278
cfg = config or MempalaceConfig()
247279
model_name = cfg.embedding_model
280+
device = cfg.embedding_device
281+
282+
# Ergonomic default: if the user asked for a device but didn't pick a
283+
# model, use the same model ChromaDB uses by default so vectors stay
284+
# compatible with existing palaces.
285+
if not model_name and device:
286+
model_name = _DEFAULT_MODEL_FOR_DEVICE
287+
248288
if not model_name:
249-
return None
289+
# No explicit configuration — use ChromaDB's default embedder.
290+
# We must return a real callable (not None), because newer ChromaDB
291+
# versions reject `embedding_function=None` at collection.add() time.
292+
try:
293+
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
294+
295+
_embedding_function = DefaultEmbeddingFunction()
296+
except Exception:
297+
_embedding_function = None
298+
return _embedding_function
250299

251300
try:
252301
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
253302

254-
_embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name)
255-
logger.info("Using embedding model: %s", model_name)
303+
kwargs = {"model_name": model_name}
304+
if device:
305+
kwargs["device"] = device
306+
307+
_embedding_function = SentenceTransformerEmbeddingFunction(**kwargs)
308+
logger.info(
309+
"Using embedding model: %s (device=%s)",
310+
model_name,
311+
device or "default",
312+
)
256313
except Exception:
257314
logger.warning(
258315
"sentence-transformers not installed — falling back to ChromaDB default. "
259316
"Install with: pip install mempalace[multilingual]"
260317
)
261-
_embedding_function = None
318+
try:
319+
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
320+
321+
_embedding_function = DefaultEmbeddingFunction()
322+
except Exception:
323+
_embedding_function = None
262324

263325
return _embedding_function

tests/test_multilingual.py

Lines changed: 145 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,23 @@ def config_dir(tmp_path):
3333

3434

3535
class TestGetEmbeddingFunctionDefault:
36-
"""When no model is configured, get_embedding_function returns None."""
36+
"""When no model is configured, get_embedding_function returns ChromaDB's default."""
3737

38-
def test_returns_none_no_config(self, tmp_path):
38+
def test_returns_default_no_config(self, tmp_path):
3939
config = MempalaceConfig(config_dir=str(tmp_path / "empty"))
4040
result = get_embedding_function(config=config)
41-
assert result is None
41+
# Must not be None — newer ChromaDB requires an explicit callable
42+
# so that `collection.add()` can compute embeddings.
43+
assert result is not None
44+
assert callable(result)
4245

43-
def test_returns_none_empty_config(self, config_dir):
46+
def test_returns_default_empty_config(self, config_dir):
4447
config_file = config_dir / "config.json"
4548
config_file.write_text("{}")
4649
config = MempalaceConfig(config_dir=str(config_dir))
4750
result = get_embedding_function(config=config)
48-
assert result is None
51+
assert result is not None
52+
assert callable(result)
4953

5054

5155
class TestGetEmbeddingFunctionEnvVar:
@@ -117,7 +121,7 @@ def test_config_file_model(self, config_dir):
117121
class TestGetEmbeddingFunctionFallback:
118122
"""Graceful fallback when sentence-transformers is not installed."""
119123

120-
def test_import_error_returns_none(self, config_dir):
124+
def test_import_error_falls_back_to_default(self, config_dir):
121125
config_file = config_dir / "config.json"
122126
config_file.write_text(json.dumps({"embedding_model": "some-model"}))
123127
config = MempalaceConfig(config_dir=str(config_dir))
@@ -128,7 +132,9 @@ def test_import_error_returns_none(self, config_dir):
128132
):
129133
result = get_embedding_function(config=config)
130134

131-
assert result is None
135+
# Falls back to ChromaDB's DefaultEmbeddingFunction, not None
136+
assert result is not None
137+
assert callable(result)
132138

133139

134140
class TestGetEmbeddingFunctionCaching:
@@ -153,12 +159,14 @@ def test_caches_result(self, config_dir):
153159
# Constructor called only once due to caching
154160
assert mock_st_cls.call_count == 1
155161

156-
def test_caches_none_result(self, tmp_path):
162+
def test_caches_default_result(self, tmp_path):
163+
"""The default embedding function is also cached between calls."""
157164
config = MempalaceConfig(config_dir=str(tmp_path / "empty"))
158165
result1 = get_embedding_function(config=config)
159166
result2 = get_embedding_function(config=config)
160-
assert result1 is None
161-
assert result2 is None
167+
# Same instance returned (cached), and never None
168+
assert result1 is result2
169+
assert result1 is not None
162170

163171

164172
class TestEmbeddingModelProperty:
@@ -180,3 +188,130 @@ def test_env_var_overrides(self, config_dir):
180188
config = MempalaceConfig(config_dir=str(config_dir))
181189
with patch.dict(os.environ, {"MEMPALACE_EMBEDDING_MODEL": "env-model"}):
182190
assert config.embedding_model == "env-model"
191+
192+
193+
class TestEmbeddingDeviceProperty:
194+
"""MempalaceConfig.embedding_device property."""
195+
196+
def test_returns_none_by_default(self, tmp_path):
197+
config = MempalaceConfig(config_dir=str(tmp_path / "empty"))
198+
assert config.embedding_device is None
199+
200+
def test_reads_from_config_file(self, config_dir):
201+
config_file = config_dir / "config.json"
202+
config_file.write_text(json.dumps({"embedding_device": "mps"}))
203+
config = MempalaceConfig(config_dir=str(config_dir))
204+
assert config.embedding_device == "mps"
205+
206+
def test_env_var_overrides(self, config_dir):
207+
config_file = config_dir / "config.json"
208+
config_file.write_text(json.dumps({"embedding_device": "cpu"}))
209+
config = MempalaceConfig(config_dir=str(config_dir))
210+
with patch.dict(os.environ, {"MEMPALACE_EMBEDDING_DEVICE": "mps"}):
211+
assert config.embedding_device == "mps"
212+
213+
214+
class TestGetEmbeddingFunctionDevice:
215+
"""MEMPALACE_EMBEDDING_DEVICE controls the device passed to the embedder."""
216+
217+
def test_device_passed_to_embedder_with_explicit_model(self, tmp_path):
218+
"""When both model and device are set, both are passed through."""
219+
mock_ef = MagicMock()
220+
mock_st_cls = MagicMock(return_value=mock_ef)
221+
config = MempalaceConfig(config_dir=str(tmp_path / "empty"))
222+
223+
with (
224+
patch.dict(
225+
os.environ,
226+
{
227+
"MEMPALACE_EMBEDDING_MODEL": "intfloat/multilingual-e5-base",
228+
"MEMPALACE_EMBEDDING_DEVICE": "mps",
229+
},
230+
),
231+
patch(
232+
"chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction",
233+
mock_st_cls,
234+
),
235+
):
236+
result = get_embedding_function(config=config)
237+
238+
assert result is mock_ef
239+
mock_st_cls.assert_called_once_with(
240+
model_name="intfloat/multilingual-e5-base", device="mps"
241+
)
242+
243+
def test_device_alone_activates_default_model(self, tmp_path):
244+
"""Setting only the device should trigger the default model on that device.
245+
246+
This is the ergonomic path for Apple Silicon / CUDA users: they
247+
don't need to know the model name, just the device.
248+
"""
249+
mock_ef = MagicMock()
250+
mock_st_cls = MagicMock(return_value=mock_ef)
251+
config = MempalaceConfig(config_dir=str(tmp_path / "empty"))
252+
253+
with (
254+
patch.dict(os.environ, {"MEMPALACE_EMBEDDING_DEVICE": "mps"}),
255+
patch(
256+
"chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction",
257+
mock_st_cls,
258+
),
259+
):
260+
result = get_embedding_function(config=config)
261+
262+
assert result is mock_ef
263+
mock_st_cls.assert_called_once_with(
264+
model_name="sentence-transformers/all-MiniLM-L6-v2", device="mps"
265+
)
266+
267+
def test_no_device_no_kwarg(self, tmp_path, monkeypatch):
268+
"""When no device is set, ``device`` is NOT passed as a kwarg.
269+
270+
This preserves backward compatibility with the original PR #442
271+
behavior where only ``model_name`` was passed.
272+
"""
273+
mock_ef = MagicMock()
274+
mock_st_cls = MagicMock(return_value=mock_ef)
275+
config = MempalaceConfig(config_dir=str(tmp_path / "empty"))
276+
277+
monkeypatch.setenv("MEMPALACE_EMBEDDING_MODEL", "some-model")
278+
monkeypatch.delenv("MEMPALACE_EMBEDDING_DEVICE", raising=False)
279+
280+
with patch(
281+
"chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction",
282+
mock_st_cls,
283+
):
284+
result = get_embedding_function(config=config)
285+
286+
assert result is mock_ef
287+
mock_st_cls.assert_called_once_with(model_name="some-model")
288+
289+
def test_device_from_config_file(self, config_dir, monkeypatch):
290+
"""Device can be set via config.json instead of env var."""
291+
config_file = config_dir / "config.json"
292+
config_file.write_text(
293+
json.dumps(
294+
{
295+
"embedding_model": "intfloat/multilingual-e5-base",
296+
"embedding_device": "cuda",
297+
}
298+
)
299+
)
300+
config = MempalaceConfig(config_dir=str(config_dir))
301+
302+
mock_ef = MagicMock()
303+
mock_st_cls = MagicMock(return_value=mock_ef)
304+
305+
monkeypatch.delenv("MEMPALACE_EMBEDDING_MODEL", raising=False)
306+
monkeypatch.delenv("MEMPALACE_EMBEDDING_DEVICE", raising=False)
307+
308+
with patch(
309+
"chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction",
310+
mock_st_cls,
311+
):
312+
result = get_embedding_function(config=config)
313+
314+
assert result is mock_ef
315+
mock_st_cls.assert_called_once_with(
316+
model_name="intfloat/multilingual-e5-base", device="cuda"
317+
)

0 commit comments

Comments
 (0)