Skip to content

Commit 6c2c1c5

Browse files
committed
feat: naive token estimation via tiktoken
1 parent 7b4df8a commit 6c2c1c5

File tree

3 files changed

+246
-1
lines changed

3 files changed

+246
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"opentelemetry-api>=1.30.0,<2.0.0",
4141
"opentelemetry-sdk>=1.30.0,<2.0.0",
4242
"opentelemetry-instrumentation-threading>=0.51b0,<1.00b0",
43+
"tiktoken>=0.7.0,<1.0.0",
4344
]
4445

4546

src/strands/models/model.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""Abstract base class for Agent model providers."""
22

33
import abc
4+
import json
45
import logging
56
from collections.abc import AsyncGenerator, AsyncIterable
67
from dataclasses import dataclass
78
from typing import TYPE_CHECKING, Any, Literal, TypeVar
89

10+
import tiktoken
911
from pydantic import BaseModel
1012

1113
from ..hooks.events import AfterInvocationEvent
1214
from ..plugins.plugin import Plugin
13-
from ..types.content import Messages, SystemContentBlock
15+
from ..types.content import ContentBlock, Messages, SystemContentBlock
1416
from ..types.streaming import StreamEvent
1517
from ..types.tools import ToolChoice, ToolSpec
1618

@@ -21,6 +23,79 @@
2123

2224
T = TypeVar("T", bound=BaseModel)
2325

26+
_DEFAULT_ENCODING = "cl100k_base"
27+
28+
29+
def _count_content_block_tokens(block: ContentBlock, encoding: tiktoken.Encoding) -> int:
30+
"""Count tokens for a single content block."""
31+
total = 0
32+
33+
if "text" in block:
34+
total += len(encoding.encode(block["text"]))
35+
36+
if "toolUse" in block:
37+
try:
38+
total += len(encoding.encode(json.dumps(block["toolUse"])))
39+
except (TypeError, ValueError):
40+
pass
41+
42+
if "toolResult" in block:
43+
try:
44+
total += len(encoding.encode(json.dumps(block["toolResult"])))
45+
except (TypeError, ValueError):
46+
pass
47+
48+
if "reasoningContent" in block:
49+
reasoning = block["reasoningContent"]
50+
if "reasoningText" in reasoning:
51+
reasoning_text = reasoning["reasoningText"]
52+
if "text" in reasoning_text:
53+
total += len(encoding.encode(reasoning_text["text"]))
54+
55+
if "guardContent" in block:
56+
guard = block["guardContent"]
57+
if "text" in guard:
58+
total += len(encoding.encode(guard["text"]["text"]))
59+
60+
if "citationsContent" in block:
61+
citations = block["citationsContent"]
62+
for item in citations.get("content", []):
63+
if "text" in item:
64+
total += len(encoding.encode(item["text"]))
65+
66+
return total
67+
68+
69+
def _estimate_tokens_with_tiktoken(
70+
messages: Messages,
71+
tool_specs: list[ToolSpec] | None = None,
72+
system_prompt: str | None = None,
73+
encoding_name: str = _DEFAULT_ENCODING,
74+
) -> int:
75+
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.
76+
77+
This is a best-effort fallback for providers that don't expose native counting.
78+
Accuracy varies by model but is sufficient for threshold-based decisions.
79+
"""
80+
encoding = tiktoken.get_encoding(encoding_name)
81+
total = 0
82+
83+
if system_prompt:
84+
total += len(encoding.encode(system_prompt))
85+
86+
for message in messages:
87+
for block in message["content"]:
88+
total += _count_content_block_tokens(block, encoding)
89+
90+
if tool_specs:
91+
for spec in tool_specs:
92+
try:
93+
total += len(encoding.encode(json.dumps(spec)))
94+
except (TypeError, ValueError):
95+
pass
96+
97+
return total
98+
2499

