Skip to content

Commit 45712f1

Browse files
authored
Merge pull request #1480 from adichaudhary/fix/conversation-cache
[FEAT] Implemented conversation caching and invalidating on new messages
2 parents 8f3211d + a97681c commit 45712f1

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
Interactive agent to manually test conversation string caching.
3+
Run: python examples/conversation_cache_interactive.py
4+
Type messages, see cache stats update after each response.
5+
"""
6+
7+
from swarms import Agent
8+
9+
10+
def main() -> None:
11+
agent = Agent(
12+
agent_name="CacheTestAgent",
13+
model_name="claude-sonnet-4-5",
14+
max_loops=1,
15+
verbose=False,
16+
temperature=1.0,
17+
)
18+
19+
print("\n=== Conversation Cache Interactive Test ===")
20+
print("Type your messages. After each response, cache stats are shown.")
21+
print("Type 'quit' to exit.\n")
22+
23+
while True:
24+
user_input = input("You: ").strip()
25+
if user_input.lower() in ("quit", "exit", "q"):
26+
break
27+
if not user_input:
28+
continue
29+
30+
response = agent.run(user_input)
31+
print(f"\nAgent: {response}\n")
32+
33+
# Call get_str() multiple times to exercise the cache
34+
agent.short_memory.get_str()
35+
agent.short_memory.get_str()
36+
agent.short_memory.get_str()
37+
38+
stats = agent.short_memory.get_cache_stats()
39+
print(f"--- Cache Stats ---")
40+
print(f" Hits: {stats['hits']} (get_str() returned cached string)")
41+
print(f" Misses: {stats['misses']} (get_str() rebuilt the string)")
42+
print(f" Hit rate: {stats['hit_rate']:.0%}")
43+
print(f" Cached tokens: {stats['cached_tokens']}")
44+
print(f"-------------------\n")
45+
46+
47+
if __name__ == "__main__":
48+
main()

swarms/structs/conversation.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
export_method: str = "json",
9595
dynamic_context_window: bool = True,
9696
caching: bool = True,
97+
cache_enabled: bool = True,
9798
output_metadata: bool = False,
9899
):
99100

@@ -117,13 +118,19 @@ def __init__(
117118
self.token_count = token_count
118119
self.export_method = export_method
119120
self.dynamic_context_window = dynamic_context_window
120-
self.caching = caching
121+
# Ensure an explicit `caching=False` is respected; only fall back to
122+
# `cache_enabled` when `caching` is not explicitly set (i.e., is None).
123+
self.caching = cache_enabled if caching is None else caching
121124
self.output_metadata = output_metadata
122125

123126
if self.name is None:
124127
self.name = id
125128

126129
self.conversation_history = []
130+
self._str_cache: Optional[str] = None
131+
self._cache_hits: int = 0
132+
self._cache_misses: int = 0
133+
self._last_cached_tokens: int = 0
127134

128135
self.setup_file_path()
129136
self.setup()
@@ -264,6 +271,7 @@ def add_in_memory(
264271

265272
# Add message to conversation history
266273
self.conversation_history.append(message)
274+
self._str_cache = None
267275

268276
# Handle token counting in a separate thread if enabled
269277
if self.token_count is True:
@@ -413,6 +421,7 @@ def add_multiple(
413421
def delete(self, index: str):
414422
"""Delete a message from the conversation history."""
415423
self.conversation_history.pop(int(index))
424+
self._str_cache = None
416425

417426
def update(self, index: str, role, content):
418427
"""Update a message in the conversation history.
@@ -425,6 +434,7 @@ def update(self, index: str, role, content):
425434
if 0 <= int(index) < len(self.conversation_history):
426435
self.conversation_history[int(index)]["role"] = role
427436
self.conversation_history[int(index)]["content"] = content
437+
self._str_cache = None
428438
else:
429439
logger.warning(f"Invalid index: {index}")
430440

