Skip to content

Commit 301db01

Browse files
authored
chore(py): flatten prompt opts to kwargs (#5062)
1 parent 1baa4a4 commit 301db01

2 files changed

Lines changed: 94 additions & 58 deletions

File tree

py/packages/genkit/src/genkit/_ai/_prompt.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
PromptMetadata,
3333
)
3434
from pydantic import BaseModel, ConfigDict
35+
from typing_extensions import Unpack
3536

3637
from genkit._ai._generate import (
3738
generate_action,
@@ -295,60 +296,41 @@ async def _ensure_resolved(self) -> None:
295296

296297
async def __call__(
297298
self,
298-
input: InputT | None = None,
299-
opts: PromptGenerateOptions | None = None,
299+
input: InputT | dict[str, Any] | None = None,
300+
**opts: Unpack[PromptGenerateOptions],
300301
) -> ModelResponse[OutputT]:
301-
"""Execute the prompt and return the response."""
302-
await self._ensure_resolved()
303-
effective_opts: PromptGenerateOptions = opts if opts else {}
302+
"""Execute the prompt and return the response.
304303
305-
# Extract streaming callback and middleware from opts
306-
on_chunk = effective_opts.get('on_chunk')
307-
middleware = effective_opts.get('use') or self._use
308-
context = effective_opts.get('context')
304+
Args:
305+
input: Template variables for rendering.
306+
"""
307+
return await self._call_impl(input, opts) # ty: ignore[invalid-argument-type] # ty doesn't infer Unpack[TD] as TD in function body (PEP 692 gap)
309308

309+
async def _call_impl(
310+
self,
311+
input: InputT | dict[str, Any] | None,
312+
opts: PromptGenerateOptions,
313+
) -> ModelResponse[OutputT]:
314+
"""Execute the prompt with resolved opts. Used by __call__ and stream."""
315+
await self._ensure_resolved()
316+
on_chunk = opts.get('on_chunk')
317+
middleware = opts.get('use') or self._use
318+
context = opts.get('context')
310319
result = await generate_action(
311320
self._registry,
312-
await self.render(input=input, opts=effective_opts),
321+
await self._render_impl(input, opts),
313322
on_chunk=on_chunk,
314323
middleware=middleware,
315324
context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage]
316325
)
317-
# Cast to preserve the generic type parameter
318326
return cast(ModelResponse[OutputT], result)
319327

320-
def stream(
321-
self,
322-
input: InputT | None = None,
323-
opts: PromptGenerateOptions | None = None,
324-
*,
325-
timeout: float | None = None,
326-
) -> ModelStreamResponse[OutputT]:
327-
"""Stream the prompt execution, returning (stream, response_future)."""
328-
effective_opts: PromptGenerateOptions = opts if opts else {}
329-
channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout)
330-
331-
# Create a copy of opts with the streaming callback
332-
stream_opts: PromptGenerateOptions = {
333-
**effective_opts,
334-
'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)),
335-
}
336-
337-
resp = self.__call__(input=input, opts=stream_opts)
338-
response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp)
339-
channel.set_close_future(response_future)
340-
341-
return ModelStreamResponse[OutputT](channel=channel, response_future=response_future)
342-
343-
async def render(
328+
async def _render_impl(
344329
self,
345-
input: InputT | dict[str, Any] | None = None,
346-
opts: PromptGenerateOptions | None = None,
330+
input: InputT | dict[str, Any] | None,
331+
opts: PromptGenerateOptions,
347332
) -> GenerateActionOptions:
348-
"""Render the prompt template without executing, returning GenerateActionOptions."""
349-
await self._ensure_resolved()
350-
if opts is None:
351-
opts = cast(PromptGenerateOptions, {})
333+
"""Render the prompt with resolved opts. Used by render() and _call_impl."""
352334
output_opts = opts.get('output') or {}
353335
context = opts.get('context')
354336

@@ -499,6 +481,37 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401
499481
resume=resume,
500482
)
501483

