Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/basic_memory/api/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@


class EntityBatchLookup(Protocol):
async def find_by_ids_for_hydration(self, ids: List[int]) -> Sequence[Any]: ...
async def find_by_ids_for_hydration(
self, ids: List[int], *, include_cross_project: bool = False
) -> Sequence[Any]: ...


class EntityServiceBatchLookup(Protocol):
Expand Down Expand Up @@ -88,7 +90,7 @@ async def to_graph_context(
result_count=len(entity_ids_needed),
):
entities = await entity_repository.find_by_ids_for_hydration(
list(entity_ids_needed)
list(entity_ids_needed), include_cross_project=True
)
for e in entities:
entity_title_lookup[e.id] = e.title
Expand Down
14 changes: 12 additions & 2 deletions src/basic_memory/repository/entity_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,31 @@ async def get_all_permalinks(self) -> List[str]:
result = await self.execute_query(query, use_query_options=False)
return list(result.scalars().all())

async def find_by_ids_for_hydration(self, ids: List[int]) -> Sequence[Entity]:
async def find_by_ids_for_hydration(
self, ids: List[int], *, include_cross_project: bool = False
) -> Sequence[Entity]:
"""Fetch minimal entity fields needed for context hydration.

Context hydration only needs an entity's primary key, title, and external
UUID. Keeping this separate from find_by_ids avoids the relationship eager
loads that are useful for full entity reads but expensive for response shaping.

Args:
ids: Entity IDs to hydrate.
include_cross_project: Include IDs outside this repository's project scope.
Use only for IDs already reached through validated graph traversal.
"""
if not ids:
return []

query = (
self.select()
select(Entity)
.where(Entity.id.in_(ids))
.options(load_only(Entity.id, Entity.title, Entity.external_id))
)
if not include_cross_project:
query = self._add_project_filter(query)

result = await self.execute_query(query, use_query_options=False)
return list(result.scalars().all())

