Skip to content

Commit 6793b38

Browse files
committed
simplify dispatch
1 parent fa65fd1 commit 6793b38

12 files changed

Lines changed: 134 additions & 310 deletions

File tree

src/anemoi/datasets/create/arguments.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
2626
Inheritance:
2727
Intervals ← ValidDates (MRO fallback: sources that only register
28-
ForecastIntervals ← ForecastDates @for_valid_dates / @for_forecast_dates
28+
ForecastIntervals ← ForecastDates execute_valid_dates / execute_forecast_dates
2929
automatically handle the interval subtypes)
3030
3131
Conversion helpers on each class let composite sources transform the caller's
@@ -166,8 +166,8 @@ class Intervals(ValidDates):
166166
"""Archive-resolved accumulation windows.
167167
168168
Subclass of ``ValidDates`` so that sources registering only
169-
``@for_valid_dates`` receive ``Intervals`` via MRO fallback. Sources that
170-
care about the period (e.g. ``MarsSource``) register ``@for_intervals``
169+
``execute_valid_dates`` receive ``Intervals`` via the default fallback. Sources that
170+
care about the period (e.g. ``MarsSource``) override ``execute_intervals``
171171
explicitly to build step-range requests.
172172
173173
Each ``SignedInterval`` carries its own ``base_time`` (the model-run
@@ -203,7 +203,7 @@ def adjust_request(self, interval: Any, request: dict) -> tuple[Any, dict, int]:
203203
Rewrites the request as a full ``date``/``time``/``step`` triplet
204204
anchored on the interval's model-run time. ``interval.base`` must be
205205
set; valid-time-indexed backends (grib_index) handle ``base=None``
206-
intervals in their own ``@for_intervals`` overload rather than going
206+
intervals in their own ``execute_intervals`` override rather than going
207207
through this helper.
208208
209209
Parameters
@@ -220,7 +220,7 @@ def adjust_request(self, interval: Any, request: dict) -> tuple[Any, dict, int]:
220220
"""
221221
assert interval.base is not None, (
222222
f"Intervals.adjust_request requires a basetime; got {interval!r}. "
223-
"Valid-time-indexed sources (e.g. grib_index) must override @for_intervals "
223+
"Valid-time-indexed sources (e.g. grib_index) must override execute_intervals "
224224
"and not call this helper."
225225
)
226226
r = request.copy()
@@ -256,8 +256,8 @@ class ForecastIntervals(ForecastDates):
256256
"""Forecast accumulations: ``(valid_time, basetime, period)`` items plus a flat list of intervals.
257257
258258
Subclass of ``ForecastDates`` so that sources registering only
259-
``@for_forecast_dates`` receive ``ForecastIntervals`` via MRO fallback.
260-
Sources that care about the period register ``@for_forecast_intervals``
259+
``execute_forecast_dates`` receive ``ForecastIntervals`` via the default fallback.
260+
Sources that care about the period override ``execute_forecast_intervals``
261261
explicitly.
262262
263263
Each ``SignedInterval`` already carries its ``base_time`` (the

src/anemoi/datasets/create/dispatch.py

Lines changed: 3 additions & 228 deletions
Original file line numberDiff line numberDiff line change
@@ -7,231 +7,6 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
"""Domain-named dispatch decorators for source execute() overloads.
11-
12-
Usage
13-
-----
14-
Decorate multiple ``execute`` overloads on a ``DispatchedSource`` subclass.
15-
Each overload keeps the same method name but is tagged with the argument type
16-
it handles. The class-body namespace trick (``sys._getframe``) allows the
17-
decorators to accumulate all overloads into a single ``_MultiDispatch``
18-
descriptor despite the Python rule that each new ``def`` overwrites the
19-
previous binding::
20-
21-
from anemoi.datasets.create.dispatch import (
22-
DispatchedSource,
23-
for_valid_dates,
24-
for_forecast_dates,
25-
for_intervals,
26-
for_forecast_intervals,
27-
)
28-
29-
class MySource(DispatchedSource):
30-
31-
@for_valid_dates
32-
def execute(self, argument: ValidDates) -> FieldList:
33-
...
34-
35-
@for_intervals
36-
def execute(self, argument: Intervals) -> FieldList:
37-
...
38-
39-
At call time the argument's concrete type (with MRO fallback) selects the
40-
right overload. Plain ``list[datetime]`` is wrapped in ``ValidDates``
41-
automatically for backward compatibility.
42-
43-
Note
44-
----
45-
Sources that inherit from both ``LegacySource`` and ``DispatchedSource``
46-
must list ``DispatchedSource`` *after* ``LegacySource`` in the MRO so that
47-
``_MultiDispatch`` is not shadowed by ``LegacySource.execute``. Alternatively
48-
override ``execute`` explicitly as ``MarsSource`` does for Phase 0.
49-
"""
50-
51-
from __future__ import annotations
52-
53-
import sys
54-
from typing import Any
55-
56-
# ---------------------------------------------------------------------------
57-
# Dispatch descriptor
58-
# ---------------------------------------------------------------------------
59-
60-
61-
class _MultiDispatch:
62-
"""Method descriptor that dispatches ``execute()`` by argument type.
63-
64-
Each instance carries its own ``_registry`` mapping ``argument_type →
65-
callable``. Overloads are registered by the ``@for_*`` decorators at
66-
class-body time via frame inspection.
67-
68-
At call time the argument's MRO is walked to find the best match.
69-
Plain ``list`` arguments are wrapped in ``ValidDates`` for backward compat.
70-
"""
71-
72-
def __init__(self, name: str = "execute") -> None:
73-
self.name = name
74-
self._registry: dict[type, Any] = {}
75-
76-
def register(self, argument_type: type, fn: Any) -> "_MultiDispatch":
77-
"""Register an overload for the given argument type.
78-
79-
Parameters
80-
----------
81-
argument_type : type
82-
The argument class this overload handles.
83-
fn : callable
84-
The unbound method to call.
85-
"""
86-
self._registry[argument_type] = fn
87-
return self
88-
89-
def __set_name__(self, owner: type, name: str) -> None:
90-
self.name = name
91-
92-
def __get__(self, obj: Any, objtype: type | None = None):
93-
if obj is None:
94-
return self
95-
registry = self._registry
96-
obj_class_name = type(obj).__name__
97-
98-
def dispatch(argument: Any) -> Any:
99-
from anemoi.datasets.create.arguments import ValidDates
100-
from anemoi.datasets.dates.groups import GroupOfDates
101-
102-
# Backward compat: plain list[datetime] or GroupOfDates → ValidDates
103-
if isinstance(argument, list):
104-
argument = ValidDates(argument)
105-
elif isinstance(argument, GroupOfDates):
106-
argument = ValidDates(argument.dates)
107-
108-
arg_type = type(argument)
109-
110-
# Direct lookup
111-
method = registry.get(arg_type)
112-
if method is not None:
113-
return method(obj, argument)
114-
115-
# MRO fallback — handles subclasses of registered types
116-
for arg_klass in arg_type.__mro__:
117-
method = registry.get(arg_klass)
118-
if method is not None:
119-
return method(obj, argument)
120-
121-
registered = [t.__name__ for t in registry]
122-
123-
from anemoi.datasets.create.arguments import ForecastDates
124-
from anemoi.datasets.create.arguments import ForecastIntervals
125-
126-
if isinstance(argument, (ForecastDates, ForecastIntervals)):
127-
raise NotImplementedError(
128-
f"'{obj_class_name}' does not support the trajectory layout. "
129-
f"Received {arg_type.__name__} but this source only handles: {registered}."
130-
)
131-
132-
raise TypeError(
133-
f"{obj_class_name}.execute() has no overload for argument type "
134-
f"'{arg_type.__name__}'. Registered: {registered}"
135-
)
136-
137-
return dispatch
138-
139-
140-
# ---------------------------------------------------------------------------
141-
# Decorator factory (frame-inspection approach)
142-
# ---------------------------------------------------------------------------
143-
144-
145-
def _make_for_decorator(argument_type: type):
146-
"""Return a decorator that registers an execute overload for argument_type.
147-
148-
Uses ``sys._getframe(1)`` to inspect the enclosing class body and
149-
accumulate multiple ``@for_*``-decorated methods with the same name into a
150-
single ``_MultiDispatch`` descriptor.
151-
152-
Parameters
153-
----------
154-
argument_type : type
155-
The argument class this overload handles (e.g. ``ValidDates``).
156-
"""
157-
158-
def decorator(fn: Any) -> _MultiDispatch:
159-
# Inspect the enclosing class body's local namespace.
160-
frame = sys._getframe(1)
161-
local_ns = frame.f_locals
162-
163-
existing = local_ns.get(fn.__name__)
164-
if isinstance(existing, _MultiDispatch):
165-
dispatcher = existing
166-
else:
167-
dispatcher = _MultiDispatch(fn.__name__)
168-
169-
dispatcher.register(argument_type, fn)
170-
return dispatcher
171-
172-
snake = "".join(f"_{c.lower()}" if c.isupper() else c for c in argument_type.__name__).lstrip("_")
173-
decorator.__name__ = f"for_{snake}"
174-
return decorator
175-
176-
177-
# ---------------------------------------------------------------------------
178-
# Module-level decorators (lazy to avoid circular imports at load time)
179-
# ---------------------------------------------------------------------------
180-
181-
_decorators: tuple | None = None
182-
183-
184-
def _ensure_decorators() -> tuple:
185-
global _decorators
186-
if _decorators is None:
187-
from anemoi.datasets.create.arguments import ForecastDates
188-
from anemoi.datasets.create.arguments import ForecastIntervals
189-
from anemoi.datasets.create.arguments import Intervals
190-
from anemoi.datasets.create.arguments import ValidDates
191-
192-
_decorators = (
193-
_make_for_decorator(ValidDates),
194-
_make_for_decorator(ForecastDates),
195-
_make_for_decorator(Intervals),
196-
_make_for_decorator(ForecastIntervals),
197-
)
198-
return _decorators
199-
200-
201-
def __getattr__(name: str) -> Any:
202-
_names = ("for_valid_dates", "for_forecast_dates", "for_intervals", "for_forecast_intervals")
203-
if name in _names:
204-
return _ensure_decorators()[_names.index(name)]
205-
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
206-
207-
208-
# ---------------------------------------------------------------------------
209-
# DispatchedSource mixin
210-
# ---------------------------------------------------------------------------
211-
212-
213-
class DispatchedSource:
214-
"""Mixin that enables ``@for_*`` dispatch on ``execute()``.
215-
216-
Subclasses define overloads by decorating methods named ``execute``
217-
with ``@for_valid_dates``, ``@for_intervals``, etc. Because the
218-
decorators use frame inspection to accumulate overloads, multiple
219-
``@for_*``-decorated ``def execute`` statements in the same class body
220-
all merge into a single ``_MultiDispatch`` descriptor::
221-
222-
class MySource(Source, DispatchedSource):
223-
224-
@for_valid_dates
225-
def execute(self, argument: ValidDates) -> FieldList:
226-
...
227-
228-
@for_intervals
229-
def execute(self, argument: Intervals) -> FieldList:
230-
...
231-
232-
The descriptor is placed on ``DispatchedSource`` itself and is inherited
233-
by all subclasses. Per-class registries are stored on each subclass's
234-
own ``execute`` descriptor (placed there by the decorators).
235-
"""
236-
237-
pass
10+
# This module has been superseded by the execute_valid_dates / execute_forecast_dates /
11+
# execute_intervals / execute_forecast_intervals methods on Source. The frame-inspection
12+
# trick and _MultiDispatch descriptor are no longer needed.

0 commit comments

Comments
 (0)