@@ -533,7 +543,42 @@ def get_str(self) -> str:
533543
Returns:
534544
str: The conversation history.
535545
"""
536-
return self.return_history_as_string()
546+
if not self.caching:
547+
return self.return_history_as_string()
548+
if self._str_cache is None:
549+
self._cache_misses += 1
550+
self._str_cache = self.return_history_as_string()
551+
self._last_cached_tokens = count_tokens(
552+
self._str_cache, self.tokenizer_model_name
553+
)
554+
else:
555+
self._cache_hits += 1
556+
return self._str_cache
557+
558+
def get_cache_stats(self) -> Dict[str, Any]:
559+
"""Return cache performance statistics for get_str().
560+
561+
Returns:
562+
Dict[str, Any]: A dictionary with hits, misses, cached_tokens,
563+
total_tokens, and hit_rate.
564+
"""
565+
total_calls = self._cache_hits + self._cache_misses
566+
cached_tokens = self._last_cached_tokens
567+
total_tokens = (
568+
self._cache_misses * cached_tokens
569+
+ self._cache_hits * cached_tokens
570+
)
571+
return {
572+
"hits": self._cache_hits,
573+
"misses": self._cache_misses,
574+
"cached_tokens": cached_tokens,
575+
"total_tokens": total_tokens,
576+
"hit_rate": (
577+
self._cache_hits / total_calls
578+
if total_calls > 0
579+
else 0.0
580+
),
581+
}
537582

538583
def to_dict(self) -> Dict[Any, Any]:
539584
"""
@@ -693,6 +738,7 @@ def load_from_json(self, filename: str):
693738
self.conversation_history = data.get(
694739
"conversation_history", []
695740
)
741+
self._str_cache = None
696742

697743
logger.info(
698744
f"Successfully loaded conversation from {filename}"
@@ -725,6 +771,7 @@ def load_from_yaml(self, filename: str):
725771
self.conversation_history = data.get(
726772
"conversation_history", []
727773
)
774+
self._str_cache = None
728775

729776
logger.info(
730777
f"Successfully loaded conversation from {filename}"
@@ -841,6 +888,7 @@ def truncate_memory_with_tokenizer(self):
841888

842889
# Update conversation history
843890
self.conversation_history = truncated_history
891+
self._str_cache = None
844892

845893
def _binary_search_truncate(
846894
self, text, target_tokens, model_name
@@ -902,6 +950,7 @@ def _binary_search_truncate(
902950
def clear(self):
903951
"""Clear the conversation history."""
904952
self.conversation_history = []
953+
self._str_cache = None
905954

906955
def to_json(self):
907956
"""Convert the conversation history to a JSON string.
@@ -911,6 +960,14 @@ def to_json(self):
911960
"""
912961
return json.dumps(self.conversation_history)
913962

963+
def to_yaml(self):
964+
"""Convert the conversation history to a YAML string.
965+
966+
Returns:
967+
str: The conversation history as a YAML string.
968+
"""
969+
return yaml.dump(self.conversation_history)
970+
914971
def to_list(self):
915972
"""Convert the conversation history to a list.
916973
@@ -1050,6 +1107,7 @@ def batch_add(self, messages: List[dict]):
10501107
messages (List[dict]): List of messages to add.
10511108
"""
10521109
self.conversation_history.extend(messages)
1110+
self._str_cache = None
10531111

10541112
@classmethod
10551113
def load_conversation(
@@ -1174,9 +1232,32 @@ def list_conversations(
11741232
conversations, key=lambda x: x["created_at"], reverse=True
11751233
)
11761234

1235+
@classmethod
1236+
def list_cached_conversations(
1237+
cls, conversations_dir: Optional[str] = None
1238+
) -> List[str]:
1239+
"""List names of all saved conversations (JSON and YAML).
1240+
1241+
Args:
1242+
conversations_dir (Optional[str]): Directory containing conversations.
1243+
1244+
Returns:
1245+
List[str]: List of conversation names.
1246+
"""
1247+
conv_dir = conversations_dir or get_conversation_dir()
1248+
if not os.path.exists(conv_dir):
1249+
return []
1250+
names = []
1251+
for filename in os.listdir(conv_dir):
1252+
if filename.endswith((".json", ".yaml", ".yml")):
1253+
name = os.path.splitext(filename)[0]
1254+
names.append(name)
1255+
return sorted(names)
1256+
11771257
def clear_memory(self):
11781258
"""Clear the memory of the conversation."""
11791259
self.conversation_history = []
1260+
self._str_cache = None
11801261

11811262
def _dynamic_auto_chunking_worker(self):
11821263
"""

0 commit comments

Comments
 (0)