Skip to content

Commit 0af0ae7

Browse files
rayhpengclaude
andcommitted
perf: use SQL aggregation for feedback stats and thread token usage
Replace Python-side counting in FeedbackRepository.aggregate_by_run with a single SELECT COUNT/SUM query. Add RunStore.aggregate_tokens_by_thread abstract method with SQL GROUP BY implementation in RunRepository and Python fallback in MemoryRunStore. Simplify the thread_token_usage endpoint to delegate to the new method, eliminating the limit=10000 truncation risk. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 332fb18 commit 0af0ae7

5 files changed

Lines changed: 98 additions & 41 deletions

File tree

backend/app/gateway/routers/thread_runs.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -310,32 +310,5 @@ async def list_run_events(
310310
async def thread_token_usage(thread_id: str, request: Request) -> dict:
311311
"""Thread-level token usage aggregation."""
312312
run_store = get_run_store(request)
313-
runs = await run_store.list_by_thread(thread_id, limit=10000)
314-
completed = [r for r in runs if r.get("status") in ("success", "error")]
315-
316-
total_tokens = sum(r.get("total_tokens", 0) for r in completed)
317-
total_input = sum(r.get("total_input_tokens", 0) for r in completed)
318-
total_output = sum(r.get("total_output_tokens", 0) for r in completed)
319-
320-
by_model: dict[str, dict] = {}
321-
for r in completed:
322-
model = r.get("model_name") or "unknown"
323-
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
324-
entry["tokens"] += r.get("total_tokens", 0)
325-
entry["runs"] += 1
326-
327-
by_caller = {
328-
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
329-
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
330-
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
331-
}
332-
333-
return {
334-
"thread_id": thread_id,
335-
"total_tokens": total_tokens,
336-
"total_input_tokens": total_input,
337-
"total_output_tokens": total_output,
338-
"total_runs": len(completed),
339-
"by_model": by_model,
340-
"by_caller": by_caller,
341-
}
313+
agg = await run_store.aggregate_tokens_by_thread(thread_id)
314+
return {"thread_id": thread_id, **agg}

backend/packages/harness/deerflow/persistence/repositories/feedback_repo.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from datetime import UTC, datetime
1010

11-
from sqlalchemy import select
11+
from sqlalchemy import case, func, select
1212
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
1313

1414
from deerflow.persistence.models.feedback import FeedbackRow
@@ -82,13 +82,17 @@ async def delete(self, feedback_id: str) -> bool:
8282
return True
8383

8484
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
85-
"""Aggregate feedback stats for a run."""
86-
items = await self.list_by_run(thread_id, run_id, limit=10000)
87-
positive = sum(1 for i in items if i["rating"] == 1)
88-
negative = sum(1 for i in items if i["rating"] == -1)
89-
return {
90-
"run_id": run_id,
91-
"total": len(items),
92-
"positive": positive,
93-
"negative": negative,
94-
}
85+
"""Aggregate feedback stats for a run using database-side counting."""
86+
stmt = select(
87+
func.count().label("total"),
88+
func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"),
89+
func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"),
90+
).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
91+
async with self._sf() as session:
92+
row = (await session.execute(stmt)).one()
93+
return {
94+
"run_id": run_id,
95+
"total": row.total,
96+
"positive": row.positive,
97+
"negative": row.negative,
98+
}

backend/packages/harness/deerflow/persistence/repositories/run_repo.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from datetime import UTC, datetime
1212
from typing import Any
1313

14-
from sqlalchemy import select, update
14+
from sqlalchemy import func, select, update
1515
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
1616

1717
from deerflow.persistence.models.run import RunRow
@@ -171,3 +171,52 @@ async def update_run_completion(
171171
async with self._sf() as session:
172172
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
173173
await session.commit()
174+
175+
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
176+
"""Aggregate token usage via a single SQL GROUP BY query."""
177+
_completed = RunRow.status.in_(("success", "error"))
178+
_thread = RunRow.thread_id == thread_id
179+
180+
stmt = (
181+
select(
182+
func.coalesce(RunRow.model_name, "unknown").label("model"),
183+
func.count().label("runs"),
184+
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
185+
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
186+
func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"),
187+
func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"),
188+
func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"),
189+
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
190+
)
191+
.where(_thread, _completed)
192+
.group_by(func.coalesce(RunRow.model_name, "unknown"))
193+
)
194+
195+
async with self._sf() as session:
196+
rows = (await session.execute(stmt)).all()
197+
198+
total_tokens = total_input = total_output = total_runs = 0
199+
lead_agent = subagent = middleware = 0
200+
by_model: dict[str, dict] = {}
201+
for r in rows:
202+
by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs}
203+
total_tokens += r.total_tokens
204+
total_input += r.total_input_tokens
205+
total_output += r.total_output_tokens
206+
total_runs += r.runs
207+
lead_agent += r.lead_agent
208+
subagent += r.subagent
209+
middleware += r.middleware
210+
211+
return {
212+
"total_tokens": total_tokens,
213+
"total_input_tokens": total_input,
214+
"total_output_tokens": total_output,
215+
"total_runs": total_runs,
216+
"by_model": by_model,
217+
"by_caller": {
218+
"lead_agent": lead_agent,
219+
"subagent": subagent,
220+
"middleware": middleware,
221+
},
222+
}

backend/packages/harness/deerflow/runtime/runs/store/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@ async def update_run_completion(
8484
@abc.abstractmethod
8585
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
8686
pass
87+
88+
@abc.abstractmethod
89+
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
90+
"""Aggregate token usage for completed runs in a thread.
91+
92+
Returns a dict with keys: total_tokens, total_input_tokens,
93+
total_output_tokens, total_runs, by_model (model_name → {tokens, runs}),
94+
by_caller ({lead_agent, subagent, middleware}).
95+
"""
96+
pass

backend/packages/harness/deerflow/runtime/runs/store/memory.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,24 @@ async def list_pending(self, *, before=None):
7777
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
7878
results.sort(key=lambda r: r["created_at"])
7979
return results
80+
81+
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
82+
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
83+
by_model: dict[str, dict] = {}
84+
for r in completed:
85+
model = r.get("model_name") or "unknown"
86+
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
87+
entry["tokens"] += r.get("total_tokens", 0)
88+
entry["runs"] += 1
89+
return {
90+
"total_tokens": sum(r.get("total_tokens", 0) for r in completed),
91+
"total_input_tokens": sum(r.get("total_input_tokens", 0) for r in completed),
92+
"total_output_tokens": sum(r.get("total_output_tokens", 0) for r in completed),
93+
"total_runs": len(completed),
94+
"by_model": by_model,
95+
"by_caller": {
96+
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
97+
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
98+
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
99+
},
100+
}

0 commit comments

Comments
 (0)