484+
def stream(
485+
self,
486+
input: InputT | dict[str, Any] | None = None,
487+
*,
488+
timeout: float | None = None,
489+
**opts: Unpack[PromptGenerateOptions],
490+
) -> ModelStreamResponse[OutputT]:
491+
"""Stream the prompt execution, returning (stream, response_future)."""
492+
channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout)
493+
stream_opts: PromptGenerateOptions = {
494+
**opts, # ty doesn't infer Unpack[TD] as TD in function body (PEP 692 gap)
495+
'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)),
496+
}
497+
resp = self._call_impl(input, stream_opts)
498+
response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp)
499+
channel.set_close_future(response_future)
500+
501+
return ModelStreamResponse[OutputT](channel=channel, response_future=response_future)
502+
503+
async def render(
504+
self,
505+
input: InputT | dict[str, Any] | None = None,
506+
**opts: Unpack[PromptGenerateOptions],
507+
) -> GenerateActionOptions:
508+
"""Render the prompt template without executing, returning GenerateActionOptions.
509+
510+
Same keyword options as ``__call__`` (see PromptGenerateOptions).
511+
"""
512+
await self._ensure_resolved()
513+
return await self._render_impl(input, opts) # ty: ignore[invalid-argument-type] # ty doesn't infer Unpack[TD] as TD in function body (PEP 692 gap)
514+
502515
async def as_tool(self) -> Action:
503516
"""Expose this prompt as a tool.
504517

py/packages/genkit/tests/genkit/ai/prompt_test.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ async def test_simple_prompt_with_override_config() -> None:
8585

8686
my_prompt = ai.define_prompt(prompt='hi', config={'banana': True})
8787

88-
# New API: pass config via opts parameter - this MERGES with prompt config
89-
response = await my_prompt(opts={'config': {'temperature': 12}})
88+
# New API: pass config via kwargs — this MERGES with prompt config
89+
response = await my_prompt(config={'temperature': 12})
9090

9191
assert response.text == want_txt
9292

93-
# New API: stream also uses opts
94-
result = my_prompt.stream(opts={'config': {'temperature': 12}})
93+
# New API: stream also uses kwargs
94+
result = my_prompt.stream(config={'temperature': 12})
9595

9696
assert (await result.response).text == want_txt
9797

@@ -243,8 +243,8 @@ async def test_prompt_rendering_dotprompt(
243243

244244
my_prompt = ai.define_prompt(**prompt)
245245

246-
# New API: use opts parameter to pass config and context
247-
response = await my_prompt(input, opts={'config': input_option, 'context': context})
246+
# New API: use kwargs to pass config and context
247+
response = await my_prompt(input, config=input_option, context=context)
248248

249249
assert response.text == want_rendered
250250

@@ -488,7 +488,7 @@ async def test_config_merge_priority() -> None:
488488
# New API: runtime config is MERGED with prompt config
489489
# - temperature: 0.9 (from opts, overrides 0.5)
490490
# - banana: 'yellow' (from prompt, preserved)
491-
rendered = await my_prompt.render(opts={'config': {'temperature': 0.9}})
491+
rendered = await my_prompt.render(config={'temperature': 0.9})
492492

493493
assert rendered.config is not None
494494
# Config is now a dict after merging
@@ -509,8 +509,8 @@ async def test_opts_can_override_model() -> None:
509509
prompt='hello',
510510
)
511511

512-
# Override model via opts
513-
response = await my_prompt(opts={'model': 'programmableModel'})
512+
# Override model via kwargs
513+
response = await my_prompt(model='programmableModel')
514514

515515
# Should use programmableModel, not echoModel
516516
assert response.text == 'pm response'
@@ -531,8 +531,8 @@ async def test_opts_can_append_messages() -> None:
531531
Message(role=Role.MODEL, content=[Part(root=TextPart(text='Previous answer'))]),
532532
]
533533

534-
# Append conversation history via opts
535-
rendered = await my_prompt.render(opts={'messages': history_messages})
534+
# Append conversation history via kwargs
535+
rendered = await my_prompt.render(messages=history_messages)
536536

537537
# Should have: system + history (2) + user prompt = 4 messages
538538
assert len(rendered.messages) == 4
@@ -583,13 +583,11 @@ class OutputSchema(BaseModel):
583583
output_format='text', # Default to text
584584
)
585585

586-
# Override output via opts
586+
# Override output via kwargs
587587
rendered = await my_prompt.render(
588-
opts={
589-
'output': {
590-
'format': 'json',
591-
'schema': OutputSchema,
592-
}
588+
output={
589+
'format': 'json',
590+
'schema': OutputSchema,
593591
}
594592
)
595593

@@ -599,6 +597,31 @@ class OutputSchema(BaseModel):
599597
assert rendered.output.json_schema is not None
600598

601599

600+
@pytest.mark.asyncio
601+
async def test_executable_prompt_input_positional_opts_as_kwargs() -> None:
602+
"""ExecutablePrompt: input is positional, opts via kwargs after *."""
603+
ai, *_ = setup_test()
604+
605+
my_prompt = ai.define_prompt(
606+
prompt='Recipe for {{cuisine}} {{dish}}',
607+
output_format='text',
608+
)
609+
610+
# input = positional (template vars), output = kwarg (opts)
611+
rendered = await my_prompt.render(
612+
{'cuisine': 'Italian', 'dish': 'pasta'},
613+
output={'format': 'text'},
614+
)
615+
616+
# Template vars from input should be in the rendered prompt
617+
assert any('Italian' in str(m) for m in rendered.messages)
618+
assert any('pasta' in str(m) for m in rendered.messages)
619+
620+
# output kwarg should be respected
621+
assert rendered.output is not None
622+
assert rendered.output.format == 'text'
623+
624+
602625
# Tests for file-based prompt loading and two-action structure
603626
@pytest.mark.asyncio
604627
async def test_file_based_prompt_registers_two_actions() -> None:

0 commit comments

Comments
 (0)