Skip to content

Commit 55bd572

Browse files
committed
feat: naive token estimation via tiktoken
1 parent 5391794 commit 55bd572

File tree

3 files changed

+302
-1
lines changed

3 files changed

+302
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"opentelemetry-api>=1.30.0,<2.0.0",
4141
"opentelemetry-sdk>=1.30.0,<2.0.0",
4242
"opentelemetry-instrumentation-threading>=0.51b0,<1.00b0",
43+
"tiktoken>=0.7.0,<1.0.0",
4344
]
4445

4546

src/strands/models/model.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""Abstract base class for Agent model providers."""
22

33
import abc
4+
import json
45
import logging
56
from collections.abc import AsyncGenerator, AsyncIterable
67
from dataclasses import dataclass
78
from typing import TYPE_CHECKING, Any, Literal, TypeVar
89

10+
import tiktoken
911
from pydantic import BaseModel
1012

1113
from ..hooks.events import AfterInvocationEvent
1214
from ..plugins.plugin import Plugin
13-
from ..types.content import Messages, SystemContentBlock
15+
from ..types.content import ContentBlock, Messages, SystemContentBlock
1416
from ..types.streaming import StreamEvent
1517
from ..types.tools import ToolChoice, ToolSpec
1618

@@ -21,6 +23,87 @@
2123

2224
T = TypeVar("T", bound=BaseModel)
2325

26+
_DEFAULT_ENCODING = "cl100k_base"
27+
_cached_encoding: tiktoken.Encoding | None = None
28+
29+
30+
def _get_encoding() -> tiktoken.Encoding:
31+
"""Get the default tiktoken encoding, caching to avoid repeated lookups."""
32+
global _cached_encoding
33+
if _cached_encoding is None:
34+
_cached_encoding = tiktoken.get_encoding(_DEFAULT_ENCODING)
35+
return _cached_encoding
36+
37+
38+
def _count_content_block_tokens(block: ContentBlock, encoding: tiktoken.Encoding) -> int:
39+
"""Count tokens for a single content block."""
40+
total = 0
41+
42+
if "text" in block:
43+
total += len(encoding.encode(block["text"]))
44+
45+
if "toolUse" in block:
46+
try:
47+
total += len(encoding.encode(json.dumps(block["toolUse"])))
48+
except (TypeError, ValueError):
49+
pass
50+
51+
if "toolResult" in block:
52+
try:
53+
total += len(encoding.encode(json.dumps(block["toolResult"])))
54+
except (TypeError, ValueError):
55+
pass
56+
57+
if "reasoningContent" in block:
58+
reasoning = block["reasoningContent"]
59+
if "reasoningText" in reasoning:
60+
reasoning_text = reasoning["reasoningText"]
61+
if "text" in reasoning_text:
62+
total += len(encoding.encode(reasoning_text["text"]))
63+
64+
if "guardContent" in block:
65+
guard = block["guardContent"]
66+
if "text" in guard:
67+
total += len(encoding.encode(guard["text"]["text"]))
68+
69+
if "citationsContent" in block:
70+
citations = block["citationsContent"]
71+
for item in citations.get("content", []):
72+
if "text" in item:
73+
total += len(encoding.encode(item["text"]))
74+
75+
return total
76+
77+
78+
def _estimate_tokens_with_tiktoken(
79+
messages: Messages,
80+
tool_specs: list[ToolSpec] | None = None,
81+
system_prompt: str | None = None,
82+
) -> int:
83+
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.
84+
85+
This is a best-effort fallback for providers that don't expose native counting.
86+
Accuracy varies by model but is sufficient for threshold-based decisions.
87+
"""
88+
encoding = _get_encoding()
89+
total = 0
90+
91+
if system_prompt:
92+
total += len(encoding.encode(system_prompt))
93+
94+
for message in messages:
95+
for block in message["content"]:
96+
total += _count_content_block_tokens(block, encoding)
97+
98+
if tool_specs:
99+
for spec in tool_specs:
100+
try:
101+
total += len(encoding.encode(json.dumps(spec)))
102+
except (TypeError, ValueError):
103+
pass
104+
105+
return total
106+
24107

