|
10 | 10 | import logging |
11 | 11 | from abc import ABC |
12 | 12 | from abc import abstractmethod |
| 13 | +from functools import cache |
13 | 14 |
|
14 | | -from anemoi.datasets.dates import DatesProvider |
| 15 | +from anemoi.datasets.create.recipe.dates import StartEndDates |
15 | 16 |
|
16 | 17 | LOG = logging.getLogger(__name__) |
17 | 18 |
|
@@ -79,7 +80,7 @@ def __init__(self, config, *path): |
79 | 80 | for i, item in enumerate(config): |
80 | 81 |
|
81 | 82 | dates = item["dates"] |
82 | | - filtering_dates = DatesProvider.from_config(**dates) |
| 83 | + filtering_dates = StartEndDates(**dates) |
83 | 84 | action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) |
84 | 85 | self.choices.append((filtering_dates, action)) |
85 | 86 |
|
@@ -308,48 +309,55 @@ def __call__(self, context, argument): |
308 | 309 | return self.input(context, argument) |
309 | 310 |
|
310 | 311 |
|
311 | | -KLASS = { |
312 | | - "concat": Concat, |
313 | | - "join": Join, |
314 | | - "pipe": Pipe, |
315 | | - "data-sources": DataSources, |
316 | | -} |
| 312 | +@cache |
| 313 | +def _action_factories(): |
317 | 314 |
|
318 | | -LEN_KLASS = len(KLASS) |
| 315 | + factories = { |
| 316 | + "concat": Concat, |
| 317 | + "join": Join, |
| 318 | + "pipe": Pipe, |
| 319 | + "data-sources": DataSources, |
| 320 | + } |
319 | 321 |
|
| 322 | + # Load pluggins |
| 323 | + from anemoi.transform.filters import filter_registry as transform_filter_registry |
| 324 | + from anemoi.transform.sources import source_registry as transform_source_registry |
320 | 325 |
|
321 | | -def make(key, config, *path): |
| 326 | + from anemoi.datasets.create.sources import source_registry as dataset_source_registry |
| 327 | + |
| 328 | + # Register sources, local first |
| 329 | + for name in dataset_source_registry.registered: |
| 330 | + if name not in factories: |
| 331 | + factories[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) |
322 | 332 |
|
323 | | - if LEN_KLASS == len(KLASS): |
| 333 | + for name in transform_source_registry.registered: |
| 334 | + if name not in factories: |
| 335 | + factories[name.replace("_", "-")] = new_source(name, TransformSourceMixin) |
324 | 336 |
|
325 | | - # Load pluggins |
326 | | - from anemoi.transform.filters import filter_registry as transform_filter_registry |
327 | | - from anemoi.transform.sources import source_registry as transform_source_registry |
| 337 | + # Register filters |
| 338 | + for name in transform_filter_registry.registered: |
| 339 | + if name not in factories: |
| 340 | + factories[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) |
328 | 341 |
|
329 | | - from anemoi.datasets.create.sources import source_registry as dataset_source_registry |
| 342 | + return factories |
330 | 343 |
|
331 | | - # Register sources, local first |
332 | | - for name in dataset_source_registry.registered: |
333 | | - if name not in KLASS: |
334 | | - KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) |
335 | 344 |
|
336 | | - for name in transform_source_registry.registered: |
337 | | - if name not in KLASS: |
338 | | - KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) |
| 345 | +def make(key, config, *path): |
339 | 346 |
|
340 | | - # Register filters |
341 | | - for name in transform_filter_registry.registered: |
342 | | - if name not in KLASS: |
343 | | - KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) |
| 347 | + factories = _action_factories() |
344 | 348 |
|
345 | | - return KLASS[key.replace("_", "-")](config, *path) |
| 349 | + return factories[key.replace("_", "-")](config, *path) |
346 | 350 |
|
347 | 351 |
|
348 | 352 | def action_factory(data, *path): |
| 353 | + from pydantic import BaseModel |
349 | 354 |
|
350 | 355 | assert len(path) > 0, f"Path must contain at least one element {path}" |
351 | 356 | assert path[0] in ("input", "data_sources") |
352 | 357 |
|
| 358 | + if isinstance(data, BaseModel): |
| 359 | + data = data.model_dump() |
| 360 | + |
353 | 361 | assert isinstance(data, dict), f"Input data must be a dictionary, got {type(data)}" |
354 | 362 | assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" |
355 | 363 |
|
|
0 commit comments