From de707faf360f5dab87f25ce09879140e3663e32f Mon Sep 17 00:00:00 2001 From: sourrrish Date: Sun, 31 May 2026 13:58:20 +0530 Subject: [PATCH] fix cross-project context traversal Signed-off-by: sourrrish --- src/basic_memory/api/v2/utils.py | 6 +- .../repository/entity_repository.py | 14 +- src/basic_memory/services/context_service.py | 74 ++++-- tests/api/v2/test_memory_hydration.py | 24 +- tests/repository/test_entity_repository.py | 35 +++ tests/services/test_context_service.py | 221 ++++++++++++++++++ 6 files changed, 342 insertions(+), 32 deletions(-) diff --git a/src/basic_memory/api/v2/utils.py b/src/basic_memory/api/v2/utils.py index 61aac74ec..db96b8b10 100644 --- a/src/basic_memory/api/v2/utils.py +++ b/src/basic_memory/api/v2/utils.py @@ -18,7 +18,9 @@ class EntityBatchLookup(Protocol): - async def find_by_ids_for_hydration(self, ids: List[int]) -> Sequence[Any]: ... + async def find_by_ids_for_hydration( + self, ids: List[int], *, include_cross_project: bool = False + ) -> Sequence[Any]: ... class EntityServiceBatchLookup(Protocol): @@ -88,7 +90,7 @@ async def to_graph_context( result_count=len(entity_ids_needed), ): entities = await entity_repository.find_by_ids_for_hydration( - list(entity_ids_needed) + list(entity_ids_needed), include_cross_project=True ) for e in entities: entity_title_lookup[e.id] = e.title diff --git a/src/basic_memory/repository/entity_repository.py b/src/basic_memory/repository/entity_repository.py index f3fab47cd..706717855 100644 --- a/src/basic_memory/repository/entity_repository.py +++ b/src/basic_memory/repository/entity_repository.py @@ -178,21 +178,31 @@ async def get_all_permalinks(self) -> List[str]: result = await self.execute_query(query, use_query_options=False) return list(result.scalars().all()) - async def find_by_ids_for_hydration(self, ids: List[int]) -> Sequence[Entity]: + async def find_by_ids_for_hydration( + self, ids: List[int], *, include_cross_project: bool = False + ) -> Sequence[Entity]: """Fetch minimal entity fields needed for context hydration. Context hydration only needs an entity's primary key, title, and external UUID. Keeping this separate from find_by_ids avoids the relationship eager loads that are useful for full entity reads but expensive for response shaping. + + Args: + ids: Entity IDs to hydrate. + include_cross_project: Include IDs outside this repository's project scope. + Use only for IDs already reached through validated graph traversal. """ if not ids: return [] query = ( - self.select() + select(Entity) .where(Entity.id.in_(ids)) .options(load_only(Entity.id, Entity.title, Entity.external_id)) ) + if not include_cross_project: + query = self._add_project_filter(query) + result = await self.execute_query(query, use_query_options=False) return list(result.scalars().all()) diff --git a/src/basic_memory/services/context_service.py b/src/basic_memory/services/context_service.py index 995058fd2..678af8836 100644 --- a/src/basic_memory/services/context_service.py +++ b/src/basic_memory/services/context_service.py @@ -333,9 +333,14 @@ async def find_related( relation_date_filter = "" timeframe_condition = "" - # Add project filtering for security - ensure all entities and relations belong to the same project - project_filter = "AND e.project_id = :project_id" - relation_project_filter = "AND e_from.project_id = :project_id" + # Trigger: build_context starts from a project-scoped search result. + # Why: the seed entity must belong to the requested project, but an + # explicit relation edge may point at another project. + # Outcome: traversal follows only project-owned edges from reached + # entities, instead of forcing every reached entity into the seed project. + seed_project_filter = "AND e.project_id = :project_id" + connected_entity_project_filter = "" + relation_project_filter = "AND e_from.project_id = r.project_id" # Use a CTE that operates directly on entity and relation tables # This avoids the overhead of the search_index virtual table @@ -351,7 +356,8 @@ async def find_related( query = self._build_postgres_query( entity_id_values, date_filter, - project_filter, + seed_project_filter, + connected_entity_project_filter, relation_date_filter, relation_project_filter, timeframe_condition, @@ -362,7 +368,8 @@ async def find_related( query = self._build_sqlite_query( entity_id_values, date_filter, - project_filter, + seed_project_filter, + connected_entity_project_filter, relation_date_filter, relation_project_filter, timeframe_condition, @@ -397,7 +404,8 @@ def _build_postgres_query( # pragma: no cover self, entity_id_values: str, date_filter: str, - project_filter: str, + seed_project_filter: str, + connected_entity_project_filter: str, relation_date_filter: str, relation_project_filter: str, timeframe_condition: str, @@ -421,11 +429,13 @@ def _build_postgres_query( # pragma: no cover 0 as depth, e.id as root_id, e.created_at, - e.created_at as relation_date + e.created_at as relation_date, + e.project_id as project_id, + ',' || e.id::text || ',' as entity_path FROM entity e WHERE e.id IN ({entity_id_values}) {date_filter} - {project_filter} + {seed_project_filter} UNION ALL @@ -477,15 +487,25 @@ def _build_postgres_query( # pragma: no cover CASE WHEN step_type = 1 THEN e_from.created_at ELSE eg.relation_date - END as relation_date + END as relation_date, + CASE + WHEN step_type = 1 THEN eg.project_id + ELSE e.project_id + END as project_id, + CASE + WHEN step_type = 1 THEN eg.entity_path + ELSE eg.entity_path || e.id::text || ',' + END as entity_path FROM entity_graph eg CROSS JOIN LATERAL (VALUES (1), (2)) AS steps(step_type) JOIN relation r ON ( eg.type = 'entity' AND - (r.from_id = eg.id OR r.to_id = eg.id) + (r.from_id = eg.id OR r.to_id = eg.id) AND + r.project_id = eg.project_id ) JOIN entity e_from ON ( r.from_id = e_from.id + {relation_date_filter} {relation_project_filter} ) LEFT JOIN entity e ON ( @@ -495,10 +515,17 @@ def _build_postgres_query( # pragma: no cover ELSE r.from_id END {date_filter} - {project_filter} + {connected_entity_project_filter} ) WHERE eg.depth < :max_depth - AND (step_type = 1 OR (step_type = 2 AND e.id IS NOT NULL AND e.id != eg.id)) + AND ( + step_type = 1 OR ( + step_type = 2 + AND e.id IS NOT NULL + AND e.id != eg.id + AND position(',' || e.id::text || ',' in eg.entity_path) = 0 + ) + ) {timeframe_condition} ) -- Materialize and filter @@ -529,7 +556,8 @@ def _build_sqlite_query( self, entity_id_values: str, date_filter: str, - project_filter: str, + seed_project_filter: str, + connected_entity_project_filter: str, relation_date_filter: str, relation_project_filter: str, timeframe_condition: str, @@ -555,11 +583,13 @@ def _build_sqlite_query( e.id as root_id, e.created_at, e.created_at as relation_date, - 0 as is_incoming + 0 as is_incoming, + e.project_id as project_id, + ',' || e.id || ',' as entity_path FROM entity e WHERE e.id IN ({entity_id_values}) {date_filter} - {project_filter} + {seed_project_filter} UNION ALL @@ -580,11 +610,14 @@ def _build_sqlite_query( eg.root_id, e_from.created_at, e_from.created_at as relation_date, - CASE WHEN r.from_id = eg.id THEN 0 ELSE 1 END as is_incoming + CASE WHEN r.from_id = eg.id THEN 0 ELSE 1 END as is_incoming, + eg.project_id as project_id, + eg.entity_path as entity_path FROM entity_graph eg JOIN relation r ON ( eg.type = 'entity' AND - (r.from_id = eg.id OR r.to_id = eg.id) + (r.from_id = eg.id OR r.to_id = eg.id) AND + r.project_id = eg.project_id ) JOIN entity e_from ON ( r.from_id = e_from.id @@ -615,7 +648,9 @@ def _build_sqlite_query( eg.root_id, e.created_at, eg.relation_date, - eg.is_incoming + eg.is_incoming, + e.project_id as project_id, + eg.entity_path || e.id || ',' as entity_path FROM entity_graph eg JOIN entity e ON ( eg.type = 'relation' AND @@ -624,9 +659,10 @@ def _build_sqlite_query( ELSE eg.from_id END {date_filter} - {project_filter} + {connected_entity_project_filter} ) WHERE eg.depth < :max_depth + AND instr(eg.entity_path, ',' || e.id || ',') = 0 {timeframe_condition} ) SELECT DISTINCT diff --git a/tests/api/v2/test_memory_hydration.py b/tests/api/v2/test_memory_hydration.py index e52759825..187cc9902 100644 --- a/tests/api/v2/test_memory_hydration.py +++ b/tests/api/v2/test_memory_hydration.py @@ -49,14 +49,16 @@ class SpyEntityRepository: def __init__(self, entities_by_id: dict[int, SimpleNamespace]): self.entities_by_id = entities_by_id - self.calls: list[list[int]] = [] + self.calls: list[tuple[list[int], bool]] = [] async def find_by_ids(self, ids: list[int]): - self.calls.append(ids) + self.calls.append((ids, False)) return [self.entities_by_id[i] for i in ids if i in self.entities_by_id] - async def find_by_ids_for_hydration(self, ids: list[int]): - self.calls.append(ids) + async def find_by_ids_for_hydration( + self, ids: list[int], *, include_cross_project: bool = False + ): + self.calls.append((ids, include_cross_project)) return [self.entities_by_id[i] for i in ids if i in self.entities_by_id] @@ -65,13 +67,15 @@ class LightweightOnlyEntityRepository: def __init__(self, entities_by_id: dict[int, SimpleNamespace]): self.entities_by_id = entities_by_id - self.hydration_calls: list[list[int]] = [] + self.hydration_calls: list[tuple[list[int], bool]] = [] async def find_by_ids(self, ids: list[int]): raise AssertionError("graph hydration must use the lightweight hydration lookup") - async def find_by_ids_for_hydration(self, ids: list[int]): - self.hydration_calls.append(ids) + async def find_by_ids_for_hydration( + self, ids: list[int], *, include_cross_project: bool = False + ): + self.hydration_calls.append((ids, include_cross_project)) return [self.entities_by_id[i] for i in ids if i in self.entities_by_id] @@ -177,7 +181,8 @@ async def test_to_graph_context_batches_entity_hydration_for_recent_activity(): graph = await to_graph_context(context, entity_repository=repo, page=1, page_size=10) assert len(repo.calls) == 1, f"Expected 1 entity lookup, got {len(repo.calls)}" - assert set(repo.calls[0]) == {1, 2, 3} + assert set(repo.calls[0][0]) == {1, 2, 3} + assert repo.calls[0][1] is True first_result = graph.results[0] first_primary = first_result.primary_result @@ -272,7 +277,8 @@ async def test_to_graph_context_uses_lightweight_hydration_lookup(): graph = await to_graph_context(context, entity_repository=repo) assert len(repo.hydration_calls) == 1 - assert set(repo.hydration_calls[0]) == {1, 2} + assert set(repo.hydration_calls[0][0]) == {1, 2} + assert repo.hydration_calls[0][1] is True relation = graph.results[0].related_results[0] assert isinstance(relation, RelationSummary) assert relation.from_entity_external_id == "ext-root" diff --git a/tests/repository/test_entity_repository.py b/tests/repository/test_entity_repository.py index a74f9b340..194722af5 100644 --- a/tests/repository/test_entity_repository.py +++ b/tests/repository/test_entity_repository.py @@ -1069,6 +1069,41 @@ def fail_get_load_options(): assert found[0].external_id == sample_entity.external_id +@pytest.mark.asyncio +async def test_find_by_ids_for_hydration_can_include_cross_project_entities( + entity_repository: EntityRepository, sample_entity: Entity, session_maker +): + """Context hydration can opt into IDs reached through explicit graph edges.""" + async with db.scoped_session(session_maker) as session: + other_project = Project(name="other-project", path="/other") + session.add(other_project) + await session.flush() + + other_entity = Entity( + project_id=other_project.id, + title="Other Project Entity", + note_type="test", + permalink="other-project/entity", + file_path="other-project/entity.md", + content_type="text/markdown", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session.add(other_entity) + await session.flush() + other_entity_id = other_entity.id + + project_scoped = await entity_repository.find_by_ids_for_hydration( + [sample_entity.id, other_entity_id] + ) + cross_project = await entity_repository.find_by_ids_for_hydration( + [sample_entity.id, other_entity_id], include_cross_project=True + ) + + assert {entity.id for entity in project_scoped} == {sample_entity.id} + assert {entity.id for entity in cross_project} == {sample_entity.id, other_entity_id} + + @pytest.mark.asyncio async def test_get_permalink_to_file_path_map(entity_repository: EntityRepository, session_maker): """Test getting permalink -> file_path mapping for bulk operations.""" diff --git a/tests/services/test_context_service.py b/tests/services/test_context_service.py index d86413f56..0cb4d5d7e 100644 --- a/tests/services/test_context_service.py +++ b/tests/services/test_context_service.py @@ -339,6 +339,227 @@ async def test_project_isolation_in_find_related(session_maker, app_config): assert entity1_p2.project_id == project2.id +@pytest.mark.asyncio +async def test_find_related_expands_cross_project_relation_targets(session_maker, app_config): + """Explicit cross-project links should expand without exposing unrelated incoming links.""" + from basic_memory.repository.entity_repository import EntityRepository + from basic_memory.repository.observation_repository import ObservationRepository + from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository + from basic_memory.repository.postgres_search_repository import PostgresSearchRepository + from basic_memory.config import DatabaseBackend + from basic_memory import db + + async with db.scoped_session(session_maker) as db_session: + project1 = Project(name="project1", path="/test1") + project2 = Project(name="project2", path="/test2") + project3 = Project(name="project3", path="/test3") + db_session.add_all([project1, project2, project3]) + await db_session.flush() + + now = datetime.now(UTC) + source = Entity( + title="Source", + note_type="document", + content_type="text/markdown", + project_id=project1.id, + permalink="project1/source", + file_path="project1/source.md", + created_at=now, + updated_at=now, + ) + target = Entity( + title="Company Standards", + note_type="document", + content_type="text/markdown", + project_id=project2.id, + permalink="project2/company-standards", + file_path="project2/company-standards.md", + created_at=now, + updated_at=now, + ) + target_child = Entity( + title="Review Checklist", + note_type="document", + content_type="text/markdown", + project_id=project2.id, + permalink="project2/review-checklist", + file_path="project2/review-checklist.md", + created_at=now, + updated_at=now, + ) + unrelated_source = Entity( + title="Unrelated Source", + note_type="document", + content_type="text/markdown", + project_id=project3.id, + permalink="project3/unrelated-source", + file_path="project3/unrelated-source.md", + created_at=now, + updated_at=now, + ) + db_session.add_all([source, target, target_child, unrelated_source]) + await db_session.flush() + + cross_project_relation = Relation( + project_id=project1.id, + from_id=source.id, + to_id=target.id, + to_name="Company Standards", + relation_type="links_to", + ) + target_relation = Relation( + project_id=project2.id, + from_id=target.id, + to_id=target_child.id, + to_name="Review Checklist", + relation_type="links_to", + ) + unrelated_incoming_relation = Relation( + project_id=project3.id, + from_id=unrelated_source.id, + to_id=target.id, + to_name="Company Standards", + relation_type="links_to", + ) + db_session.add_all([cross_project_relation, target_relation, unrelated_incoming_relation]) + await db_session.commit() + + if app_config.database_backend == DatabaseBackend.POSTGRES: + search_repo_p1 = PostgresSearchRepository(session_maker, project1.id) + else: + search_repo_p1 = SQLiteSearchRepository(session_maker, project1.id) + + entity_repo_p1 = EntityRepository(session_maker, project1.id) + obs_repo_p1 = ObservationRepository(session_maker, project1.id) + context_service_p1 = ContextService(search_repo_p1, entity_repo_p1, obs_repo_p1) + + await search_repo_p1.index_item( + SearchIndexRow( + project_id=project1.id, + id=source.id, + title=source.title, + content_snippet="Source content", + permalink=source.permalink, + file_path=source.file_path, + type=SearchItemType.ENTITY, + metadata={"created_at": now.isoformat()}, + created_at=now, + updated_at=now, + ) + ) + + context = await context_service_p1.build_context( + memory_url.validate_strings("memory://project1/source"), + depth=2, + max_related=100, + ) + assert len(context.results) == 1 + + context_related_entity_ids = { + row.id for row in context.results[0].related_results if row.type == "entity" + } + context_related_relation_ids = { + row.id for row in context.results[0].related_results if row.type == "relation" + } + + assert target.id in context_related_entity_ids + assert target_child.id in context_related_entity_ids + assert unrelated_source.id not in context_related_entity_ids + assert cross_project_relation.id in context_related_relation_ids + assert target_relation.id in context_related_relation_ids + assert unrelated_incoming_relation.id not in context_related_relation_ids + + related = await context_service_p1.find_related( + [("entity", source.id)], max_depth=2, max_results=100 + ) + + related_entity_ids = {row.id for row in related if row.type == "entity"} + related_relation_ids = {row.id for row in related if row.type == "relation"} + + assert target.id in related_entity_ids + assert target_child.id in related_entity_ids + assert unrelated_source.id not in related_entity_ids + assert cross_project_relation.id in related_relation_ids + assert target_relation.id in related_relation_ids + assert unrelated_incoming_relation.id not in related_relation_ids + + +@pytest.mark.asyncio +async def test_find_related_does_not_revisit_entities_in_cycles(session_maker, app_config): + """Recursive graph expansion should stop when a path loops back to a visited entity.""" + from basic_memory.repository.entity_repository import EntityRepository + from basic_memory.repository.observation_repository import ObservationRepository + from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository + from basic_memory.repository.postgres_search_repository import PostgresSearchRepository + from basic_memory.config import DatabaseBackend + from basic_memory import db + + async with db.scoped_session(session_maker) as db_session: + project = Project(name="cycle-project", path="/cycle") + db_session.add(project) + await db_session.flush() + + now = datetime.now(UTC) + root = Entity( + title="Root", + note_type="document", + content_type="text/markdown", + project_id=project.id, + permalink="cycle/root", + file_path="cycle/root.md", + created_at=now, + updated_at=now, + ) + connected = Entity( + title="Connected", + note_type="document", + content_type="text/markdown", + project_id=project.id, + permalink="cycle/connected", + file_path="cycle/connected.md", + created_at=now, + updated_at=now, + ) + db_session.add_all([root, connected]) + await db_session.flush() + + root_to_connected = Relation( + project_id=project.id, + from_id=root.id, + to_id=connected.id, + to_name="Connected", + relation_type="links_to", + ) + connected_to_root = Relation( + project_id=project.id, + from_id=connected.id, + to_id=root.id, + to_name="Root", + relation_type="links_to", + ) + db_session.add_all([root_to_connected, connected_to_root]) + await db_session.commit() + + if app_config.database_backend == DatabaseBackend.POSTGRES: + search_repo = PostgresSearchRepository(session_maker, project.id) + else: + search_repo = SQLiteSearchRepository(session_maker, project.id) + + entity_repo = EntityRepository(session_maker, project.id) + obs_repo = ObservationRepository(session_maker, project.id) + context_service = ContextService(search_repo, entity_repo, obs_repo) + + related = await context_service.find_related( + [("entity", root.id)], max_depth=4, max_results=100 + ) + + related_entity_ids = [row.id for row in related if row.type == "entity"] + related_relation_ids = {row.id for row in related if row.type == "relation"} + + assert related_entity_ids == [connected.id] + assert related_relation_ids == {root_to_connected.id, connected_to_root.id} + + @pytest.mark.asyncio async def test_build_context_fallback_via_link_resolver(context_service, test_graph): """Test that build_context falls back to LinkResolver when exact permalink fails.