@@ -33,19 +33,23 @@ def config_dir(tmp_path):
3333
3434
3535class TestGetEmbeddingFunctionDefault :
36- """When no model is configured, get_embedding_function returns None ."""
36+ """When no model is configured, get_embedding_function returns ChromaDB's default ."""
3737
38- def test_returns_none_no_config (self , tmp_path ):
38+ def test_returns_default_no_config (self , tmp_path ):
3939 config = MempalaceConfig (config_dir = str (tmp_path / "empty" ))
4040 result = get_embedding_function (config = config )
41- assert result is None
41+ # Must not be None — newer ChromaDB requires an explicit callable
42+ # so that `collection.add()` can compute embeddings.
43+ assert result is not None
44+ assert callable (result )
4245
43- def test_returns_none_empty_config (self , config_dir ):
46+ def test_returns_default_empty_config (self , config_dir ):
4447 config_file = config_dir / "config.json"
4548 config_file .write_text ("{}" )
4649 config = MempalaceConfig (config_dir = str (config_dir ))
4750 result = get_embedding_function (config = config )
48- assert result is None
51+ assert result is not None
52+ assert callable (result )
4953
5054
5155class TestGetEmbeddingFunctionEnvVar :
@@ -117,7 +121,7 @@ def test_config_file_model(self, config_dir):
117121class TestGetEmbeddingFunctionFallback :
118122 """Graceful fallback when sentence-transformers is not installed."""
119123
120- def test_import_error_returns_none (self , config_dir ):
124+ def test_import_error_falls_back_to_default (self , config_dir ):
121125 config_file = config_dir / "config.json"
122126 config_file .write_text (json .dumps ({"embedding_model" : "some-model" }))
123127 config = MempalaceConfig (config_dir = str (config_dir ))
@@ -128,7 +132,9 @@ def test_import_error_returns_none(self, config_dir):
128132 ):
129133 result = get_embedding_function (config = config )
130134
131- assert result is None
135+ # Falls back to ChromaDB's DefaultEmbeddingFunction, not None
136+ assert result is not None
137+ assert callable (result )
132138
133139
134140class TestGetEmbeddingFunctionCaching :
@@ -153,12 +159,14 @@ def test_caches_result(self, config_dir):
153159 # Constructor called only once due to caching
154160 assert mock_st_cls .call_count == 1
155161
156- def test_caches_none_result (self , tmp_path ):
162+ def test_caches_default_result (self , tmp_path ):
163+ """The default embedding function is also cached between calls."""
157164 config = MempalaceConfig (config_dir = str (tmp_path / "empty" ))
158165 result1 = get_embedding_function (config = config )
159166 result2 = get_embedding_function (config = config )
160- assert result1 is None
161- assert result2 is None
167+ # Same instance returned (cached), and never None
168+ assert result1 is result2
169+ assert result1 is not None
162170
163171
164172class TestEmbeddingModelProperty :
@@ -180,3 +188,130 @@ def test_env_var_overrides(self, config_dir):
180188 config = MempalaceConfig (config_dir = str (config_dir ))
181189 with patch .dict (os .environ , {"MEMPALACE_EMBEDDING_MODEL" : "env-model" }):
182190 assert config .embedding_model == "env-model"
191+
192+
193+ class TestEmbeddingDeviceProperty :
194+ """MempalaceConfig.embedding_device property."""
195+
196+ def test_returns_none_by_default (self , tmp_path ):
197+ config = MempalaceConfig (config_dir = str (tmp_path / "empty" ))
198+ assert config .embedding_device is None
199+
200+ def test_reads_from_config_file (self , config_dir ):
201+ config_file = config_dir / "config.json"
202+ config_file .write_text (json .dumps ({"embedding_device" : "mps" }))
203+ config = MempalaceConfig (config_dir = str (config_dir ))
204+ assert config .embedding_device == "mps"
205+
206+ def test_env_var_overrides (self , config_dir ):
207+ config_file = config_dir / "config.json"
208+ config_file .write_text (json .dumps ({"embedding_device" : "cpu" }))
209+ config = MempalaceConfig (config_dir = str (config_dir ))
210+ with patch .dict (os .environ , {"MEMPALACE_EMBEDDING_DEVICE" : "mps" }):
211+ assert config .embedding_device == "mps"
212+
213+
214+ class TestGetEmbeddingFunctionDevice :
215+ """MEMPALACE_EMBEDDING_DEVICE controls the device passed to the embedder."""
216+
217+ def test_device_passed_to_embedder_with_explicit_model (self , tmp_path ):
218+ """When both model and device are set, both are passed through."""
219+ mock_ef = MagicMock ()
220+ mock_st_cls = MagicMock (return_value = mock_ef )
221+ config = MempalaceConfig (config_dir = str (tmp_path / "empty" ))
222+
223+ with (
224+ patch .dict (
225+ os .environ ,
226+ {
227+ "MEMPALACE_EMBEDDING_MODEL" : "intfloat/multilingual-e5-base" ,
228+ "MEMPALACE_EMBEDDING_DEVICE" : "mps" ,
229+ },
230+ ),
231+ patch (
232+ "chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction" ,
233+ mock_st_cls ,
234+ ),
235+ ):
236+ result = get_embedding_function (config = config )
237+
238+ assert result is mock_ef
239+ mock_st_cls .assert_called_once_with (
240+ model_name = "intfloat/multilingual-e5-base" , device = "mps"
241+ )
242+
243+ def test_device_alone_activates_default_model (self , tmp_path ):
244+ """Setting only the device should trigger the default model on that device.
245+
246+ This is the ergonomic path for Apple Silicon / CUDA users: they
247+ don't need to know the model name, just the device.
248+ """
249+ mock_ef = MagicMock ()
250+ mock_st_cls = MagicMock (return_value = mock_ef )
251+ config = MempalaceConfig (config_dir = str (tmp_path / "empty" ))
252+
253+ with (
254+ patch .dict (os .environ , {"MEMPALACE_EMBEDDING_DEVICE" : "mps" }),
255+ patch (
256+ "chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction" ,
257+ mock_st_cls ,
258+ ),
259+ ):
260+ result = get_embedding_function (config = config )
261+
262+ assert result is mock_ef
263+ mock_st_cls .assert_called_once_with (
264+ model_name = "sentence-transformers/all-MiniLM-L6-v2" , device = "mps"
265+ )
266+
267+ def test_no_device_no_kwarg (self , tmp_path , monkeypatch ):
268+ """When no device is set, ``device`` is NOT passed as a kwarg.
269+
270+ This preserves backward compatibility with the original PR #442
271+ behavior where only ``model_name`` was passed.
272+ """
273+ mock_ef = MagicMock ()
274+ mock_st_cls = MagicMock (return_value = mock_ef )
275+ config = MempalaceConfig (config_dir = str (tmp_path / "empty" ))
276+
277+ monkeypatch .setenv ("MEMPALACE_EMBEDDING_MODEL" , "some-model" )
278+ monkeypatch .delenv ("MEMPALACE_EMBEDDING_DEVICE" , raising = False )
279+
280+ with patch (
281+ "chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction" ,
282+ mock_st_cls ,
283+ ):
284+ result = get_embedding_function (config = config )
285+
286+ assert result is mock_ef
287+ mock_st_cls .assert_called_once_with (model_name = "some-model" )
288+
289+ def test_device_from_config_file (self , config_dir , monkeypatch ):
290+ """Device can be set via config.json instead of env var."""
291+ config_file = config_dir / "config.json"
292+ config_file .write_text (
293+ json .dumps (
294+ {
295+ "embedding_model" : "intfloat/multilingual-e5-base" ,
296+ "embedding_device" : "cuda" ,
297+ }
298+ )
299+ )
300+ config = MempalaceConfig (config_dir = str (config_dir ))
301+
302+ mock_ef = MagicMock ()
303+ mock_st_cls = MagicMock (return_value = mock_ef )
304+
305+ monkeypatch .delenv ("MEMPALACE_EMBEDDING_MODEL" , raising = False )
306+ monkeypatch .delenv ("MEMPALACE_EMBEDDING_DEVICE" , raising = False )
307+
308+ with patch (
309+ "chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction" ,
310+ mock_st_cls ,
311+ ):
312+ result = get_embedding_function (config = config )
313+
314+ assert result is mock_ef
315+ mock_st_cls .assert_called_once_with (
316+ model_name = "intfloat/multilingual-e5-base" , device = "cuda"
317+ )
0 commit comments