Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions internal/llminternal/base_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ func (f *Flow) Run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error]
return
}
if lastEvent.LLMResponse.Partial {
// We may have reached max token limit during streaming mode.
// TODO: handle Partial response in model level. CL 781377328
yield(nil, fmt.Errorf("TODO: last event is not final"))
// The last event is a partial streaming response (e.g., reached
// max token limit during streaming, or a sub-agent emitted
// partial events). The turn is complete so we simply return
// instead of looping again.
return
}
}
Expand Down
118 changes: 118 additions & 0 deletions internal/llminternal/base_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
package llminternal

import (
"context"
"errors"
"iter"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/genai"

"google.golang.org/adk/agent"
"google.golang.org/adk/internal/agent/runconfig"
icontext "google.golang.org/adk/internal/context"
"google.golang.org/adk/internal/toolinternal"
"google.golang.org/adk/model"
Expand Down Expand Up @@ -575,3 +579,117 @@ func TestMergeEventActions(t *testing.T) {
})
}
}

// mockStreamingLLM is a mock LLM that returns a canned sequence of responses.
// This simulates the scenario where the model's last streaming chunk has Partial=true,
// for example when reaching a max token limit.
type mockStreamingLLM struct {
responses []*model.LLMResponse
}

func (m *mockStreamingLLM) Name() string { return "mock-llm" }
func (m *mockStreamingLLM) GenerateContent(_ context.Context, _ *model.LLMRequest, _ bool) iter.Seq2[*model.LLMResponse, error] {
return func(yield func(*model.LLMResponse, error) bool) {
for _, r := range m.responses {
if !yield(r, nil) {
return
}
}
}
}

// TestFlowRunPartialLastEvent verifies that Flow.Run does not return an error
// when the last event from runOneStep has Partial=true.
// This is a regression test for https://github.com/google/adk-go/issues/600.
func TestFlowRunPartialLastEvent(t *testing.T) {
tests := []struct {
name string
responses []*model.LLMResponse
wantTexts []string
}{
{
name: "single partial response completes without error",
responses: []*model.LLMResponse{
{
Content: genai.NewContentFromText("Hello", genai.RoleModel),
Partial: true,
},
},
wantTexts: []string{"Hello"},
},
{
name: "multiple partial responses complete without error",
responses: []*model.LLMResponse{
{
Content: genai.NewContentFromText("Hello", genai.RoleModel),
Partial: true,
},
{
Content: genai.NewContentFromText(" World", genai.RoleModel),
Partial: true,
},
},
wantTexts: []string{"Hello", " World"},
},
{
name: "partial followed by non-partial completes without error",
responses: []*model.LLMResponse{
{
Content: genai.NewContentFromText("Hello", genai.RoleModel),
Partial: true,
},
{
Content: genai.NewContentFromText("Hello World", genai.RoleModel),
Partial: false,
},
},
wantTexts: []string{"Hello", "Hello World"},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockAgent, err := agent.New(agent.Config{Name: "test-agent"})
if err != nil {
t.Fatal(err)
}

ctx := runconfig.ToContext(t.Context(), &runconfig.RunConfig{
StreamingMode: runconfig.StreamingModeSSE,
})

invCtx := icontext.NewInvocationContext(ctx, icontext.InvocationContextParams{
Agent: mockAgent,
})

f := &Flow{
Model: &mockStreamingLLM{responses: tc.responses},
RequestProcessors: nil, // no preprocessors needed
}

var gotTexts []string
var gotErr error
for ev, err := range f.Run(invCtx) {
if err != nil {
gotErr = err
break
}
if ev != nil && ev.Content != nil {
for _, p := range ev.Content.Parts {
if p.Text != "" {
gotTexts = append(gotTexts, p.Text)
}
}
}
}

if gotErr != nil {
t.Errorf("Flow.Run() returned unexpected error: %v", gotErr)
}

if diff := cmp.Diff(tc.wantTexts, gotTexts); diff != "" {
t.Errorf("Flow.Run() text mismatch (-want +got):\n%s", diff)
}
})
}
}