Expand Down
74 changes: 55 additions & 19 deletions src/basic_memory/services/context_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,14 @@ async def find_related(
relation_date_filter = ""
timeframe_condition = ""

# Add project filtering for security - ensure all entities and relations belong to the same project
project_filter = "AND e.project_id = :project_id"
relation_project_filter = "AND e_from.project_id = :project_id"
# Trigger: build_context starts from a project-scoped search result.
# Why: the seed entity must belong to the requested project, but an
# explicit relation edge may point at another project.
# Outcome: traversal follows only project-owned edges from reached
# entities, instead of forcing every reached entity into the seed project.
seed_project_filter = "AND e.project_id = :project_id"
connected_entity_project_filter = ""
relation_project_filter = "AND e_from.project_id = r.project_id"

# Use a CTE that operates directly on entity and relation tables
# This avoids the overhead of the search_index virtual table
Expand All @@ -351,7 +356,8 @@ async def find_related(
query = self._build_postgres_query(
entity_id_values,
date_filter,
project_filter,
seed_project_filter,
connected_entity_project_filter,
relation_date_filter,
relation_project_filter,
timeframe_condition,
Expand All @@ -362,7 +368,8 @@ async def find_related(
query = self._build_sqlite_query(
entity_id_values,
date_filter,
project_filter,
seed_project_filter,
connected_entity_project_filter,
relation_date_filter,
relation_project_filter,
timeframe_condition,
Expand Down Expand Up @@ -397,7 +404,8 @@ def _build_postgres_query( # pragma: no cover
self,
entity_id_values: str,
date_filter: str,
project_filter: str,
seed_project_filter: str,
connected_entity_project_filter: str,
relation_date_filter: str,
relation_project_filter: str,
timeframe_condition: str,
Expand All @@ -421,11 +429,13 @@ def _build_postgres_query( # pragma: no cover
0 as depth,
e.id as root_id,
e.created_at,
e.created_at as relation_date
e.created_at as relation_date,
e.project_id as project_id,
',' || e.id::text || ',' as entity_path
FROM entity e
WHERE e.id IN ({entity_id_values})
{date_filter}
{project_filter}
{seed_project_filter}

UNION ALL

Expand Down Expand Up @@ -477,15 +487,25 @@ def _build_postgres_query( # pragma: no cover
CASE
WHEN step_type = 1 THEN e_from.created_at
ELSE eg.relation_date
END as relation_date
END as relation_date,
CASE
WHEN step_type = 1 THEN eg.project_id
ELSE e.project_id
END as project_id,
CASE
WHEN step_type = 1 THEN eg.entity_path
ELSE eg.entity_path || e.id::text || ','
END as entity_path
FROM entity_graph eg
CROSS JOIN LATERAL (VALUES (1), (2)) AS steps(step_type)
JOIN relation r ON (
eg.type = 'entity' AND
(r.from_id = eg.id OR r.to_id = eg.id)
(r.from_id = eg.id OR r.to_id = eg.id) AND
r.project_id = eg.project_id
)
JOIN entity e_from ON (
r.from_id = e_from.id
{relation_date_filter}
{relation_project_filter}
)
LEFT JOIN entity e ON (
Expand All @@ -495,10 +515,17 @@ def _build_postgres_query( # pragma: no cover
ELSE r.from_id
END
{date_filter}
{project_filter}
{connected_entity_project_filter}
)
WHERE eg.depth < :max_depth
AND (step_type = 1 OR (step_type = 2 AND e.id IS NOT NULL AND e.id != eg.id))
AND (
step_type = 1 OR (
step_type = 2
AND e.id IS NOT NULL
AND e.id != eg.id
AND position(',' || e.id::text || ',' in eg.entity_path) = 0
)
)
{timeframe_condition}
)
-- Materialize and filter
Expand Down Expand Up @@ -529,7 +556,8 @@ def _build_sqlite_query(
self,
entity_id_values: str,
date_filter: str,
project_filter: str,
seed_project_filter: str,
connected_entity_project_filter: str,
relation_date_filter: str,
relation_project_filter: str,
timeframe_condition: str,
Expand All @@ -555,11 +583,13 @@ def _build_sqlite_query(
e.id as root_id,
e.created_at,
e.created_at as relation_date,
0 as is_incoming
0 as is_incoming,
e.project_id as project_id,
',' || e.id || ',' as entity_path
FROM entity e
WHERE e.id IN ({entity_id_values})
{date_filter}
{project_filter}
{seed_project_filter}

UNION ALL

Expand All @@ -580,11 +610,14 @@ def _build_sqlite_query(
eg.root_id,
e_from.created_at,
e_from.created_at as relation_date,
CASE WHEN r.from_id = eg.id THEN 0 ELSE 1 END as is_incoming
CASE WHEN r.from_id = eg.id THEN 0 ELSE 1 END as is_incoming,
eg.project_id as project_id,
eg.entity_path as entity_path
FROM entity_graph eg
JOIN relation r ON (
eg.type = 'entity' AND
(r.from_id = eg.id OR r.to_id = eg.id)
(r.from_id = eg.id OR r.to_id = eg.id) AND
r.project_id = eg.project_id
)
JOIN entity e_from ON (
r.from_id = e_from.id
Expand Down Expand Up @@ -615,7 +648,9 @@ def _build_sqlite_query(
eg.root_id,
e.created_at,
eg.relation_date,
eg.is_incoming
eg.is_incoming,
e.project_id as project_id,
eg.entity_path || e.id || ',' as entity_path
FROM entity_graph eg
JOIN entity e ON (
eg.type = 'relation' AND
Expand All @@ -624,9 +659,10 @@ def _build_sqlite_query(
ELSE eg.from_id
END
{date_filter}
{project_filter}
{connected_entity_project_filter}
)
WHERE eg.depth < :max_depth
AND instr(eg.entity_path, ',' || e.id || ',') = 0
{timeframe_condition}
)
SELECT DISTINCT
Expand Down
24 changes: 15 additions & 9 deletions tests/api/v2/test_memory_hydration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ class SpyEntityRepository:

def __init__(self, entities_by_id: dict[int, SimpleNamespace]):
self.entities_by_id = entities_by_id
self.calls: list[list[int]] = []
self.calls: list[tuple[list[int], bool]] = []

async def find_by_ids(self, ids: list[int]):
self.calls.append(ids)
self.calls.append((ids, False))
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]

async def find_by_ids_for_hydration(self, ids: list[int]):
self.calls.append(ids)
async def find_by_ids_for_hydration(
self, ids: list[int], *, include_cross_project: bool = False
):
self.calls.append((ids, include_cross_project))
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]


Expand All @@ -65,13 +67,15 @@ class LightweightOnlyEntityRepository:

def __init__(self, entities_by_id: dict[int, SimpleNamespace]):
self.entities_by_id = entities_by_id
self.hydration_calls: list[list[int]] = []
self.hydration_calls: list[tuple[list[int], bool]] = []

async def find_by_ids(self, ids: list[int]):
raise AssertionError("graph hydration must use the lightweight hydration lookup")

async def find_by_ids_for_hydration(self, ids: list[int]):
self.hydration_calls.append(ids)
async def find_by_ids_for_hydration(
self, ids: list[int], *, include_cross_project: bool = False
):
self.hydration_calls.append((ids, include_cross_project))
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]


Expand Down Expand Up @@ -177,7 +181,8 @@ async def test_to_graph_context_batches_entity_hydration_for_recent_activity():
graph = await to_graph_context(context, entity_repository=repo, page=1, page_size=10)

assert len(repo.calls) == 1, f"Expected 1 entity lookup, got {len(repo.calls)}"
assert set(repo.calls[0]) == {1, 2, 3}
assert set(repo.calls[0][0]) == {1, 2, 3}
assert repo.calls[0][1] is True

first_result = graph.results[0]
first_primary = first_result.primary_result
Expand Down Expand Up @@ -272,7 +277,8 @@ async def test_to_graph_context_uses_lightweight_hydration_lookup():
graph = await to_graph_context(context, entity_repository=repo)

assert len(repo.hydration_calls) == 1
assert set(repo.hydration_calls[0]) == {1, 2}
assert set(repo.hydration_calls[0][0]) == {1, 2}
assert repo.hydration_calls[0][1] is True
relation = graph.results[0].related_results[0]
assert isinstance(relation, RelationSummary)
assert relation.from_entity_external_id == "ext-root"
Expand Down
35 changes: 35 additions & 0 deletions tests/repository/test_entity_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,41 @@ def fail_get_load_options():
assert found[0].external_id == sample_entity.external_id


@pytest.mark.asyncio
async def test_find_by_ids_for_hydration_can_include_cross_project_entities(
entity_repository: EntityRepository, sample_entity: Entity, session_maker
):
"""Context hydration can opt into IDs reached through explicit graph edges."""
async with db.scoped_session(session_maker) as session:
other_project = Project(name="other-project", path="/other")
session.add(other_project)
await session.flush()

other_entity = Entity(
project_id=other_project.id,
title="Other Project Entity",
note_type="test",
permalink="other-project/entity",
file_path="other-project/entity.md",
content_type="text/markdown",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(other_entity)
await session.flush()
other_entity_id = other_entity.id

project_scoped = await entity_repository.find_by_ids_for_hydration(
[sample_entity.id, other_entity_id]
)
cross_project = await entity_repository.find_by_ids_for_hydration(
[sample_entity.id, other_entity_id], include_cross_project=True
)

assert {entity.id for entity in project_scoped} == {sample_entity.id}
assert {entity.id for entity in cross_project} == {sample_entity.id, other_entity_id}


@pytest.mark.asyncio
async def test_get_permalink_to_file_path_map(entity_repository: EntityRepository, session_maker):
"""Test getting permalink -> file_path mapping for bulk operations."""
Expand Down
Loading