|
32 | 32 | PromptMetadata, |
33 | 33 | ) |
34 | 34 | from pydantic import BaseModel, ConfigDict |
| 35 | +from typing_extensions import Unpack |
35 | 36 |
|
36 | 37 | from genkit._ai._generate import ( |
37 | 38 | generate_action, |
@@ -295,60 +296,41 @@ async def _ensure_resolved(self) -> None: |
295 | 296 |
|
296 | 297 | async def __call__( |
297 | 298 | self, |
298 | | - input: InputT | None = None, |
299 | | - opts: PromptGenerateOptions | None = None, |
| 299 | + input: InputT | dict[str, Any] | None = None, |
| 300 | + **opts: Unpack[PromptGenerateOptions], |
300 | 301 | ) -> ModelResponse[OutputT]: |
301 | | - """Execute the prompt and return the response.""" |
302 | | - await self._ensure_resolved() |
303 | | - effective_opts: PromptGenerateOptions = opts if opts else {} |
| 302 | + """Execute the prompt and return the response. |
304 | 303 |
|
305 | | - # Extract streaming callback and middleware from opts |
306 | | - on_chunk = effective_opts.get('on_chunk') |
307 | | - middleware = effective_opts.get('use') or self._use |
308 | | - context = effective_opts.get('context') |
| 304 | + Args: |
| 305 | + input: Template variables for rendering. |
| 306 | + """ |
| 307 | + return await self._call_impl(input, opts) # ty: ignore[invalid-argument-type] # ty doesn't infer Unpack[TD] as TD in function body (PEP 692 gap) |
309 | 308 |
|
| 309 | + async def _call_impl( |
| 310 | + self, |
| 311 | + input: InputT | dict[str, Any] | None, |
| 312 | + opts: PromptGenerateOptions, |
| 313 | + ) -> ModelResponse[OutputT]: |
| 314 | + """Execute the prompt with resolved opts. Used by __call__ and stream.""" |
| 315 | + await self._ensure_resolved() |
| 316 | + on_chunk = opts.get('on_chunk') |
| 317 | + middleware = opts.get('use') or self._use |
| 318 | + context = opts.get('context') |
310 | 319 | result = await generate_action( |
311 | 320 | self._registry, |
312 | | - await self.render(input=input, opts=effective_opts), |
| 321 | + await self._render_impl(input, opts), |
313 | 322 | on_chunk=on_chunk, |
314 | 323 | middleware=middleware, |
315 | 324 | context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage] |
316 | 325 | ) |
317 | | - # Cast to preserve the generic type parameter |
318 | 326 | return cast(ModelResponse[OutputT], result) |
319 | 327 |
|
320 | | - def stream( |
321 | | - self, |
322 | | - input: InputT | None = None, |
323 | | - opts: PromptGenerateOptions | None = None, |
324 | | - *, |
325 | | - timeout: float | None = None, |
326 | | - ) -> ModelStreamResponse[OutputT]: |
327 | | - """Stream the prompt execution, returning (stream, response_future).""" |
328 | | - effective_opts: PromptGenerateOptions = opts if opts else {} |
329 | | - channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) |
330 | | - |
331 | | - # Create a copy of opts with the streaming callback |
332 | | - stream_opts: PromptGenerateOptions = { |
333 | | - **effective_opts, |
334 | | - 'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)), |
335 | | - } |
336 | | - |
337 | | - resp = self.__call__(input=input, opts=stream_opts) |
338 | | - response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp) |
339 | | - channel.set_close_future(response_future) |
340 | | - |
341 | | - return ModelStreamResponse[OutputT](channel=channel, response_future=response_future) |
342 | | - |
343 | | - async def render( |
| 328 | + async def _render_impl( |
344 | 329 | self, |
345 | | - input: InputT | dict[str, Any] | None = None, |
346 | | - opts: PromptGenerateOptions | None = None, |
| 330 | + input: InputT | dict[str, Any] | None, |
| 331 | + opts: PromptGenerateOptions, |
347 | 332 | ) -> GenerateActionOptions: |
348 | | - """Render the prompt template without executing, returning GenerateActionOptions.""" |
349 | | - await self._ensure_resolved() |
350 | | - if opts is None: |
351 | | - opts = cast(PromptGenerateOptions, {}) |
| 333 | + """Render the prompt with resolved opts. Used by render() and _call_impl.""" |
352 | 334 | output_opts = opts.get('output') or {} |
353 | 335 | context = opts.get('context') |
354 | 336 |
|
@@ -499,6 +481,37 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 |
499 | 481 | resume=resume, |
500 | 482 | ) |
501 | 483 |
|
| 484 | + def stream( |
| 485 | + self, |
| 486 | + input: InputT | dict[str, Any] | None = None, |
| 487 | + *, |
| 488 | + timeout: float | None = None, |
| 489 | + **opts: Unpack[PromptGenerateOptions], |
| 490 | + ) -> ModelStreamResponse[OutputT]: |
| 491 | + """Stream the prompt execution, returning (stream, response_future).""" |
| 492 | + channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) |
| 493 | + stream_opts: PromptGenerateOptions = { |
| 494 | + **opts, # ty doesn't infer Unpack[TD] as TD in function body (PEP 692 gap) |
| 495 | + 'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)), |
| 496 | + } |
| 497 | + resp = self._call_impl(input, stream_opts) |
| 498 | + response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp) |
| 499 | + channel.set_close_future(response_future) |
| 500 | + |
| 501 | + return ModelStreamResponse[OutputT](channel=channel, response_future=response_future) |
| 502 | + |
| 503 | + async def render( |
| 504 | + self, |
| 505 | + input: InputT | dict[str, Any] | None = None, |
| 506 | + **opts: Unpack[PromptGenerateOptions], |
| 507 | + ) -> GenerateActionOptions: |
| 508 | + """Render the prompt template without executing, returning GenerateActionOptions. |
| 509 | +
|
| 510 | + Same keyword options as ``__call__`` (see PromptGenerateOptions). |
| 511 | + """ |
| 512 | + await self._ensure_resolved() |
| 513 | + return await self._render_impl(input, opts) # ty: ignore[invalid-argument-type] # ty doesn't infer Unpack[TD] as TD in function body (PEP 692 gap) |
| 514 | + |
502 | 515 | async def as_tool(self) -> Action: |
503 | 516 | """Expose this prompt as a tool. |
504 | 517 |
|
|
0 commit comments