Skip to content

Commit 958d908

Browse files
authored
Merge pull request #4294 from pipecat-ai/ac/fix-assistant-turn-stopped-event
Fix on_assistant_turn_stopped not firing for tool-call-only responses
2 parents f013d56 + 403235e commit 958d908

3 files changed

Lines changed: 29 additions & 12 deletions

File tree

changelog/4294.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Fixed `on_assistant_turn_stopped` not resetting internal state when the LLM returned no text tokens. Added `interrupted` field to `AssistantTurnStoppedMessage` to indicate whether the assistant turn was interrupted.

src/pipecat/processors/aggregators/llm_response_universal.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,16 @@ class AssistantTurnStoppedMessage:
209209
content. This is the aggregated transcript that is then used in the context.
210210
211211
Parameters:
212-
content: The message content/text.
212+
content: The message content/text. May be empty if the LLM
213+
returned zero tokens (e.g. turn was interrupted before any tokens
214+
were received or pushed)
215+
interrupted: Whether the assistant turn was interrupted.
213216
timestamp: When the assistant turn started.
214217
215218
"""
216219

217220
content: str
221+
interrupted: bool
218222
timestamp: str
219223

220224

@@ -1032,11 +1036,11 @@ async def _handle_llm_messages_transform(self, frame: LLMMessagesTransformFrame)
10321036
await self.push_context_frame(FrameDirection.UPSTREAM)
10331037

10341038
async def _handle_interruptions(self, frame: InterruptionFrame):
1035-
await self._trigger_assistant_turn_stopped()
1039+
await self._trigger_assistant_turn_stopped(interrupted=True)
10361040
await self.reset()
10371041

10381042
async def _handle_end_or_cancel(self, frame: Frame):
1039-
await self._trigger_assistant_turn_stopped()
1043+
await self._trigger_assistant_turn_stopped(interrupted=isinstance(frame, CancelFrame))
10401044
if self._summarizer:
10411045
await self._summarizer.cleanup()
10421046

@@ -1394,17 +1398,23 @@ async def _trigger_assistant_turn_started(self):
13941398

13951399
await self._call_event_handler("on_assistant_turn_started")
13961400

1397-
async def _trigger_assistant_turn_stopped(self):
1401+
async def _trigger_assistant_turn_stopped(self, *, interrupted: bool = False):
1402+
if not self._assistant_turn_start_timestamp:
1403+
return
1404+
13981405
aggregation = await self.push_aggregation()
13991406
if aggregation:
14001407
# Strip turn completion markers from the transcript
1401-
content = self._maybe_strip_turn_completion_markers(aggregation)
1402-
message = AssistantTurnStoppedMessage(
1403-
content=content, timestamp=self._assistant_turn_start_timestamp
1404-
)
1405-
await self._call_event_handler("on_assistant_turn_stopped", message)
1408+
aggregation = self._maybe_strip_turn_completion_markers(aggregation)
1409+
1410+
message = AssistantTurnStoppedMessage(
1411+
content=aggregation,
1412+
interrupted=interrupted,
1413+
timestamp=self._assistant_turn_start_timestamp,
1414+
)
1415+
await self._call_event_handler("on_assistant_turn_stopped", message)
14061416

1407-
self._assistant_turn_start_timestamp = ""
1417+
self._assistant_turn_start_timestamp = ""
14081418

14091419
def _maybe_strip_turn_completion_markers(self, text: str) -> str:
14101420
"""Strip turn completion markers from assistant transcript.

tests/test_context_aggregators_universal.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,10 @@ async def on_assistant_turn_stopped(aggregator, message: AssistantTurnStoppedMes
580580
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
581581
await run_test(aggregator, frames_to_send=frames_to_send)
582582
self.assertTrue(should_start)
583-
self.assertIsNone(should_stop)
584-
self.assertIsNone(stop_message)
583+
self.assertTrue(should_stop)
584+
self.assertIsNotNone(stop_message)
585+
self.assertFalse(stop_message.interrupted)
586+
self.assertEqual(stop_message.content, "")
585587

586588
async def test_simple(self):
587589
context = LLMContext()
@@ -616,6 +618,7 @@ async def on_assistant_turn_stopped(aggregator, message: AssistantTurnStoppedMes
616618
)
617619
self.assertTrue(should_start)
618620
self.assertTrue(should_stop)
621+
self.assertFalse(stop_message.interrupted)
619622
self.assertEqual(stop_message.content, "Hello from Pipecat!")
620623

621624
async def test_multiple(self):
@@ -653,6 +656,7 @@ async def on_assistant_turn_stopped(aggregator, message: AssistantTurnStoppedMes
653656
)
654657
self.assertTrue(should_start)
655658
self.assertTrue(should_stop)
659+
self.assertFalse(stop_message.interrupted)
656660
self.assertEqual(stop_message.content, "Hello from Pipecat!")
657661

658662
async def test_multiple_text_with_spaces(self):
@@ -858,7 +862,9 @@ async def on_assistant_turn_stopped(aggregator, message: AssistantTurnStoppedMes
858862
)
859863
self.assertEqual(should_start, 2)
860864
self.assertEqual(should_stop, 2)
865+
self.assertTrue(stop_messages[0].interrupted)
861866
self.assertEqual(stop_messages[0].content, "Hello")
867+
self.assertFalse(stop_messages[1].interrupted)
862868
self.assertEqual(stop_messages[1].content, "Hello there!")
863869

864870
async def test_function_call(self):

0 commit comments

Comments
 (0)