Skip to content

Commit 1c7124a

Browse files
committed
KG-178. Add contextual environment instance that adds an agent context into environment
- Add an agent contextual environment that wraps the original environment with a context from AIAgent and provides all data from an agent context into environment execution logic; - Update Tool Call handlers in pipeline and all features; - Update logic in all features related to tool calls; - Update related tests.
1 parent 9fa1df1 commit 1c7124a

56 files changed

Lines changed: 1080 additions & 462 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/FunctionalAIAgent.kt

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import ai.koog.agents.core.agent.context.AIAgentLLMContext
66
import ai.koog.agents.core.agent.entity.AIAgentStateManager
77
import ai.koog.agents.core.agent.entity.AIAgentStorage
88
import ai.koog.agents.core.annotation.InternalAgentsApi
9+
import ai.koog.agents.core.environment.AIAgentEnvironment
10+
import ai.koog.agents.core.environment.ContextualAgentEnvironment
911
import ai.koog.agents.core.environment.GenericAgentEnvironment
1012
import ai.koog.agents.core.feature.AIAgentFeature
1113
import ai.koog.agents.core.feature.AIAgentFunctionalFeature
@@ -52,13 +54,6 @@ public class FunctionalAIAgent<Input, Output>(
5254

5355
override val pipeline: AIAgentFunctionalPipeline = AIAgentFunctionalPipeline(clock)
5456

55-
private val environment = GenericAgentEnvironment(
56-
agentId = this.id,
57-
logger = logger,
58-
toolRegistry = toolRegistry,
59-
pipeline = pipeline
60-
)
61-
6257
/**
6358
* Represents a context for managing and configuring features in an AI agent.
6459
* Provides functionality to install and configure features into a specific instance of an AI agent.
@@ -83,7 +78,13 @@ public class FunctionalAIAgent<Input, Output>(
8378
}
8479

8580
override suspend fun prepareContext(agentInput: Input, runId: String): AIAgentFunctionalContext {
86-
val llm = AIAgentLLMContext(
81+
val environment = GenericAgentEnvironment(
82+
agentId = id,
83+
logger = logger,
84+
toolRegistry = toolRegistry,
85+
)
86+
87+
val initialAgentLLMContext = AIAgentLLMContext(
8788
tools = toolRegistry.tools.map { it.descriptor },
8889
toolRegistry = toolRegistry,
8990
prompt = agentConfig.prompt,
@@ -99,17 +100,51 @@ public class FunctionalAIAgent<Input, Output>(
99100
clock = clock
100101
)
101102

102-
return AIAgentFunctionalContext(
103-
environment = environment,
103+
val preparedEnvironment = prepareEnvironment()
104+
105+
// Context
106+
val agentContext = AIAgentFunctionalContext(
107+
environment = preparedEnvironment,
104108
agentId = id,
105109
runId = runId,
106110
agentInput = agentInput,
107111
config = agentConfig,
108-
llm = llm,
112+
llm = initialAgentLLMContext,
109113
stateManager = AIAgentStateManager(),
110114
storage = AIAgentStorage(),
111115
strategyName = strategy.name,
112116
pipeline = pipeline
113117
)
118+
119+
// Updated environment
120+
val environmentProxy = ContextualAgentEnvironment(
121+
environment = preparedEnvironment,
122+
context = agentContext,
123+
)
124+
125+
val updatedLLMContext = agentContext.llm.copy(
126+
environment = environmentProxy
127+
)
128+
129+
// Update the environment and llm with a created context instance
130+
return agentContext.copy(
131+
parentRootContext = agentContext.parentContext,
132+
llm = updatedLLMContext,
133+
environment = environmentProxy
134+
)
135+
}
136+
137+
//region Private Methods
138+
139+
private fun prepareEnvironment(): AIAgentEnvironment {
140+
val baseEnvironment = GenericAgentEnvironment(
141+
agentId = id,
142+
logger = logger,
143+
toolRegistry = toolRegistry,
144+
)
145+
146+
return baseEnvironment
114147
}
148+
149+
//endregion Private Methods
115150
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/GraphAIAgent.kt

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import ai.koog.agents.core.agent.entity.AIAgentGraphStrategy
88
import ai.koog.agents.core.agent.entity.AIAgentStateManager
99
import ai.koog.agents.core.agent.entity.AIAgentStorage
1010
import ai.koog.agents.core.annotation.InternalAgentsApi
11+
import ai.koog.agents.core.environment.AIAgentEnvironment
12+
import ai.koog.agents.core.environment.ContextualAgentEnvironment
1113
import ai.koog.agents.core.environment.GenericAgentEnvironment
1214
import ai.koog.agents.core.feature.AIAgentFeature
1315
import ai.koog.agents.core.feature.AIAgentGraphFeature
@@ -66,13 +68,6 @@ public open class GraphAIAgent<Input, Output>(
6668

6769
override val pipeline: AIAgentGraphPipeline = AIAgentGraphPipeline(clock)
6870

69-
private val environment = GenericAgentEnvironment(
70-
agentId = this.id,
71-
logger = logger,
72-
toolRegistry = toolRegistry,
73-
pipeline = pipeline
74-
)
75-
7671
/**
7772
* The context for adding and configuring features in a Kotlin AI Agent instance.
7873
*
@@ -103,16 +98,9 @@ public open class GraphAIAgent<Input, Output>(
10398
val stateManager = AIAgentStateManager()
10499
val storage = AIAgentStorage()
105100

106-
// Environment (initially equal to the current agent), transformed by some features
107-
// (ex: testing feature transforms it into a MockEnvironment with mocked tools)
108-
val preparedEnvironment =
109-
pipeline.onAgentEnvironmentTransforming(
110-
strategy = strategy,
111-
agent = this,
112-
baseEnvironment = environment
113-
)
101+
val preparedEnvironment = prepareAgentEnvironment()
114102

115-
return AIAgentGraphContext(
103+
val context = AIAgentGraphContext(
116104
environment = preparedEnvironment,
117105
agentId = id,
118106
agentInput = agentInput,
@@ -139,5 +127,48 @@ public open class GraphAIAgent<Input, Output>(
139127
strategyName = strategy.name,
140128
pipeline = pipeline,
141129
)
130+
131+
// Update Environment
132+
val contextualEnvironment = ContextualAgentEnvironment(
133+
environment = preparedEnvironment,
134+
context = context,
135+
)
136+
137+
val updatedLLMContext = context.llm.copy(
138+
environment = contextualEnvironment
139+
)
140+
141+
// Update the environment and llm with a created context instance
142+
return context.copy(
143+
parentContext = context.parentContext,
144+
llm = updatedLLMContext,
145+
environment = contextualEnvironment
146+
)
147+
}
148+
149+
/**
150+
* Prepares the environment for the AI agent by initializing a base environment
151+
* and applying any registered environment transformations defined in the pipeline.
152+
*
153+
* Environment (initially equal to the current agent), transformed by some features
154+
* (ex: testing feature transforms it into a MockEnvironment with mocked tools
155+
*
156+
* @return An instance of `AIAgentEnvironment` that represents the finalized environment
157+
* for the AI agent after applying all transformations.
158+
*/
159+
private suspend fun prepareAgentEnvironment(): AIAgentEnvironment {
160+
// Create a base environment implementation
161+
val environment = GenericAgentEnvironment(
162+
agentId = id,
163+
logger = logger,
164+
toolRegistry = toolRegistry,
165+
)
166+
167+
val preparedEnvironment = pipeline.onAgentEnvironmentTransforming(
168+
agent = this,
169+
baseEnvironment = environment
170+
)
171+
172+
return preparedEnvironment
142173
}
143174
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentGraphContext.kt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public interface AIAgentGraphContextBase : AIAgentContext {
6060
runId: String = this.runId,
6161
strategyName: String = this.strategyName,
6262
pipeline: AIAgentGraphPipeline = this.pipeline,
63+
parentContext: AIAgentGraphContextBase? = this,
6364
): AIAgentGraphContextBase {
6465
val clone = AIAgentGraphContext(
6566
environment = environment,
@@ -73,7 +74,7 @@ public interface AIAgentGraphContextBase : AIAgentContext {
7374
runId = runId,
7475
strategyName = strategyName,
7576
pipeline = pipeline,
76-
parentContext = this
77+
parentContext = parentContext,
7778
)
7879

7980
return clone
@@ -126,7 +127,7 @@ public class AIAgentGraphContext(
126127
override val runId: String,
127128
override val strategyName: String,
128129
override val pipeline: AIAgentGraphPipeline,
129-
override val parentContext: AIAgentGraphContextBase? = null
130+
override val parentContext: AIAgentGraphContextBase? = null,
130131
) : AIAgentGraphContextBase {
131132
private val mutableAIAgentContext = MutableAIAgentContext(llm, stateManager, storage)
132133

@@ -160,12 +161,16 @@ public class AIAgentGraphContext(
160161
}
161162

162163
/**
163-
* Replaces the current contxt with the provided context.
164+
* Replaces the current context with the provided context.
164165
* @param llm The LLM context to replace the current context with.
165166
* @param stateManager The state manager to replace the current context with.
166167
* @param storage The storage to replace the current context with.
167168
*/
168-
suspend fun replace(llm: AIAgentLLMContext?, stateManager: AIAgentStateManager?, storage: AIAgentStorage?) {
169+
suspend fun replace(
170+
llm: AIAgentLLMContext?,
171+
stateManager: AIAgentStateManager?,
172+
storage: AIAgentStorage?,
173+
) {
169174
rwLock.withWriteLock {
170175
llm?.let { this.llm = llm }
171176
stateManager?.let { this.stateManager = stateManager }
@@ -217,7 +222,7 @@ public class AIAgentGraphContext(
217222
mutableAIAgentContext.replace(
218223
context.llm,
219224
context.stateManager,
220-
context.storage
225+
context.storage,
221226
)
222227
}
223228
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ public class AIAgentParallelNodesMergeContext<Input, Output>(
7474
runId: String,
7575
strategyName: String,
7676
pipeline: AIAgentGraphPipeline,
77+
parentContext: AIAgentGraphContextBase?,
7778
): AIAgentGraphContextBase = underlyingContextBase.copy(
7879
environment = environment,
7980
agentInput = agentInput,
@@ -84,7 +85,8 @@ public class AIAgentParallelNodesMergeContext<Input, Output>(
8485
storage = storage,
8586
runId = runId,
8687
strategyName = strategyName,
87-
pipeline = pipeline
88+
pipeline = pipeline,
89+
parentContext = parentContext
8890
)
8991

9092
/**
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package ai.koog.agents.core.environment
2+
3+
import ai.koog.agents.core.agent.context.AIAgentContext
4+
import ai.koog.prompt.message.Message
5+
import io.github.oshai.kotlinlogging.KotlinLogging
6+
7+
internal class ContextualAgentEnvironment(
8+
private val environment: AIAgentEnvironment,
9+
private val context: AIAgentContext,
10+
) : AIAgentEnvironment {
11+
12+
companion object {
13+
private val logger = KotlinLogging.logger { }
14+
}
15+
16+
override suspend fun executeTool(toolCall: Message.Tool.Call): ReceivedToolResult {
17+
logger.trace { "Executing tool call (run id: ${context.runId}, tool call id: ${toolCall.id}, tool: ${toolCall.tool}, args: ${toolCall.contentJson})" }
18+
19+
context.pipeline.onToolCallStarting(
20+
runId = context.runId,
21+
toolCallId = toolCall.id,
22+
toolName = toolCall.tool,
23+
toolArgs = toolCall.contentJson
24+
)
25+
26+
val toolResult = environment.executeTool(toolCall)
27+
processToolResult(toolResult)
28+
29+
logger.trace { "Tool call completed (run id: ${context.runId}, tool call id: ${toolCall.id}, tool: ${toolCall.tool}, args: ${toolCall.contentJson}) with result: $toolResult" }
30+
return toolResult
31+
}
32+
33+
override suspend fun reportProblem(exception: Throwable) {
34+
environment.reportProblem(exception)
35+
}
36+
37+
//region Private Methods
38+
39+
private suspend fun processToolResult(
40+
toolResult: ReceivedToolResult
41+
) {
42+
when (val toolResultKind = toolResult.resultKind) {
43+
is ToolResultKind.Success -> {
44+
context.pipeline.onToolCallCompleted(
45+
runId = context.runId,
46+
toolCallId = toolResult.id,
47+
toolName = toolResult.tool,
48+
toolArgs = toolResult.toolArgs,
49+
toolDescription = toolResult.toolDescription,
50+
toolResult = toolResult.result
51+
)
52+
}
53+
54+
is ToolResultKind.Failure -> {
55+
context.pipeline.onToolCallFailed(
56+
runId = context.runId,
57+
toolCallId = toolResult.id,
58+
toolName = toolResult.tool,
59+
toolArgs = toolResult.toolArgs,
60+
toolDescription = toolResult.toolDescription,
61+
message = toolResult.content,
62+
error = toolResultKind.error
63+
)
64+
}
65+
66+
is ToolResultKind.ValidationError -> {
67+
context.pipeline.onToolValidationFailed(
68+
runId = context.runId,
69+
toolCallId = toolResult.id,
70+
toolName = toolResult.tool,
71+
toolArgs = toolResult.toolArgs,
72+
toolDescription = toolResult.toolDescription,
73+
message = toolResult.content,
74+
error = toolResultKind.error
75+
)
76+
}
77+
}
78+
}
79+
80+
//endregion Private Methods
81+
}

0 commit comments

Comments
 (0)