Skip to content

Commit 9a6efea

Browse files
authored
feat: custom runner provider for adka2a executor (#680)
* custom runner provider for adka2a executor * remove the mandatory dependency on runnerConfig
1 parent b8eb8c5 commit 9a6efea

4 files changed

Lines changed: 153 additions & 32 deletions

File tree

server/adka2a/executor.go

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"errors"
2020
"fmt"
21+
"iter"
2122
"slices"
2223

2324
"github.com/a2aproject/a2a-go/a2a"
@@ -30,6 +31,7 @@ import (
3031

3132
"google.golang.org/adk/agent"
3233
iremoteagent "google.golang.org/adk/internal/agent/remoteagent"
34+
"google.golang.org/adk/plugin"
3335
"google.golang.org/adk/runner"
3436
"google.golang.org/adk/session"
3537
)
@@ -59,6 +61,17 @@ type A2AExecutionCleanupCallback func(ctx context.Context, reqCtx *a2asrv.Reques
5961
// OutputMode controls how artifacts are produced.
6062
type OutputMode string
6163

64+
// Runner is an interface matching [runner.Runner] API.
65+
// It exists to let users use custom runner implementations with A2A agent executor.
66+
type Runner interface {
67+
// Run runs the agent for the given user input, yielding events from agents.
68+
Run(ctx context.Context, userID, sessionID string, msg *genai.Content, cfg agent.RunConfig) iter.Seq2[*session.Event, error]
69+
}
70+
71+
// RunnerProvider is a [Runner] factory function. The provided plugin must be installed in the returned [Runner] for
72+
// callbacks taking [ExecutorContext] to work correctly.
73+
type RunnerProvider func(ctx context.Context, reqCtx *a2asrv.RequestContext, plugin *plugin.Plugin) (RunnerConfig, Runner, error)
74+
6275
const (
6376
// OutputArtifactPerRun produces a single artifact per [runner.Runner.Run].
6477
OutputArtifactPerRun OutputMode = "artifact-per-run"
@@ -69,10 +82,24 @@ const (
6982
OutputArtifactPerEvent OutputMode = "artifact-per-event"
7083
)
7184

85+
// RunnerConfig is part of the runner configuration executor code depends on.
86+
// Custom [RunnerProvider] needs to return it back to callers.
87+
type RunnerConfig struct {
88+
// AppName is the name of the application used in [session.Service] keys and A2A event metadata.
89+
AppName string
90+
// Agent is the root agent. It isued
91+
Agent agent.Agent
92+
// SessionService is the session service to use.
93+
SessionService session.Service
94+
}
95+
7296
// ExecutorConfig allows to configure Executor.
7397
type ExecutorConfig struct {
74-
// RunnerConfig is the configuration which will be used for [runner.New] during A2A Execute invocation.
98+
// RunnerConfig is used for creating a default RunnerProvider. The field is ignored when RunnerProvider is set.
7599
RunnerConfig runner.Config
100+
// RunnerProvider is a function which allows to control how a runner is created.
101+
// If not provided the default provider is used which calls [runner.New] with the RunnerConfig field.
102+
RunnerProvider RunnerProvider
76103

77104
// RunConfig is the configuration which will be passed to [runner.Runner.Run] during A2A Execute invocation.
78105
RunConfig agent.RunConfig
@@ -127,6 +154,9 @@ type Executor struct {
127154

128155
// NewExecutor creates an initialized [Executor] instance.
129156
func NewExecutor(config ExecutorConfig) *Executor {
157+
if config.RunnerProvider == nil {
158+
config.RunnerProvider = newDefaultRunnerProvider(config.RunnerConfig)
159+
}
130160
return &Executor{config: config}
131161
}
132162

@@ -140,15 +170,16 @@ func (e *Executor) Execute(ctx context.Context, reqCtx *a2asrv.RequestContext, q
140170
return fmt.Errorf("a2a message conversion failed: %w", err)
141171
}
142172

143-
runnerCfg, executorPlugin, err := withExecutorPlugin(e.config.RunnerConfig)
173+
executorPlugin, err := newExecutorPlugin()
144174
if err != nil {
145-
return fmt.Errorf("failed to install a2a-executor plugin: %w", err)
175+
return fmt.Errorf("failed to create a2a-executor plugin: %w", err)
146176
}
147177

148-
r, err := runner.New(runnerCfg)
178+
cfg, r, err := e.config.RunnerProvider(ctx, reqCtx, executorPlugin.plugin)
149179
if err != nil {
150180
return fmt.Errorf("failed to create a runner: %w", err)
151181
}
182+
152183
if e.config.BeforeExecuteCallback != nil {
153184
ctx, err = e.config.BeforeExecuteCallback(ctx, reqCtx)
154185
if err != nil {
@@ -170,9 +201,9 @@ func (e *Executor) Execute(ctx context.Context, reqCtx *a2asrv.RequestContext, q
170201
}
171202
}
172203

173-
invocationMeta := toInvocationMeta(ctx, e.config, reqCtx)
204+
invocationMeta := toInvocationMeta(ctx, cfg, reqCtx)
174205

175-
err = e.prepareSession(ctx, invocationMeta)
206+
err = e.prepareSession(ctx, cfg, invocationMeta)
176207
if err != nil {
177208
event := toTaskFailedUpdateEvent(reqCtx, err, invocationMeta.eventMeta)
178209
execCtx := newExecutorContext(ctx, invocationMeta, executorPlugin, content)
@@ -204,7 +235,13 @@ func (e *Executor) Cancel(ctx context.Context, reqCtx *a2asrv.RequestContext, qu
204235
}
205236

206237
func (e *Executor) Cleanup(ctx context.Context, reqCtx *a2asrv.RequestContext, result a2a.SendMessageResult, cause error) {
207-
remoteSubagents := findRemoteSubagents(e.config.RunnerConfig.Agent)
238+
cfg, err := e.createRunnerConfig(ctx, reqCtx)
239+
if err != nil {
240+
log.Error(ctx, "failed to create runner config", err)
241+
return
242+
}
243+
244+
remoteSubagents := findRemoteSubagents(cfg.Agent)
208245

209246
// If task was in input-required and got successfully cancelled - run the cleanup logic
210247
if reqCtx.StoredTask != nil && reqCtx.StoredTask.Status.State == a2a.TaskStateInputRequired {
@@ -235,9 +272,14 @@ func (e *Executor) cancelChildInputRequiredTasks(ctx context.Context, reqCtx *a2
235272
return nil
236273
}
237274

238-
meta := toInvocationMeta(ctx, e.config, reqCtx)
239-
getSessionResponse, err := e.config.RunnerConfig.SessionService.Get(ctx, &session.GetRequest{
240-
AppName: e.config.RunnerConfig.AppName,
275+
cfg, err := e.createRunnerConfig(ctx, reqCtx)
276+
if err != nil {
277+
return fmt.Errorf("failed to create runner config: %w", err)
278+
}
279+
280+
meta := toInvocationMeta(ctx, cfg, reqCtx)
281+
getSessionResponse, err := cfg.SessionService.Get(ctx, &session.GetRequest{
282+
AppName: cfg.AppName,
241283
UserID: meta.userID,
242284
SessionID: meta.sessionID,
243285
})
@@ -286,7 +328,7 @@ func (e *Executor) cancelChildInputRequiredTasks(ctx context.Context, reqCtx *a2
286328
}
287329

288330
// Processing failures should be delivered as Task failed events. An error is returned from this method if an event write fails.
289-
func (e *Executor) process(ctx ExecutorContext, r *runner.Runner, processor *eventProcessor, q eventqueue.Queue) error {
331+
func (e *Executor) process(ctx ExecutorContext, r Runner, processor *eventProcessor, q eventqueue.Queue) error {
290332
meta := processor.meta
291333
for adkEvent, adkErr := range r.Run(ctx, meta.userID, meta.sessionID, ctx.UserContent(), e.config.RunConfig) {
292334
if adkErr != nil {
@@ -338,11 +380,11 @@ func (e *Executor) writeFinalTaskStatus(
338380
return nil
339381
}
340382

341-
func (e *Executor) prepareSession(ctx context.Context, meta invocationMeta) error {
342-
service := e.config.RunnerConfig.SessionService
383+
func (e *Executor) prepareSession(ctx context.Context, cfg RunnerConfig, meta invocationMeta) error {
384+
service := cfg.SessionService
343385

344386
_, err := service.Get(ctx, &session.GetRequest{
345-
AppName: e.config.RunnerConfig.AppName,
387+
AppName: cfg.AppName,
346388
UserID: meta.userID,
347389
SessionID: meta.sessionID,
348390
})
@@ -351,7 +393,7 @@ func (e *Executor) prepareSession(ctx context.Context, meta invocationMeta) erro
351393
}
352394

353395
_, err = service.Create(ctx, &session.CreateRequest{
354-
AppName: e.config.RunnerConfig.AppName,
396+
AppName: cfg.AppName,
355397
UserID: meta.userID,
356398
SessionID: meta.sessionID,
357399
State: make(map[string]any),
@@ -362,3 +404,46 @@ func (e *Executor) prepareSession(ctx context.Context, meta invocationMeta) erro
362404

363405
return nil
364406
}
407+
408+
func (e *Executor) createRunnerConfig(ctx context.Context, reqCtx *a2asrv.RequestContext) (RunnerConfig, error) {
409+
executorPlugin, err := newExecutorPlugin()
410+
if err != nil {
411+
return RunnerConfig{}, fmt.Errorf("failed to create a2a-plugin: %w", err)
412+
}
413+
cfg, _, err := e.config.RunnerProvider(ctx, reqCtx, executorPlugin.plugin)
414+
if err != nil {
415+
return RunnerConfig{}, fmt.Errorf("runner provider failed: %w", err)
416+
}
417+
return cfg, nil
418+
}
419+
420+
func newDefaultRunnerProvider(baseConfig runner.Config) RunnerProvider {
421+
return func(ctx context.Context, reqCtx *a2asrv.RequestContext, plugin *plugin.Plugin) (RunnerConfig, Runner, error) {
422+
if baseConfig.Agent == nil {
423+
return RunnerConfig{}, nil, fmt.Errorf("runner.Config.Agent is not provided")
424+
}
425+
if baseConfig.Agent == nil {
426+
return RunnerConfig{}, nil, fmt.Errorf("runner.Config.SessionService is not provided")
427+
}
428+
429+
cfg := baseConfig
430+
cfg.PluginConfig.Plugins = append(slices.Clone(cfg.PluginConfig.Plugins), plugin)
431+
r, err := runner.New(cfg)
432+
if err != nil {
433+
return RunnerConfig{}, nil, err
434+
}
435+
return toInternalRunnerConfig(cfg), &defaultRunner{runner: r}, nil
436+
}
437+
}
438+
439+
type defaultRunner struct {
440+
runner *runner.Runner
441+
}
442+
443+
func (r *defaultRunner) Run(ctx context.Context, userID, sessionID string, msg *genai.Content, cfg agent.RunConfig) iter.Seq2[*session.Event, error] {
444+
return r.runner.Run(ctx, userID, sessionID, msg, cfg)
445+
}
446+
447+
func toInternalRunnerConfig(cfg runner.Config) RunnerConfig {
448+
return RunnerConfig{Agent: cfg.Agent, AppName: cfg.AppName, SessionService: cfg.SessionService}
449+
}

server/adka2a/executor_plugin.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,10 @@
1515
package adka2a
1616

1717
import (
18-
"slices"
19-
2018
"google.golang.org/genai"
2119

2220
"google.golang.org/adk/agent"
2321
"google.golang.org/adk/plugin"
24-
"google.golang.org/adk/runner"
2522
"google.golang.org/adk/session"
2623
)
2724

@@ -31,15 +28,6 @@ type executorPlugin struct {
3128
invocationSession session.Session
3229
}
3330

34-
func withExecutorPlugin(cfg runner.Config) (runner.Config, *executorPlugin, error) {
35-
executorPlugin, err := newExecutorPlugin()
36-
if err != nil {
37-
return cfg, nil, err
38-
}
39-
cfg.PluginConfig.Plugins = append(slices.Clone(cfg.PluginConfig.Plugins), executorPlugin.plugin)
40-
return cfg, executorPlugin, nil
41-
}
42-
4331
func newExecutorPlugin() (*executorPlugin, error) {
4432
execPlugin := &executorPlugin{}
4533
plugin, err := plugin.New(plugin.Config{

server/adka2a/executor_test.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ import (
3131
"google.golang.org/genai"
3232

3333
"google.golang.org/adk/agent"
34+
"google.golang.org/adk/internal/utils"
3435
"google.golang.org/adk/model"
36+
"google.golang.org/adk/plugin"
3537
"google.golang.org/adk/runner"
3638
"google.golang.org/adk/session"
3739
)
@@ -300,7 +302,7 @@ func TestExecutor_SessionReuse(t *testing.T) {
300302
t.Fatalf("executor.Execute() error = %v, want nil", err)
301303
}
302304

303-
meta := toInvocationMeta(ctx, config, reqCtx)
305+
meta := toInvocationMeta(ctx, toInternalRunnerConfig(config.RunnerConfig), reqCtx)
304306
sessions, err := sessionService.List(ctx, &session.ListRequest{AppName: runnerConfig.AppName, UserID: meta.userID})
305307
if err != nil {
306308
t.Fatalf("sessionService.List() error = %v, want nil", err)
@@ -310,7 +312,7 @@ func TestExecutor_SessionReuse(t *testing.T) {
310312
}
311313

312314
reqCtx.ContextID = a2a.NewContextID()
313-
otherContextMeta := toInvocationMeta(ctx, config, reqCtx)
315+
otherContextMeta := toInvocationMeta(ctx, toInternalRunnerConfig(config.RunnerConfig), reqCtx)
314316
if meta.sessionID == otherContextMeta.sessionID {
315317
t.Fatal("want sessionID to be different for different contextIDs")
316318
}
@@ -935,3 +937,49 @@ func TestExecutor_OutputArtifactPerEvent(t *testing.T) {
935937
})
936938
}
937939
}
940+
941+
func TestExecutor_RunnerProvider(t *testing.T) {
942+
wantText := "Hello"
943+
ctx := t.Context()
944+
task := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()}
945+
hiMsg := a2a.NewMessageForTask(a2a.MessageRoleUser, task, a2a.TextPart{Text: "hi"})
946+
reqCtx := &a2asrv.RequestContext{TaskID: task.ID, ContextID: task.ContextID, Message: hiMsg, StoredTask: task}
947+
948+
runnerConfig := runner.Config{
949+
AppName: "test",
950+
SessionService: session.InMemoryService(),
951+
Agent: utils.Must(agent.New(agent.Config{Name: "agent"})),
952+
}
953+
executor := NewExecutor(ExecutorConfig{
954+
RunnerConfig: runnerConfig,
955+
RunnerProvider: func(pCtx context.Context, pReqCtx *a2asrv.RequestContext, plugin *plugin.Plugin) (RunnerConfig, Runner, error) {
956+
return toInternalRunnerConfig(runnerConfig), &testRunner{
957+
runFunc: func(ctx context.Context, userID, sessionID string, msg *genai.Content, cfg agent.RunConfig) iter.Seq2[*session.Event, error] {
958+
return func(yield func(*session.Event, error) bool) {
959+
yield(&session.Event{LLMResponse: modelResponseFromParts(genai.NewPartFromText(wantText))}, nil)
960+
}
961+
},
962+
}, nil
963+
},
964+
})
965+
966+
queue := &testQueue{Queue: newInMemoryQueue(t)}
967+
if err := executor.Execute(ctx, reqCtx, queue); err != nil {
968+
t.Fatalf("executor.Execute() error = %v", err)
969+
}
970+
ta, ok := queue.events[1].(*a2a.TaskArtifactUpdateEvent)
971+
if !ok {
972+
t.Fatalf("queue.events[1] = %T, want a2a.TaskArtifactUpdateEvent", queue.events[1])
973+
}
974+
if tp, ok := ta.Artifact.Parts[0].(a2a.TextPart); !ok || tp.Text != wantText {
975+
t.Fatalf("ta.Artifact.Parts[0] = %v, want text part with text = %q", tp, wantText)
976+
}
977+
}
978+
979+
type testRunner struct {
980+
runFunc func(ctx context.Context, userID, sessionID string, msg *genai.Content, cfg agent.RunConfig) iter.Seq2[*session.Event, error]
981+
}
982+
983+
func (r *testRunner) Run(ctx context.Context, userID, sessionID string, msg *genai.Content, cfg agent.RunConfig) iter.Seq2[*session.Event, error] {
984+
return r.runFunc(ctx, userID, sessionID, msg, cfg)
985+
}

server/adka2a/metadata.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ type invocationMeta struct {
5757
eventMeta map[string]any
5858
}
5959

60-
func toInvocationMeta(ctx context.Context, config ExecutorConfig, reqCtx *a2asrv.RequestContext) invocationMeta {
60+
func toInvocationMeta(ctx context.Context, config RunnerConfig, reqCtx *a2asrv.RequestContext) invocationMeta {
6161
userID, sessionID := "A2A_USER_"+reqCtx.ContextID, reqCtx.ContextID
6262

6363
// a2a sdk attaches authn info to the call context, use it when provided
@@ -68,15 +68,15 @@ func toInvocationMeta(ctx context.Context, config ExecutorConfig, reqCtx *a2asrv
6868
}
6969

7070
meta := map[string]any{
71-
ToA2AMetaKey("app_name"): config.RunnerConfig.AppName,
71+
ToA2AMetaKey("app_name"): config.AppName,
7272
ToA2AMetaKey("user_id"): userID,
7373
ToA2AMetaKey("session_id"): sessionID,
7474
}
7575

7676
return invocationMeta{
7777
userID: userID,
7878
sessionID: sessionID,
79-
agentName: config.RunnerConfig.Agent.Name(),
79+
agentName: config.Agent.Name(),
8080
eventMeta: meta,
8181
reqCtx: reqCtx,
8282
}

0 commit comments

Comments
 (0)