Skip to content

Commit 69e903b

Browse files
authored
feat(go): added common middleware (e.g. tool approval, retry, fallback) (#4719)
1 parent 4ee4cdb commit 69e903b

39 files changed

Lines changed: 4183 additions & 59 deletions

.github/workflows/go.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
runs-on: ubuntu-latest
3030
strategy:
3131
matrix:
32-
go-version: ['1.24.x']
32+
go-version: ['1.25.x', '1.26.x']
3333
steps:
3434
- name: Checkout Repo
3535
uses: actions/checkout@main

go/README.md

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func main() {
5353
g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}))
5454

5555
answer, err := genkit.GenerateText(ctx, g,
56-
ai.WithModelName("googleai/gemini-2.5-flash"),
56+
ai.WithModelName("googleai/gemini-flash-latest"),
5757
ai.WithPrompt("Why is Go a great language for AI applications?"),
5858
)
5959
if err != nil {
@@ -80,7 +80,7 @@ Call any model with a simple, unified API:
8080

8181
```go
8282
text, _ := genkit.GenerateText(ctx, g,
83-
ai.WithModelName("googleai/gemini-2.5-flash"),
83+
ai.WithModelName("googleai/gemini-flash-latest"),
8484
ai.WithPrompt("Explain quantum computing in simple terms."),
8585
)
8686
fmt.Println(text)
@@ -98,7 +98,7 @@ type Recipe struct {
9898
}
9999

100100
recipe, _ := genkit.GenerateData[Recipe](ctx, g,
101-
ai.WithModelName("googleai/gemini-2.5-flash"),
101+
ai.WithModelName("googleai/gemini-flash-latest"),
102102
ai.WithPrompt("Create a recipe for chocolate chip cookies."),
103103
)
104104
fmt.Printf("Recipe: %s\n", recipe.Title)
@@ -112,7 +112,7 @@ Stream text as it's generated for responsive user experiences:
112112

113113
```go
114114
stream := genkit.GenerateStream(ctx, g,
115-
ai.WithModelName("googleai/gemini-2.5-flash"),
115+
ai.WithModelName("googleai/gemini-flash-latest"),
116116
ai.WithPrompt("Write a short story about a robot learning to paint."),
117117
)
118118

@@ -145,7 +145,7 @@ type Recipe struct {
145145
}
146146

147147
stream := genkit.GenerateDataStream[*Recipe](ctx, g,
148-
ai.WithModelName("googleai/gemini-2.5-flash"),
148+
ai.WithModelName("googleai/gemini-flash-latest"),
149149
ai.WithPrompt("Create a recipe for spaghetti carbonara."),
150150
)
151151

@@ -184,7 +184,7 @@ weatherTool := genkit.DefineTool(g, "getWeather",
184184
)
185185

186186
response, _ := genkit.Generate(ctx, g,
187-
ai.WithModelName("googleai/gemini-2.5-flash"),
187+
ai.WithModelName("googleai/gemini-flash-latest"),
188188
ai.WithPrompt("What's the weather like in San Francisco?"),
189189
ai.WithTools(weatherTool),
190190
)
@@ -226,7 +226,7 @@ transferTool := genkit.DefineTool(g, "transfer",
226226

227227
// Handle interrupts in your flow
228228
resp, _ := genkit.Generate(ctx, g,
229-
ai.WithModelName("googleai/gemini-2.5-flash"),
229+
ai.WithModelName("googleai/gemini-flash-latest"),
230230
ai.WithPrompt("Transfer $5000 to account ABC123"),
231231
ai.WithTools(transferTool),
232232
)
@@ -248,6 +248,60 @@ if resp.FinishReason == ai.FinishReasonInterrupted {
248248

249249
[See full example](samples/intermediate-interrupts)
250250

251+
### Middleware
252+
253+
Middleware wraps generation, model calls, and tool execution to add cross-cutting behavior without touching your flows. Register the `middleware` plugin during `Init` to expose the built-ins in the Dev UI, then attach them per call with `ai.WithUse`:
254+
255+
```go
256+
import "github.com/firebase/genkit/go/plugins/middleware"
257+
258+
g := genkit.Init(ctx, genkit.WithPlugins(
259+
&googlegenai.GoogleAI{},
260+
&middleware.Middleware{},
261+
))
262+
263+
// Retry transient failures, then fall back to a secondary model if the primary
264+
// stays down. Middleware composes outer-to-inner: Retry { Fallback { model } }.
265+
response, _ := genkit.Generate(ctx, g,
266+
ai.WithModelName("googleai/gemini-flash-latest"),
267+
ai.WithPrompt("Explain quantum computing."),
268+
ai.WithUse(
269+
&middleware.Retry{MaxRetries: 3},
270+
&middleware.Fallback{Models: []ai.ModelRef{
271+
googlegenai.ModelRef("googleai/gemini-3.1-flash", nil),
272+
}},
273+
),
274+
)
275+
```
276+
277+
The `middleware` plugin also ships with [`ToolApproval`](plugins/middleware/tool_approval.go) for human-in-the-loop gating, [`Filesystem`](samples/basic-middleware/filesystem) for sandboxed file access, and [`Skills`](samples/basic-middleware/skills) for loadable `SKILL.md` skills. [See the retry + fallback sample](samples/basic-middleware/retry-fallback) for the full composition.
278+
279+
### Custom Middleware
280+
281+
Implement the `ai.Middleware` interface to build your own. Embed `ai.BaseMiddleware` to inherit pass-through defaults for the hooks you don't need, then override `WrapGenerate`, `WrapModel`, or `WrapTool`:
282+
283+
```go
284+
type Logger struct {
285+
ai.BaseMiddleware
286+
Prefix string `json:"prefix,omitempty"`
287+
}
288+
289+
func (l *Logger) Name() string { return "mine/logger" }
290+
func (l *Logger) New() ai.Middleware { return &Logger{Prefix: l.Prefix} }
291+
292+
func (l *Logger) WrapModel(ctx context.Context, params *ai.ModelParams, next ai.ModelNext) (*ai.ModelResponse, error) {
293+
start := time.Now()
294+
resp, err := next(ctx, params)
295+
log.Printf("%s model call took %s", l.Prefix, time.Since(start))
296+
return resp, err
297+
}
298+
299+
// Use it like any built-in middleware.
300+
ai.WithUse(&Logger{Prefix: "[trace]"})
301+
```
302+
303+
`New()` is called once per `Generate` invocation, so middleware can hold per-call state without worrying about concurrent use across calls. `Name()` must be unique and stable since it's the key used to register and reference the middleware from the Dev UI and across runtimes.
304+
251305
### Define Flows
252306

253307
Wrap your AI logic in flows for better observability, testing, and deployment:
@@ -256,7 +310,7 @@ Wrap your AI logic in flows for better observability, testing, and deployment:
256310
jokeFlow := genkit.DefineFlow(g, "tellJoke",
257311
func(ctx context.Context, topic string) (string, error) {
258312
return genkit.GenerateText(ctx, g,
259-
ai.WithModelName("googleai/gemini-2.5-flash"),
313+
ai.WithModelName("googleai/gemini-flash-latest"),
260314
ai.WithPrompt("Tell me a joke about %s", topic),
261315
)
262316
},
@@ -276,7 +330,7 @@ Stream data from your flows using Server-Sent Events (SSE):
276330
genkit.DefineStreamingFlow(g, "streamStory",
277331
func(ctx context.Context, topic string, send core.StreamCallback[string]) (string, error) {
278332
stream := genkit.GenerateStream(ctx, g,
279-
ai.WithModelName("googleai/gemini-2.5-flash"),
333+
ai.WithModelName("googleai/gemini-flash-latest"),
280334
ai.WithPrompt("Write a story about %s", topic),
281335
)
282336

@@ -306,14 +360,14 @@ genkit.DefineFlow(g, "processDocument",
306360
// Each Run call creates a traced step visible in the Dev UI
307361
summary, _ := genkit.Run(ctx, "summarize", func() (string, error) {
308362
return genkit.GenerateText(ctx, g,
309-
ai.WithModelName("googleai/gemini-2.5-flash"),
363+
ai.WithModelName("googleai/gemini-flash-latest"),
310364
ai.WithPrompt("Summarize: %s", doc),
311365
)
312366
})
313367

314368
keywords, _ := genkit.Run(ctx, "extractKeywords", func() ([]string, error) {
315369
return genkit.GenerateData[[]string](ctx, g,
316-
ai.WithModelName("googleai/gemini-2.5-flash"),
370+
ai.WithModelName("googleai/gemini-flash-latest"),
317371
ai.WithPrompt("Extract keywords from: %s", summary),
318372
)
319373
})
@@ -331,7 +385,7 @@ Create reusable prompts with Handlebars templating:
331385

332386
```go
333387
greetingPrompt := genkit.DefinePrompt(g, "greeting",
334-
ai.WithModelName("googleai/gemini-2.5-flash"),
388+
ai.WithModelName("googleai/gemini-flash-latest"),
335389
ai.WithPrompt("Write a {{style}} greeting for {{name}}."),
336390
)
337391

@@ -359,7 +413,7 @@ type Joke struct {
359413
}
360414

361415
jokePrompt := genkit.DefineDataPrompt[JokeRequest, *Joke](g, "joke",
362-
ai.WithModelName("googleai/gemini-2.5-flash"),
416+
ai.WithModelName("googleai/gemini-flash-latest"),
363417
ai.WithPrompt("Tell a joke about {{topic}}."),
364418
)
365419

@@ -387,7 +441,7 @@ Keep prompts separate from code using `.prompt` files with YAML frontmatter:
387441
```yaml
388442
# prompts/recipe.prompt
389443
---
390-
model: googleai/gemini-2.5-flash
444+
model: googleai/gemini-flash-latest
391445
input:
392446
schema: RecipeRequest
393447
output:
@@ -553,13 +607,13 @@ import "google.golang.org/genai"
553607

554608
// Simple: just the model name
555609
response, _ := genkit.Generate(ctx, g,
556-
ai.WithModelName("googleai/gemini-2.5-flash"),
610+
ai.WithModelName("googleai/gemini-flash-latest"),
557611
ai.WithPrompt("Hello!"),
558612
)
559613

560614
// Advanced: model name + provider-specific configuration
561615
response, _ := genkit.Generate(ctx, g,
562-
ai.WithModel(googlegenai.ModelRef("googleai/gemini-2.5-flash", &genai.GenerateContentConfig{
616+
ai.WithModel(googlegenai.ModelRef("googleai/gemini-flash-latest", &genai.GenerateContentConfig{
563617
Temperature: genai.Ptr(float32(0.7)),
564618
MaxOutputTokens: genai.Ptr(int32(1000)),
565619
TopP: genai.Ptr(float32(0.9)),
@@ -657,6 +711,9 @@ Explore working examples to see Genkit in action:
657711
| [basic-structured](samples/basic-structured) | Typed JSON output with `GenerateData` and `GenerateDataStream` |
658712
| [basic-prompts](samples/basic-prompts) | Prompt templates with Handlebars and `.prompt` files |
659713
| [intermediate-interrupts](samples/intermediate-interrupts) | Human-in-the-loop with tool interrupts |
714+
| [basic-middleware/retry-fallback](samples/basic-middleware/retry-fallback) | Composing `Retry` and `Fallback` middleware |
715+
| [basic-middleware/filesystem](samples/basic-middleware/filesystem) | Scoped filesystem tools for the model |
716+
| [basic-middleware/skills](samples/basic-middleware/skills) | On-demand loadable `SKILL.md` personas |
660717
| [prompts-embed](samples/prompts-embed) | Embed prompts in your binary |
661718
| [durable-streaming](samples/durable-streaming) | Reconnectable streams with replay |
662719
| [session](samples/session) | Stateful flows with typed session data |

go/ai/document.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package ai
1919
import (
2020
"encoding/json"
2121
"fmt"
22+
"maps"
23+
"slices"
2224
"strings"
2325
)
2426

@@ -44,6 +46,31 @@ type Part struct {
4446
Metadata map[string]any `json:"metadata,omitempty"` // valid for all kinds
4547
}
4648

49+
// Clone returns a shallow copy of the Part with its own Metadata and Custom
50+
// maps. Callers can add or remove map keys without mutating the original.
51+
func (p *Part) Clone() *Part {
52+
if p == nil {
53+
return nil
54+
}
55+
cp := *p
56+
cp.Custom = maps.Clone(p.Custom)
57+
cp.Metadata = maps.Clone(p.Metadata)
58+
return &cp
59+
}
60+
61+
// Clone returns a shallow copy of the Message with its own Content slice
62+
// and Metadata map. Callers can replace parts or add metadata keys without
63+
// mutating the original.
64+
func (m *Message) Clone() *Message {
65+
if m == nil {
66+
return nil
67+
}
68+
cp := *m
69+
cp.Content = slices.Clone(m.Content)
70+
cp.Metadata = maps.Clone(m.Metadata)
71+
return &cp
72+
}
73+
4774
type PartKind int8
4875

4976
const (

go/ai/document_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package ai
1818

1919
import (
20+
"bytes"
2021
"encoding/json"
2122
"reflect"
2223
"testing"
@@ -411,3 +412,108 @@ func TestNewResponseForToolRequest(t *testing.T) {
411412
}
412413
})
413414
}
415+
416+
// TestPartClone verifies that Part.Clone produces an independent copy.
417+
// Every Part field is populated so that adding a new field without updating
418+
// this test (and Clone) causes a failure.
419+
func TestPartClone(t *testing.T) {
420+
orig := &Part{
421+
Kind: PartToolRequest,
422+
ContentType: "application/json",
423+
Text: "body",
424+
ToolRequest: &ToolRequest{Name: "tool", Input: map[string]any{"a": 1}},
425+
// Normally a Part wouldn't have both ToolRequest and ToolResponse,
426+
// but we populate everything to catch missing fields.
427+
ToolResponse: &ToolResponse{Name: "tool", Output: "ok"},
428+
Resource: &ResourcePart{Uri: "res://x"},
429+
Custom: map[string]any{"ck": "cv"},
430+
Metadata: map[string]any{"sig": []byte{1, 2, 3}, "key": "val"},
431+
}
432+
433+
// Guard: every field in the fixture must be non-zero.
434+
// If someone adds a new field to Part this will fail, forcing them to
435+
// add it here and verify Clone handles it.
436+
rv := reflect.ValueOf(orig).Elem()
437+
for i := range rv.NumField() {
438+
if rv.Field(i).IsZero() {
439+
t.Fatalf("Part field %q is zero in test fixture — populate it and verify Clone handles it", rv.Type().Field(i).Name)
440+
}
441+
}
442+
443+
cp := orig.Clone()
444+
445+
// Values must match.
446+
if !reflect.DeepEqual(orig, cp) {
447+
t.Fatal("Clone() values differ from original")
448+
}
449+
450+
// Mutating clone's maps must not affect the original.
451+
cp.Metadata["extra"] = true
452+
if _, ok := orig.Metadata["extra"]; ok {
453+
t.Error("mutating clone Metadata affected original")
454+
}
455+
456+
cp.Custom["extra"] = true
457+
if _, ok := orig.Custom["extra"]; ok {
458+
t.Error("mutating clone Custom affected original")
459+
}
460+
461+
// Go types in metadata (e.g. []byte) must be preserved, not string-ified.
462+
sig, ok := cp.Metadata["sig"].([]byte)
463+
if !ok {
464+
t.Fatalf("Metadata[sig] type = %T, want []byte", cp.Metadata["sig"])
465+
}
466+
if !bytes.Equal(sig, []byte{1, 2, 3}) {
467+
t.Errorf("Metadata[sig] = %v, want [1 2 3]", sig)
468+
}
469+
470+
// nil Part.Clone() should return nil.
471+
var nilPart *Part
472+
if nilPart.Clone() != nil {
473+
t.Error("nil Part.Clone() should return nil")
474+
}
475+
}
476+
477+
// TestMessageClone verifies that Message.Clone produces an independent copy.
478+
// Every Message field is populated so that adding a new field without updating
479+
// this test (and Clone) causes a failure.
480+
func TestMessageClone(t *testing.T) {
481+
orig := &Message{
482+
Role: RoleModel,
483+
Content: []*Part{NewTextPart("hello"), NewTextPart("world")},
484+
Metadata: map[string]any{"k": "v"},
485+
}
486+
487+
// Guard: every field must be non-zero.
488+
rv := reflect.ValueOf(orig).Elem()
489+
for i := range rv.NumField() {
490+
if rv.Field(i).IsZero() {
491+
t.Fatalf("Message field %q is zero in test fixture — populate it and verify Clone handles it", rv.Type().Field(i).Name)
492+
}
493+
}
494+
495+
cp := orig.Clone()
496+
497+
// Values must match.
498+
if !reflect.DeepEqual(orig, cp) {
499+
t.Fatal("Clone() values differ from original")
500+
}
501+
502+
// Mutating clone's Content slice must not affect the original.
503+
cp.Content[0] = NewTextPart("replaced")
504+
if orig.Content[0].Text != "hello" {
505+
t.Error("mutating clone Content affected original")
506+
}
507+
508+
// Mutating clone's Metadata must not affect the original.
509+
cp.Metadata["extra"] = true
510+
if _, ok := orig.Metadata["extra"]; ok {
511+
t.Error("mutating clone Metadata affected original")
512+
}
513+
514+
// nil Message.Clone() should return nil.
515+
var nilMsg *Message
516+
if nilMsg.Clone() != nil {
517+
t.Error("nil Message.Clone() should return nil")
518+
}
519+
}

0 commit comments

Comments
 (0)