Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cmd/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package launcher

import (
"context"
"time"

"github.com/a2aproject/a2a-go/a2asrv"

Expand Down Expand Up @@ -54,6 +55,18 @@ type SubLauncher interface {
Run(ctx context.Context, config *Config) error
}

// TriggerConfig contains configuration options for triggers.
type TriggerConfig struct {
// MaxRetries is the maximum number of times to retry a failed agent execution.
MaxRetries int
// BaseDelay is the base delay between retries.
BaseDelay time.Duration
// MaxDelay is the maximum delay between retries.
MaxDelay time.Duration
// MaxConcurrentRuns is the maximum number of concurrent runs.
MaxConcurrentRuns int
}

// Config contains parameters for web & console execution: sessions, artifacts, agents etc
type Config struct {
SessionService session.Service
Expand All @@ -63,4 +76,6 @@ type Config struct {
A2AOptions []a2asrv.RequestHandlerOption
PluginConfig runner.PluginConfig
TelemetryOptions []telemetry.Option
TriggerSources []string
TriggerConfig TriggerConfig
}
53 changes: 50 additions & 3 deletions cmd/launcher/web/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"flag"
"fmt"
"net/http"
"slices"
"strings"
"time"

Expand All @@ -31,11 +32,19 @@ import (
"google.golang.org/adk/telemetry"
)

// SupportedTriggers defines the allowed trigger sources for the ADK REST API.
var SupportedTriggers = []string{"pubsub"}

// apiConfig contains parametres for lauching ADK REST API
type apiConfig struct {
frontendAddress string
pathPrefix string
sseWriteTimeout time.Duration
frontendAddress string
pathPrefix string
sseWriteTimeout time.Duration
triggerSources string
triggerMaxRetries int
triggerBaseDelay time.Duration
triggerMaxDelay time.Duration
triggerMaxRuns int
}

// apiLauncher can launch ADK REST API
Expand Down Expand Up @@ -73,6 +82,25 @@ func (a *apiLauncher) UserMessage(webURL string, printer func(v ...any)) {

// SetupSubrouters adds the API router to the parent router.
func (a *apiLauncher) SetupSubrouters(router *mux.Router, config *launcher.Config) error {
if a.config.triggerSources != "" {
sources := strings.Split(a.config.triggerSources, ",")
for _, source := range sources {
if !slices.Contains(SupportedTriggers, source) {
return fmt.Errorf("invalid trigger source: %q. Any subset of %s is allowed. Values should be comma-separated", source, strings.Join(SupportedTriggers, ", "))
}
}
// De-duplicate the input sources.
slices.Sort(sources)
config.TriggerSources = slices.Compact(sources)
}

config.TriggerConfig = launcher.TriggerConfig{
MaxRetries: a.config.triggerMaxRetries,
BaseDelay: a.config.triggerBaseDelay,
MaxDelay: a.config.triggerMaxDelay,
MaxConcurrentRuns: a.config.triggerMaxRuns,
}

// Create the ADK REST API handler
restServer, err := adkrest.NewServer(adkrest.ServerConfig{
SessionService: config.SessionService,
Expand All @@ -81,6 +109,7 @@ func (a *apiLauncher) SetupSubrouters(router *mux.Router, config *launcher.Confi
ArtifactService: config.ArtifactService,
SSEWriteTimeout: a.config.sseWriteTimeout,
PluginConfig: config.PluginConfig,
TriggerSources: config.TriggerSources,
})
if err != nil {
return fmt.Errorf("failed to create REST server: %w", err)
Expand Down Expand Up @@ -115,6 +144,19 @@ func (a *apiLauncher) Parse(args []string) ([]string, error) {
if err != nil || !a.flags.Parsed() {
return nil, fmt.Errorf("failed to parse api flags: %v", err)
}
if a.config.triggerMaxRetries < 0 {
return nil, fmt.Errorf("trigger_max_retries must be >= 0")
}
if a.config.triggerBaseDelay < 0 {
return nil, fmt.Errorf("trigger_base_delay must be >= 0")
}
if a.config.triggerMaxDelay < 0 {
return nil, fmt.Errorf("trigger_max_delay must be >= 0")
}
if a.config.triggerMaxRuns < 0 {
return nil, fmt.Errorf("trigger_max_concurrent_runs must be >= 0")
}

p := a.config.pathPrefix
if !strings.HasPrefix(p, "/") {
p = "/" + p
Expand All @@ -138,6 +180,11 @@ func NewLauncher() weblauncher.Sublauncher {
fs.StringVar(&config.frontendAddress, "webui_address", "localhost:8080", "ADK WebUI address as seen from the user browser. It's used to allow CORS requests. Please specify only hostname and (optionally) port.")
fs.StringVar(&config.pathPrefix, "path_prefix", "/api", "ADK REST API path prefix. Default is '/api'.")
fs.DurationVar(&config.sseWriteTimeout, "sse-write-timeout", 120*time.Second, "SSE server write timeout (i.e. '10s', '2m' - see time.ParseDuration for details) - for writing the SSE response after reading the headers & body")
fs.IntVar(&config.triggerMaxRetries, "trigger_max_retries", 3, "Maximum retries for HTTP 429 errors from triggers")
fs.DurationVar(&config.triggerBaseDelay, "trigger_base_delay", 1*time.Second, "Base delay for trigger retry exponential backoff")
fs.DurationVar(&config.triggerMaxDelay, "trigger_max_delay", 10*time.Second, "Maximum delay for trigger retry exponential backoff")
fs.IntVar(&config.triggerMaxRuns, "trigger_max_concurrent_runs", 100, "Maximum concurrent trigger runs")
fs.StringVar(&config.triggerSources, "trigger_sources", "", fmt.Sprintf("Comma-separated list of trigger sources to enable (any subset of %s)", strings.Join(SupportedTriggers, ", ")))

return &apiLauncher{
config: config,
Expand Down
94 changes: 94 additions & 0 deletions cmd/launcher/web/api/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gorilla/mux"

"google.golang.org/adk/cmd/launcher"
)

func TestSetupSubrouters_TriggerSourcesValidation(t *testing.T) {
tests := []struct {
name string
triggerSources string
wantErr bool
wantSources []string
}{
{
name: "empty trigger sources",
triggerSources: "",
wantErr: false,
wantSources: nil,
},
{
name: "valid trigger sources single",
triggerSources: "pubsub",
wantErr: false,
wantSources: []string{"pubsub"},
},
{
name: "deduplicatedd trigger sources",
triggerSources: "pubsub,pubsub,pubsub",
wantErr: false,
wantSources: []string{"pubsub"},
},
{
name: "invalid trigger source",
triggerSources: "invalid",
wantErr: true,
wantSources: nil,
},
{
name: "mixed valid and invalid",
triggerSources: "pubsub,invalid,bq",
wantErr: true,
wantSources: nil,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
a := &apiLauncher{
config: &apiConfig{
triggerSources: tc.triggerSources,
},
}
router := mux.NewRouter()
config := &launcher.Config{}

err := a.SetupSubrouters(router, config)
if tc.wantErr {
if err == nil {
t.Errorf("SetupSubrouters() error = nil, wantErr %v", tc.wantErr)
}
} else {
if err != nil {
t.Errorf("SetupSubrouters() error = %v, wantErr %v", err, tc.wantErr)
}
diff := cmp.Diff(tc.wantSources, config.TriggerSources, cmpopts.SortSlices(func(a, b string) bool {
return a < b
}))
if diff != "" {
t.Errorf("SetupSubrouters() config.TriggerSources mismatch (-want +got):\n%s", diff)
}
}
})
}
}
Loading
Loading