25108
@dataclass
26109
class CacheConfig:
@@ -130,6 +213,32 @@ def stream(
130213
"""
131214
pass
132215

216+
def _estimate_tokens(
217+
self,
218+
messages: Messages,
219+
tool_specs: list[ToolSpec] | None = None,
220+
system_prompt: str | None = None,
221+
) -> int:
222+
"""Estimate token count for the given input before sending to the model.
223+
224+
Used for proactive context management (e.g., triggering compression at a
225+
threshold). This is a naive approximation using tiktoken's cl100k_base encoding.
226+
Accuracy varies by model provider but is typically within 5-10% for most providers.
227+
Not intended for billing or precise quota calculations.
228+
229+
Subclasses may override this method to provide model-specific token counting
230+
using native APIs for improved accuracy.
231+
232+
Args:
233+
messages: List of message objects to estimate tokens for.
234+
tool_specs: List of tool specifications to include in the estimate.
235+
system_prompt: System prompt to include in the estimate.
236+
237+
Returns:
238+
Estimated total input tokens.
239+
"""
240+
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt)
241+
133242

134243
class _ModelPlugin(Plugin):
135244
"""Plugin that manages model-related lifecycle hooks."""

tests/strands/models/test_model.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,194 @@ def test_model_plugin_preserves_messages_when_not_stateful(model_plugin):
213213
model_plugin._on_after_invocation(event)
214214

215215
assert len(agent.messages) == 1
216+
217+
218+
def test_estimate_tokens_empty_messages(model):
219+
assert model._estimate_tokens(messages=[]) == 0
220+
221+
222+
def test_estimate_tokens_system_prompt_only(model):
223+
result = model._estimate_tokens(messages=[], system_prompt="You are a helpful assistant.")
224+
assert result > 0
225+
226+
227+
def test_estimate_tokens_text_messages(model, messages):
228+
result = model._estimate_tokens(messages=messages)
229+
assert result > 0
230+
231+
232+
def test_estimate_tokens_with_tool_specs(model, messages, tool_specs):
233+
without_tools = model._estimate_tokens(messages=messages)
234+
with_tools = model._estimate_tokens(messages=messages, tool_specs=tool_specs)
235+
assert with_tools > without_tools
236+
237+
238+
def test_estimate_tokens_with_system_prompt(model, messages, system_prompt):
239+
without_prompt = model._estimate_tokens(messages=messages)
240+
with_prompt = model._estimate_tokens(messages=messages, system_prompt=system_prompt)
241+
assert with_prompt > without_prompt
242+
243+
244+
def test_estimate_tokens_combined(model, messages, tool_specs, system_prompt):
245+
result = model._estimate_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt)
246+
assert result > 0
247+
248+
249+
def test_estimate_tokens_tool_use_block(model):
250+
messages = [
251+
{
252+
"role": "assistant",
253+
"content": [
254+
{
255+
"toolUse": {
256+
"toolUseId": "123",
257+
"name": "my_tool",
258+
"input": {"query": "test"},
259+
}
260+
}
261+
],
262+
}
263+
]
264+
result = model._estimate_tokens(messages=messages)
265+
assert result > 0
266+
267+
268+
def test_estimate_tokens_tool_result_block(model):
269+
messages = [
270+
{
271+
"role": "user",
272+
"content": [
273+
{
274+
"toolResult": {
275+
"toolUseId": "123",
276+
"content": [{"text": "tool output here"}],
277+
"status": "success",
278+
}
279+
}
280+
],
281+
}
282+
]
283+
result = model._estimate_tokens(messages=messages)
284+
assert result > 0
285+
286+
287+
def test_estimate_tokens_reasoning_block(model):
288+
messages = [
289+
{
290+
"role": "assistant",
291+
"content": [
292+
{
293+
"reasoningContent": {
294+
"reasoningText": {
295+
"text": "Let me think about this step by step.",
296+
}
297+
}
298+
}
299+
],
300+
}
301+
]
302+
result = model._estimate_tokens(messages=messages)
303+
assert result > 0
304+
305+
306+
def test_estimate_tokens_skips_binary_content(model):
307+
messages = [
308+
{
309+
"role": "user",
310+
"content": [{"image": {"format": "png", "source": {"bytes": b"fake image data"}}}],
311+
}
312+
]
313+
assert model._estimate_tokens(messages=messages) == 0
314+
315+
316+
def test_estimate_tokens_tool_result_with_bytes(model):
317+
messages = [
318+
{
319+
"role": "user",
320+
"content": [
321+
{
322+
"toolResult": {
323+
"toolUseId": "123",
324+
"content": [{"image": {"format": "png", "source": {"bytes": b"image data"}}}],
325+
"status": "success",
326+
}
327+
}
328+
],
329+
}
330+
]
331+
result = model._estimate_tokens(messages=messages)
332+
assert result == 0
333+
334+
335+
def test_estimate_tokens_guard_content_block(model):
336+
messages = [
337+
{
338+
"role": "assistant",
339+
"content": [{"guardContent": {"text": {"text": "This content was filtered by guardrails."}}}],
340+
}
341+
]
342+
result = model._estimate_tokens(messages=messages)
343+
assert result > 0
344+
345+
346+
def test_estimate_tokens_tool_use_with_bytes(model):
347+
messages = [
348+
{
349+
"role": "assistant",
350+
"content": [
351+
{
352+
"toolUse": {
353+
"toolUseId": "123",
354+
"name": "my_tool",
355+
"input": {"data": b"binary data"},
356+
}
357+
}
358+
],
359+
}
360+
]
361+
result = model._estimate_tokens(messages=messages)
362+
assert result == 0
363+
364+
365+
def test_estimate_tokens_non_serializable_tool_spec(model, messages):
366+
tool_specs = [
367+
{
368+
"name": "test",
369+
"description": "a tool",
370+
"inputSchema": {"json": {"default": b"bytes"}},
371+
}
372+
]
373+
result = model._estimate_tokens(messages=messages, tool_specs=tool_specs)
374+
# Should still count the message tokens even though tool spec fails
375+
assert result > 0
376+
377+
378+
def test_estimate_tokens_citations_block(model):
379+
messages = [
380+
{
381+
"role": "assistant",
382+
"content": [
383+
{
384+
"citationsContent": {
385+
"content": [{"text": "According to the document, the answer is 42."}],
386+
"citations": [],
387+
}
388+
}
389+
],
390+
}
391+
]
392+
result = model._estimate_tokens(messages=messages)
393+
assert result > 0
394+
395+
396+
def test_estimate_tokens_all_inputs(model):
397+
messages = [
398+
{"role": "user", "content": [{"text": "hello world"}]},
399+
{"role": "assistant", "content": [{"text": "hi there"}]},
400+
]
401+
result = model._estimate_tokens(
402+
messages=messages,
403+
tool_specs=[{"name": "test", "description": "a test tool", "inputSchema": {"json": {}}}],
404+
system_prompt="Be helpful.",
405+
)
406+
assert result > 0

0 commit comments

Comments
 (0)