25100
@dataclass
26101
class CacheConfig:
@@ -130,6 +205,27 @@ def stream(
130205
"""
131206
pass
132207

208+
def estimate_tokens(
209+
self,
210+
messages: Messages,
211+
tool_specs: list[ToolSpec] | None = None,
212+
system_prompt: str | None = None,
213+
) -> int:
214+
"""Estimate token count for the given input before sending to the model.
215+
216+
Used for proactive context management (e.g., triggering compression at a
217+
threshold). Accuracy within 5-10% is sufficient — this is not used for billing.
218+
219+
Args:
220+
messages: List of message objects to estimate tokens for.
221+
tool_specs: List of tool specifications to include in the estimate.
222+
system_prompt: System prompt to include in the estimate.
223+
224+
Returns:
225+
Estimated total input tokens.
226+
"""
227+
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt)
228+
133229

134230
class _ModelPlugin(Plugin):
135231
"""Plugin that manages model-related lifecycle hooks."""

tests/strands/models/test_model.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,151 @@ def test_model_plugin_preserves_messages_when_not_stateful(model_plugin):
213213
model_plugin._on_after_invocation(event)
214214

215215
assert len(agent.messages) == 1
216+
217+
218+
def test_estimate_tokens_empty_messages(model):
219+
assert model.estimate_tokens(messages=[]) == 0
220+
221+
222+
def test_estimate_tokens_system_prompt_only(model):
223+
result = model.estimate_tokens(messages=[], system_prompt="You are a helpful assistant.")
224+
assert result > 0
225+
226+
227+
def test_estimate_tokens_text_messages(model, messages):
228+
result = model.estimate_tokens(messages=messages)
229+
assert result > 0
230+
231+
232+
def test_estimate_tokens_with_tool_specs(model, messages, tool_specs):
233+
without_tools = model.estimate_tokens(messages=messages)
234+
with_tools = model.estimate_tokens(messages=messages, tool_specs=tool_specs)
235+
assert with_tools > without_tools
236+
237+
238+
def test_estimate_tokens_with_system_prompt(model, messages, system_prompt):
239+
without_prompt = model.estimate_tokens(messages=messages)
240+
with_prompt = model.estimate_tokens(messages=messages, system_prompt=system_prompt)
241+
assert with_prompt > without_prompt
242+
243+
244+
def test_estimate_tokens_combined(model, messages, tool_specs, system_prompt):
245+
result = model.estimate_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt)
246+
assert result > 0
247+
248+
249+
def test_estimate_tokens_tool_use_block(model):
250+
messages = [
251+
{
252+
"role": "assistant",
253+
"content": [
254+
{
255+
"toolUse": {
256+
"toolUseId": "123",
257+
"name": "my_tool",
258+
"input": {"query": "test"},
259+
}
260+
}
261+
],
262+
}
263+
]
264+
result = model.estimate_tokens(messages=messages)
265+
assert result > 0
266+
267+
268+
def test_estimate_tokens_tool_result_block(model):
269+
messages = [
270+
{
271+
"role": "user",
272+
"content": [
273+
{
274+
"toolResult": {
275+
"toolUseId": "123",
276+
"content": [{"text": "tool output here"}],
277+
"status": "success",
278+
}
279+
}
280+
],
281+
}
282+
]
283+
result = model.estimate_tokens(messages=messages)
284+
assert result > 0
285+
286+
287+
def test_estimate_tokens_reasoning_block(model):
288+
messages = [
289+
{
290+
"role": "assistant",
291+
"content": [
292+
{
293+
"reasoningContent": {
294+
"reasoningText": {
295+
"text": "Let me think about this step by step.",
296+
}
297+
}
298+
}
299+
],
300+
}
301+
]
302+
result = model.estimate_tokens(messages=messages)
303+
assert result > 0
304+
305+
306+
def test_estimate_tokens_skips_binary_content(model):
307+
messages = [
308+
{
309+
"role": "user",
310+
"content": [{"image": {"format": "png", "source": {"bytes": b"fake image data"}}}],
311+
}
312+
]
313+
assert model.estimate_tokens(messages=messages) == 0
314+
315+
316+
def test_estimate_tokens_tool_result_with_bytes(model):
317+
messages = [
318+
{
319+
"role": "user",
320+
"content": [
321+
{
322+
"toolResult": {
323+
"toolUseId": "123",
324+
"content": [{"image": {"format": "png", "source": {"bytes": b"image data"}}}],
325+
"status": "success",
326+
}
327+
}
328+
],
329+
}
330+
]
331+
result = model.estimate_tokens(messages=messages)
332+
assert result == 0
333+
334+
335+
def test_estimate_tokens_citations_block(model):
336+
messages = [
337+
{
338+
"role": "assistant",
339+
"content": [
340+
{
341+
"citationsContent": {
342+
"content": [{"text": "According to the document, the answer is 42."}],
343+
"citations": [],
344+
}
345+
}
346+
],
347+
}
348+
]
349+
result = model.estimate_tokens(messages=messages)
350+
assert result > 0
351+
352+
353+
def test_estimate_tokens_all_inputs(model):
354+
messages = [
355+
{"role": "user", "content": [{"text": "hello world"}]},
356+
{"role": "assistant", "content": [{"text": "hi there"}]},
357+
]
358+
result = model.estimate_tokens(
359+
messages=messages,
360+
tool_specs=[{"name": "test", "description": "a test tool", "inputSchema": {"json": {}}}],
361+
system_prompt="Be helpful.",
362+
)
363+
assert result > 0

0 commit comments

Comments
 (0)