-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsse_contracts.py
More file actions
70 lines (52 loc) · 2.5 KB
/
sse_contracts.py
File metadata and controls
70 lines (52 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations
"""Small contract helpers for SSE-only route handlers.
Keep this module free of FastHTML imports so MyPy can analyze it in isolation.
The app imports these helpers and applies them only to routes that must
always return `EventSourceResponse`.
"""
import asyncio
import inspect
from functools import wraps
from typing import Any, Awaitable, Callable, ParamSpec, get_args, get_origin, get_type_hints
from sse_starlette.sse import EventSourceResponse
P = ParamSpec("P")
SSEHandler = Callable[P, Awaitable[EventSourceResponse]]
_SSE_ROUTE_CONTRACTS: list[tuple[str, Callable[..., Any]]] = []
def _is_eventsource_annotation(anno: Any) -> bool:
"""Return True when `anno` is or contains `EventSourceResponse`."""
if anno is inspect.Signature.empty:
return False
if anno is EventSourceResponse:
return True
origin = get_origin(anno)
if origin is None:
return False
return any(_is_eventsource_annotation(arg) for arg in get_args(anno))
def sse_route_contract(fn: SSEHandler[P]) -> SSEHandler[P]:
"""Mark an endpoint as SSE-only and assert its runtime return type."""
if not asyncio.iscoroutinefunction(fn):
raise TypeError(f"SSE route {fn.__qualname__} must be async def")
_SSE_ROUTE_CONTRACTS.append((fn.__qualname__, fn))
@wraps(fn)
async def _wrapped(*args: P.args, **kwargs: P.kwargs) -> EventSourceResponse:
result = await fn(*args, **kwargs)
if not isinstance(result, EventSourceResponse):
raise TypeError(
f"SSE route {fn.__qualname__} must return EventSourceResponse, got {type(result).__name__}"
)
return result
_wrapped.__sse_route_contract__ = True # type: ignore[attr-defined]
_wrapped.__sse_route_contract_name__ = fn.__qualname__ # type: ignore[attr-defined]
return _wrapped
def validate_sse_route_contracts(contracts: list[tuple[str, Callable[..., Any]]] | None = None) -> None:
"""Fail fast if a marked SSE route is not annotated as SSE-only."""
for qualname, fn in contracts or _SSE_ROUTE_CONTRACTS:
try:
hints = get_type_hints(fn, globalns=globals(), localns=globals())
except Exception:
hints = {}
return_anno = hints.get("return", inspect.signature(fn).return_annotation)
if not _is_eventsource_annotation(return_anno):
raise RuntimeError(
f"SSE route {qualname} must be annotated to return EventSourceResponse, got {return_anno!r}"
)