|
7 | 7 | # granted to it by virtue of its status as an intergovernmental organisation |
8 | 8 | # nor does it submit to any jurisdiction. |
9 | 9 |
|
10 | | -"""Domain-named dispatch decorators for source execute() overloads. |
11 | | -
|
12 | | -Usage |
13 | | ------ |
14 | | -Decorate multiple ``execute`` overloads on a ``DispatchedSource`` subclass. |
15 | | -Each overload keeps the same method name but is tagged with the argument type |
16 | | -it handles. The class-body namespace trick (``sys._getframe``) allows the |
17 | | -decorators to accumulate all overloads into a single ``_MultiDispatch`` |
18 | | -descriptor despite the Python rule that each new ``def`` overwrites the |
19 | | -previous binding:: |
20 | | -
|
21 | | - from anemoi.datasets.create.dispatch import ( |
22 | | - DispatchedSource, |
23 | | - for_valid_dates, |
24 | | - for_forecast_dates, |
25 | | - for_intervals, |
26 | | - for_forecast_intervals, |
27 | | - ) |
28 | | -
|
29 | | - class MySource(DispatchedSource): |
30 | | -
|
31 | | - @for_valid_dates |
32 | | - def execute(self, argument: ValidDates) -> FieldList: |
33 | | - ... |
34 | | -
|
35 | | - @for_intervals |
36 | | - def execute(self, argument: Intervals) -> FieldList: |
37 | | - ... |
38 | | -
|
39 | | -At call time the argument's concrete type (with MRO fallback) selects the |
40 | | -right overload. Plain ``list[datetime]`` is wrapped in ``ValidDates`` |
41 | | -automatically for backward compatibility. |
42 | | -
|
43 | | -Note |
44 | | ----- |
45 | | -Sources that inherit from both ``LegacySource`` and ``DispatchedSource`` |
46 | | -must list ``DispatchedSource`` *after* ``LegacySource`` in the MRO so that |
47 | | -``_MultiDispatch`` is not shadowed by ``LegacySource.execute``. Alternatively |
48 | | -override ``execute`` explicitly as ``MarsSource`` does for Phase 0. |
49 | | -""" |
50 | | - |
51 | | -from __future__ import annotations |
52 | | - |
53 | | -import sys |
54 | | -from typing import Any |
55 | | - |
56 | | -# --------------------------------------------------------------------------- |
57 | | -# Dispatch descriptor |
58 | | -# --------------------------------------------------------------------------- |
59 | | - |
60 | | - |
61 | | -class _MultiDispatch: |
62 | | - """Method descriptor that dispatches ``execute()`` by argument type. |
63 | | -
|
64 | | - Each instance carries its own ``_registry`` mapping ``argument_type → |
65 | | - callable``. Overloads are registered by the ``@for_*`` decorators at |
66 | | - class-body time via frame inspection. |
67 | | -
|
68 | | - At call time the argument's MRO is walked to find the best match. |
69 | | - Plain ``list`` arguments are wrapped in ``ValidDates`` for backward compat. |
70 | | - """ |
71 | | - |
72 | | - def __init__(self, name: str = "execute") -> None: |
73 | | - self.name = name |
74 | | - self._registry: dict[type, Any] = {} |
75 | | - |
76 | | - def register(self, argument_type: type, fn: Any) -> "_MultiDispatch": |
77 | | - """Register an overload for the given argument type. |
78 | | -
|
79 | | - Parameters |
80 | | - ---------- |
81 | | - argument_type : type |
82 | | - The argument class this overload handles. |
83 | | - fn : callable |
84 | | - The unbound method to call. |
85 | | - """ |
86 | | - self._registry[argument_type] = fn |
87 | | - return self |
88 | | - |
89 | | - def __set_name__(self, owner: type, name: str) -> None: |
90 | | - self.name = name |
91 | | - |
92 | | - def __get__(self, obj: Any, objtype: type | None = None): |
93 | | - if obj is None: |
94 | | - return self |
95 | | - registry = self._registry |
96 | | - obj_class_name = type(obj).__name__ |
97 | | - |
98 | | - def dispatch(argument: Any) -> Any: |
99 | | - from anemoi.datasets.create.arguments import ValidDates |
100 | | - from anemoi.datasets.dates.groups import GroupOfDates |
101 | | - |
102 | | - # Backward compat: plain list[datetime] or GroupOfDates → ValidDates |
103 | | - if isinstance(argument, list): |
104 | | - argument = ValidDates(argument) |
105 | | - elif isinstance(argument, GroupOfDates): |
106 | | - argument = ValidDates(argument.dates) |
107 | | - |
108 | | - arg_type = type(argument) |
109 | | - |
110 | | - # Direct lookup |
111 | | - method = registry.get(arg_type) |
112 | | - if method is not None: |
113 | | - return method(obj, argument) |
114 | | - |
115 | | - # MRO fallback — handles subclasses of registered types |
116 | | - for arg_klass in arg_type.__mro__: |
117 | | - method = registry.get(arg_klass) |
118 | | - if method is not None: |
119 | | - return method(obj, argument) |
120 | | - |
121 | | - registered = [t.__name__ for t in registry] |
122 | | - |
123 | | - from anemoi.datasets.create.arguments import ForecastDates |
124 | | - from anemoi.datasets.create.arguments import ForecastIntervals |
125 | | - |
126 | | - if isinstance(argument, (ForecastDates, ForecastIntervals)): |
127 | | - raise NotImplementedError( |
128 | | - f"'{obj_class_name}' does not support the trajectory layout. " |
129 | | - f"Received {arg_type.__name__} but this source only handles: {registered}." |
130 | | - ) |
131 | | - |
132 | | - raise TypeError( |
133 | | - f"{obj_class_name}.execute() has no overload for argument type " |
134 | | - f"'{arg_type.__name__}'. Registered: {registered}" |
135 | | - ) |
136 | | - |
137 | | - return dispatch |
138 | | - |
139 | | - |
140 | | -# --------------------------------------------------------------------------- |
141 | | -# Decorator factory (frame-inspection approach) |
142 | | -# --------------------------------------------------------------------------- |
143 | | - |
144 | | - |
145 | | -def _make_for_decorator(argument_type: type): |
146 | | - """Return a decorator that registers an execute overload for argument_type. |
147 | | -
|
148 | | - Uses ``sys._getframe(1)`` to inspect the enclosing class body and |
149 | | - accumulate multiple ``@for_*``-decorated methods with the same name into a |
150 | | - single ``_MultiDispatch`` descriptor. |
151 | | -
|
152 | | - Parameters |
153 | | - ---------- |
154 | | - argument_type : type |
155 | | - The argument class this overload handles (e.g. ``ValidDates``). |
156 | | - """ |
157 | | - |
158 | | - def decorator(fn: Any) -> _MultiDispatch: |
159 | | - # Inspect the enclosing class body's local namespace. |
160 | | - frame = sys._getframe(1) |
161 | | - local_ns = frame.f_locals |
162 | | - |
163 | | - existing = local_ns.get(fn.__name__) |
164 | | - if isinstance(existing, _MultiDispatch): |
165 | | - dispatcher = existing |
166 | | - else: |
167 | | - dispatcher = _MultiDispatch(fn.__name__) |
168 | | - |
169 | | - dispatcher.register(argument_type, fn) |
170 | | - return dispatcher |
171 | | - |
172 | | - snake = "".join(f"_{c.lower()}" if c.isupper() else c for c in argument_type.__name__).lstrip("_") |
173 | | - decorator.__name__ = f"for_{snake}" |
174 | | - return decorator |
175 | | - |
176 | | - |
177 | | -# --------------------------------------------------------------------------- |
178 | | -# Module-level decorators (lazy to avoid circular imports at load time) |
179 | | -# --------------------------------------------------------------------------- |
180 | | - |
181 | | -_decorators: tuple | None = None |
182 | | - |
183 | | - |
184 | | -def _ensure_decorators() -> tuple: |
185 | | - global _decorators |
186 | | - if _decorators is None: |
187 | | - from anemoi.datasets.create.arguments import ForecastDates |
188 | | - from anemoi.datasets.create.arguments import ForecastIntervals |
189 | | - from anemoi.datasets.create.arguments import Intervals |
190 | | - from anemoi.datasets.create.arguments import ValidDates |
191 | | - |
192 | | - _decorators = ( |
193 | | - _make_for_decorator(ValidDates), |
194 | | - _make_for_decorator(ForecastDates), |
195 | | - _make_for_decorator(Intervals), |
196 | | - _make_for_decorator(ForecastIntervals), |
197 | | - ) |
198 | | - return _decorators |
199 | | - |
200 | | - |
201 | | -def __getattr__(name: str) -> Any: |
202 | | - _names = ("for_valid_dates", "for_forecast_dates", "for_intervals", "for_forecast_intervals") |
203 | | - if name in _names: |
204 | | - return _ensure_decorators()[_names.index(name)] |
205 | | - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") |
206 | | - |
207 | | - |
208 | | -# --------------------------------------------------------------------------- |
209 | | -# DispatchedSource mixin |
210 | | -# --------------------------------------------------------------------------- |
211 | | - |
212 | | - |
213 | | -class DispatchedSource: |
214 | | - """Mixin that enables ``@for_*`` dispatch on ``execute()``. |
215 | | -
|
216 | | - Subclasses define overloads by decorating methods named ``execute`` |
217 | | - with ``@for_valid_dates``, ``@for_intervals``, etc. Because the |
218 | | - decorators use frame inspection to accumulate overloads, multiple |
219 | | - ``@for_*``-decorated ``def execute`` statements in the same class body |
220 | | - all merge into a single ``_MultiDispatch`` descriptor:: |
221 | | -
|
222 | | - class MySource(Source, DispatchedSource): |
223 | | -
|
224 | | - @for_valid_dates |
225 | | - def execute(self, argument: ValidDates) -> FieldList: |
226 | | - ... |
227 | | -
|
228 | | - @for_intervals |
229 | | - def execute(self, argument: Intervals) -> FieldList: |
230 | | - ... |
231 | | -
|
232 | | - The descriptor is placed on ``DispatchedSource`` itself and is inherited |
233 | | - by all subclasses. Per-class registries are stored on each subclass's |
234 | | - own ``execute`` descriptor (placed there by the decorators). |
235 | | - """ |
236 | | - |
237 | | - pass |
| 10 | +# This module has been superseded by the execute_valid_dates / execute_forecast_dates / |
| 11 | +# execute_intervals / execute_forecast_intervals methods on Source. The frame-inspection |
| 12 | +# trick and _MultiDispatch descriptor are no longer needed. |
0 commit comments