Skip to content

Commit 24f6695

Browse files
authored
fix(agents): subgraphWithTask & subtask missing tool results in prompt if other tools where requested along with finish tool (#1971)
In `subgraphWithTask` models can sometimes request other tools along with the finish tool. Previously, other tool results were ignored and not added to prompt, which led to error, since prompts can't have tool calls without matching tool results. Updated the `subgraphWithTask` logic to append all tool results when the finish tool was called
1 parent 1cbae45 commit 24f6695

4 files changed

Lines changed: 262 additions & 21 deletions

File tree

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentFunctionalContextBaseCommon.kt

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import ai.koog.agents.core.dsl.extension.HistoryCompressionStrategy
1111
import ai.koog.agents.core.environment.AIAgentEnvironment
1212
import ai.koog.agents.core.environment.ReceivedToolResult
1313
import ai.koog.agents.core.environment.SafeTool
14+
import ai.koog.agents.core.environment.ToolResultKind
1415
import ai.koog.agents.core.environment.result
1516
import ai.koog.agents.core.environment.toSafeResult
1617
import ai.koog.agents.core.feature.pipeline.AIAgentPipeline
@@ -648,7 +649,16 @@ public open class AIAgentFunctionalContextBaseCommon<Pipeline : AIAgentPipeline>
648649
response is Message.Tool.Call -> {
649650
val toolResult = executeToolHacked(response, finishTool)
650651

651-
if (toolResult.tool == finishTool.descriptor.name) {
652+
if (toolResult.tool == finishTool.descriptor.name && toolResult.resultKind is ToolResultKind.Success) {
653+
// Prompt must contain tool result
654+
llm.writeSession {
655+
appendPrompt {
656+
tool {
657+
result(toolResult)
658+
}
659+
}
660+
}
661+
652662
return toolResult.toSafeResult(finishTool, config.serializer).asSuccessful().result
653663
}
654664

@@ -722,9 +732,20 @@ public open class AIAgentFunctionalContextBaseCommon<Pipeline : AIAgentPipeline>
722732
val toolResults =
723733
executeMultipleToolsHacked(toolCalls, finishTool, parallelTools = runMode == ToolCalls.PARALLEL)
724734

725-
toolResults.firstOrNull { it.tool == finishTool.descriptor.name }
735+
toolResults.firstOrNull { it.tool == finishTool.descriptor.name && it.resultKind is ToolResultKind.Success }
726736
?.let { finishResult ->
727-
return finishResult.toSafeResult(finishTool, config.serializer).asSuccessful().result
737+
// Prompt must contain all tool results
738+
llm.writeSession {
739+
appendPrompt {
740+
tool {
741+
toolResults.forEach { result(it) }
742+
}
743+
}
744+
}
745+
746+
return finishResult
747+
.toSafeResult(finishTool, config.serializer)
748+
.asSuccessful().result
728749
}
729750

730751
responses = sendMultipleToolResults(toolResults)

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import ai.koog.agents.core.dsl.extension.nodeLLMSendMultipleToolResults
1818
import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult
1919
import ai.koog.agents.core.environment.ReceivedToolResult
2020
import ai.koog.agents.core.environment.ToolResultKind
21+
import ai.koog.agents.core.environment.result
2122
import ai.koog.agents.core.environment.toSafeResult
2223
import ai.koog.agents.core.tools.Tool
2324
import ai.koog.agents.core.tools.ToolDescriptor
@@ -351,7 +352,7 @@ public fun <Input : Any, OutputTransformed : Any> subgraphWithTask(
351352
defineTask: suspend AIAgentGraphContextBase.(input: Input) -> String
352353
): AIAgentSubgraphDelegate<Input, OutputTransformed> = subgraph<Input, OutputTransformed>(
353354
inputType = inputType,
354-
outputType = inputType,
355+
outputType = finishTool.resultType,
355356
name = name,
356357
toolSelectionStrategy = toolSelectionStrategy,
357358
llmModel = llmModel,
@@ -681,16 +682,28 @@ public fun <Input, Output, OutputTransformed> AIAgentSubgraphBuilderBase<Input,
681682
defineTask(input)
682683
}
683684

684-
val finalizeTask by node<ReceivedToolResult, OutputTransformed>(
685-
inputType = typeToken<ReceivedToolResult>(),
685+
val finalizeTask by node<List<ReceivedToolResult>, OutputTransformed>(
686+
inputType = typeToken<List<ReceivedToolResult>>(),
686687
outputType = outputTransformedType
687-
) { toolResult ->
688+
) { toolResults ->
688689
llm.writeSession {
690+
// Append all tool results to the prompt, otherwise there will be calls without results, which is invalid
691+
appendPrompt {
692+
tool {
693+
toolResults.forEach { result(it) }
694+
}
695+
}
696+
689697
// Restore original tools
690698
tools = storage.get(originalToolsKey)!!
691699
}
692700

693-
toolResult.toSafeResult(finishTool, config.serializer).asSuccessful().result
701+
// Take the first finish tool and return as a result
702+
toolResults
703+
.first { it.tool == finishTool.name && it.resultKind is ToolResultKind.Success }
704+
.toSafeResult(finishTool, config.serializer)
705+
.asSuccessful()
706+
.result
694707
}
695708

696709
// Helper node to overcome problems of the current api and repeat less code when writing routing conditions
@@ -795,10 +808,10 @@ public fun <Input, Output, OutputTransformed> AIAgentSubgraphBuilderBase<Input,
795808
edge(
796809
callToolsHacked forwardTo finalizeTask
797810
onCondition { toolResults ->
798-
toolResults.firstOrNull()
799-
?.let { it.tool == finishTool.name && it.resultKind is ToolResultKind.Success } == true
811+
toolResults
812+
.any { it.tool == finishTool.name && it.resultKind is ToolResultKind.Success }
800813
}
801-
transformed { toolsResults -> toolsResults.first() }
814+
transformed { toolsResults -> toolsResults }
802815
)
803816

804817
if (runMode == ToolCalls.SINGLE_RUN_SEQUENTIAL) {
@@ -872,16 +885,6 @@ internal suspend fun <Output, OutputTransformed> AIAgentContext.executeFinishToo
872885
)
873886
}
874887

875-
// Append a final tool call result to the prompt for further LLM calls
876-
// to see it (otherwise they would fail)
877-
llm.writeSession {
878-
appendPrompt {
879-
tool {
880-
result(toolCall.id, toolCall.tool, toolCall.content)
881-
}
882-
}
883-
}
884-
885888
return ReceivedToolResult(
886889
id = toolCall.id,
887890
tool = finishTool.name,
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package ai.koog.agents.core.agent.context
2+
3+
import ai.koog.agents.core.agent.AIAgent
4+
import ai.koog.agents.core.agent.ToolCalls
5+
import ai.koog.agents.core.agent.functionalStrategy
6+
import ai.koog.agents.core.tools.Tool
7+
import ai.koog.agents.core.tools.ToolRegistry
8+
import ai.koog.agents.features.eventHandler.feature.EventHandler
9+
import ai.koog.agents.testing.tools.TestBlankTool
10+
import ai.koog.agents.testing.tools.TestFinishTool
11+
import ai.koog.agents.testing.tools.getMockExecutor
12+
import ai.koog.prompt.dsl.Prompt
13+
import ai.koog.prompt.executor.clients.openai.OpenAIModels
14+
import ai.koog.prompt.executor.model.PromptExecutor
15+
import ai.koog.prompt.message.Message
16+
import ai.koog.serialization.kotlinx.KotlinxSerializer
17+
import ai.koog.utils.io.use
18+
import io.kotest.assertions.withClue
19+
import io.kotest.matchers.shouldBe
20+
import kotlinx.coroutines.test.runTest
21+
import kotlin.test.Test
22+
23+
class AIAgentFunctionalContextTest {
24+
25+
private val serializer = KotlinxSerializer()
26+
27+
/**
28+
* Verifies that `subtaskWithMultiToolMode` (`ToolCalls.SEQUENTIAL`) appends tool results for
29+
* ALL tool calls to the prompt when the LLM calls the finish tool together with another tool
30+
* in a single response.
31+
*/
32+
@Test
33+
fun testSubtaskWithMultiToolModeSequentialAllToolCallsHaveToolResults() = runTest {
34+
runAndAssertAllToolCallsHaveResults(ToolCalls.SEQUENTIAL)
35+
}
36+
37+
/**
38+
* Verifies that `subtaskWithMultiToolMode` (`ToolCalls.PARALLEL`) appends tool results for
39+
* ALL tool calls to the prompt when the LLM calls the finish tool together with another tool
40+
* in a single response.
41+
*/
42+
@Test
43+
fun testSubtaskWithMultiToolModeParallelAllToolCallsHaveToolResults() = runTest {
44+
runAndAssertAllToolCallsHaveResults(ToolCalls.PARALLEL)
45+
}
46+
47+
/**
48+
* Verifies that `subtaskWithSingleToolMode` (`ToolCalls.SINGLE_RUN_SEQUENTIAL`) appends tool
49+
* results for every tool call to the prompt: a regular tool call followed by the finish tool
50+
* call across two LLM round-trips (single-tool mode allows only one tool per LLM response).
51+
*/
52+
@Test
53+
fun testSubtaskWithSingleToolModeAllToolCallsHaveToolResults() = runTest {
54+
val blankTool = TestBlankTool()
55+
val finishTool = TestFinishTool
56+
57+
val toolRegistry = ToolRegistry { tool(blankTool) }
58+
59+
val inputRequest = "Test input"
60+
val blankToolResult = "Working on it"
61+
val finishToolResult = "Finished"
62+
63+
val mockExecutor = getMockExecutor(serializer) {
64+
mockLLMToolCall(blankTool, TestBlankTool.Args(blankToolResult)) onRequestEquals inputRequest
65+
mockLLMToolCall(finishTool, TestFinishTool.Args(finishToolResult)) onRequestContains blankToolResult
66+
}
67+
68+
val finalPrompt = runAgentAndCapturePrompt(
69+
mockExecutor = mockExecutor,
70+
toolRegistry = toolRegistry,
71+
inputRequest = inputRequest,
72+
blankTool = blankTool,
73+
finishTool = finishTool,
74+
runMode = ToolCalls.SINGLE_RUN_SEQUENTIAL,
75+
)
76+
77+
assertEqualToolCallAndResultCount(finalPrompt, expectedSize = 2)
78+
}
79+
80+
private suspend fun runAndAssertAllToolCallsHaveResults(runMode: ToolCalls) {
81+
val blankTool = TestBlankTool()
82+
val finishTool = TestFinishTool
83+
84+
val toolRegistry = ToolRegistry { tool(blankTool) }
85+
86+
val inputRequest = "Test input"
87+
val blankToolResult = "I'm done"
88+
val finishToolResult = "Finished"
89+
90+
val mockExecutor = getMockExecutor(serializer) {
91+
@Suppress("UNCHECKED_CAST")
92+
mockLLMToolCall(
93+
listOf(
94+
blankTool to TestBlankTool.Args(blankToolResult),
95+
finishTool to TestFinishTool.Args(finishToolResult),
96+
) as List<Pair<Tool<Any?, Any?>, Any?>>
97+
) onRequestEquals inputRequest
98+
}
99+
100+
val finalPrompt = runAgentAndCapturePrompt(
101+
mockExecutor = mockExecutor,
102+
toolRegistry = toolRegistry,
103+
inputRequest = inputRequest,
104+
blankTool = blankTool,
105+
finishTool = finishTool,
106+
runMode = runMode,
107+
)
108+
109+
assertEqualToolCallAndResultCount(finalPrompt, expectedSize = 2)
110+
}
111+
112+
private suspend fun runAgentAndCapturePrompt(
113+
mockExecutor: PromptExecutor,
114+
toolRegistry: ToolRegistry,
115+
inputRequest: String,
116+
blankTool: TestBlankTool,
117+
finishTool: Tool<TestFinishTool.Args, String>,
118+
runMode: ToolCalls,
119+
): Prompt {
120+
lateinit var finalPrompt: Prompt
121+
122+
AIAgent(
123+
promptExecutor = mockExecutor,
124+
llmModel = OpenAIModels.Chat.GPT4o,
125+
toolRegistry = toolRegistry,
126+
strategy = functionalStrategy<String, String> { input ->
127+
subtask(
128+
taskDescription = input,
129+
tools = listOf(blankTool),
130+
finishTool = finishTool,
131+
runMode = runMode,
132+
)
133+
},
134+
systemPrompt = "You are a test agent.",
135+
) {
136+
install(EventHandler) {
137+
onAgentCompleted { ctx ->
138+
finalPrompt = ctx.context.llm.prompt
139+
}
140+
}
141+
}.use { agent ->
142+
agent.run(inputRequest, null)
143+
}
144+
145+
return finalPrompt
146+
}
147+
148+
private fun assertEqualToolCallAndResultCount(prompt: Prompt, expectedSize: Int) {
149+
val toolCalls = prompt.messages.filterIsInstance<Message.Tool.Call>()
150+
val toolResults = prompt.messages.filterIsInstance<Message.Tool.Result>()
151+
152+
withClue("Equal number of tool calls and tool results") {
153+
toolCalls.size shouldBe expectedSize
154+
toolResults.size shouldBe expectedSize
155+
}
156+
}
157+
}

agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithTaskTest.kt

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import ai.koog.prompt.streaming.StreamFrame
2727
import ai.koog.serialization.kotlinx.KotlinxSerializer
2828
import ai.koog.utils.io.use
2929
import io.github.oshai.kotlinlogging.KotlinLogging
30+
import io.kotest.assertions.withClue
31+
import io.kotest.matchers.shouldBe
3032
import kotlinx.coroutines.flow.Flow
3133
import kotlinx.coroutines.flow.emptyFlow
3234
import kotlinx.coroutines.test.runTest
@@ -794,6 +796,64 @@ class SubgraphWithTaskTest {
794796

795797
//endregion
796798

799+
/**
800+
* If the model called finish tool along with some other tools, all results must be present, not only finish tool result.
801+
*/
802+
@Test
803+
fun testAllToolCallsHaveRespectiveToolResults() = runTest {
804+
val blankTool = TestBlankTool()
805+
val finishTool = TestFinishTool
806+
807+
val toolRegistry = ToolRegistry {
808+
tool(blankTool)
809+
}
810+
811+
val model = OpenAIModels.Chat.GPT4o
812+
813+
val inputRequest = "Test input"
814+
val blankToolResult = "I'm done"
815+
val finishToolResult = "Finished"
816+
817+
val mockExecutor = getMockExecutor(serializer) {
818+
@Suppress("UNCHECKED_CAST")
819+
mockLLMToolCall(
820+
listOf(
821+
blankTool to TestBlankTool.Args(blankToolResult),
822+
finishTool to TestFinishTool.Args(finishToolResult),
823+
) as List<Pair<Tool<Any?, Any?>, Any?>>
824+
) onRequestEquals inputRequest
825+
}
826+
827+
lateinit var finalPrompt: Prompt
828+
829+
createAgent(
830+
model = model,
831+
runMode = ToolCalls.SEQUENTIAL,
832+
toolRegistry = toolRegistry,
833+
executor = mockExecutor,
834+
finishTool = finishTool,
835+
installFeatures = {
836+
install(EventHandler) {
837+
onAgentCompleted { ctx ->
838+
finalPrompt = ctx.context.llm.prompt
839+
}
840+
}
841+
}
842+
).use { agent ->
843+
val agentResult = agent.run(inputRequest, null)
844+
logger.info { "Agent is finished with result: $agentResult" }
845+
}
846+
847+
val toolCalls = finalPrompt.messages.filterIsInstance<Message.Tool.Call>()
848+
val toolResults = finalPrompt.messages.filterIsInstance<Message.Tool.Result>()
849+
850+
withClue("Equal number of tool calls and tool results") {
851+
val expectedSize = 2
852+
toolCalls.size shouldBe expectedSize
853+
toolResults.size shouldBe expectedSize
854+
}
855+
}
856+
797857
//region Private Methods
798858

799859
fun createAgent(

0 commit comments

Comments
 (0)