Skip to content

Commit 0ded992

Browse files
b8raoultmcocdawc
andauthored
feat: pydantic recipes (#575)
## Description Use Pydandic to load the recipe files, including models that my be defined in sources, filters and other plugins. To add a pydantic mode to a filter: ```python class Schema(BaseModel): fermi: str = "paradox" kardashev_scale: int = 3 @filter_registry.register("the-great-filter") class TheGreatFilter(Filter): schema = Schema def __init__(self, fermi, kardashev_scale): .... ``` ## What problem does this change solve? <!-- Describe if it's a bugfix, new feature, doc update, or breaking change --> ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: Oskar Weser <oskar.weser@gmail.com>
1 parent c63243e commit 0ded992

14 files changed

Lines changed: 458 additions & 126 deletions

File tree

src/anemoi/datasets/commands/inspect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,13 +740,18 @@ class Version0_14(Version0_13):
740740
pass
741741

742742

743+
class Version0_15(Version0_14):
744+
pass
745+
746+
743747
VERSIONS = {
744748
"0.0.0": NoVersion,
745749
"0.4.0": Version0_4,
746750
"0.6.0": Version0_6,
747751
"0.12.0": Version0_12,
748752
"0.13.0": Version0_13,
749753
"0.14.0": Version0_14,
754+
"0.15.0": Version0_15,
750755
}
751756

752757

src/anemoi/datasets/commands/recipe/dump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def dump_recipe(config: dict, dumper=None) -> str:
4141
recipe = Recipe(**config)
4242
input = InputBuilder(
4343
recipe.input,
44-
data_sources=recipe.data_sources or {},
44+
data_sources=recipe.data_sources,
4545
)
4646
if dumper is None:
4747
dumper = Dumper()

src/anemoi/datasets/create/creator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tqdm
2222
import zarr
2323
from anemoi.utils.humanize import bytes_to_human
24+
from anemoi.utils.sanitise import sanitise
2425

2526
from anemoi.datasets.create.dataset import Dataset
2627
from anemoi.datasets.create.input.builder import InputBuilder
@@ -32,7 +33,7 @@
3233

3334
LOG = logging.getLogger(__name__)
3435

35-
VERSION = "0.14"
36+
VERSION = "0.15"
3637

3738
LOG = logging.getLogger(__name__)
3839

@@ -177,9 +178,10 @@ def fill_metadata(self, metadata: dict) -> None:
177178
model_dump = json.loads(model_dump)
178179
model_dump = self.recipe.strip_unknown_keys(model_dump)
179180

180-
# TODO: make it an option
181-
# recipe = sanitise(model_dump)
182-
recipe = model_dump
181+
if self.recipe.output.sanitise:
182+
recipe = sanitise(model_dump)
183+
else:
184+
recipe = model_dump
183185

184186
# Remove stuff added by prepml
185187
allow_keys = set(model_dump.keys())
@@ -348,7 +350,7 @@ def task_finalise_additions(self) -> None:
348350
@cached_property
349351
def groups(self) -> Groups:
350352
"""Return the date groups for the dataset."""
351-
return Groups(**self.recipe.dates, group_by=self.recipe.build.group_by)
353+
return Groups(self.recipe.dates, group_by=self.recipe.build.group_by)
352354

353355
@cached_property
354356
def minimal_input(self) -> Any:
@@ -362,7 +364,7 @@ def input(self) -> InputBuilder:
362364

363365
return InputBuilder(
364366
self.recipe.input,
365-
data_sources=self.recipe.data_sources or {},
367+
data_sources=self.recipe.data_sources,
366368
)
367369

368370
@cached_property

src/anemoi/datasets/create/input/action.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import logging
1111
from abc import ABC
1212
from abc import abstractmethod
13+
from functools import cache
1314

14-
from anemoi.datasets.dates import DatesProvider
15+
from anemoi.datasets.create.recipe.dates import StartEndDates
1516

1617
LOG = logging.getLogger(__name__)
1718

@@ -79,7 +80,7 @@ def __init__(self, config, *path):
7980
for i, item in enumerate(config):
8081

8182
dates = item["dates"]
82-
filtering_dates = DatesProvider.from_config(**dates)
83+
filtering_dates = StartEndDates(**dates)
8384
action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i))
8485
self.choices.append((filtering_dates, action))
8586

@@ -308,48 +309,55 @@ def __call__(self, context, argument):
308309
return self.input(context, argument)
309310

310311

311-
KLASS = {
312-
"concat": Concat,
313-
"join": Join,
314-
"pipe": Pipe,
315-
"data-sources": DataSources,
316-
}
312+
@cache
313+
def _action_factories():
317314

318-
LEN_KLASS = len(KLASS)
315+
factories = {
316+
"concat": Concat,
317+
"join": Join,
318+
"pipe": Pipe,
319+
"data-sources": DataSources,
320+
}
319321

322+
# Load pluggins
323+
from anemoi.transform.filters import filter_registry as transform_filter_registry
324+
from anemoi.transform.sources import source_registry as transform_source_registry
320325

321-
def make(key, config, *path):
326+
from anemoi.datasets.create.sources import source_registry as dataset_source_registry
327+
328+
# Register sources, local first
329+
for name in dataset_source_registry.registered:
330+
if name not in factories:
331+
factories[name.replace("_", "-")] = new_source(name, DatasetSourceMixin)
322332

323-
if LEN_KLASS == len(KLASS):
333+
for name in transform_source_registry.registered:
334+
if name not in factories:
335+
factories[name.replace("_", "-")] = new_source(name, TransformSourceMixin)
324336

325-
# Load pluggins
326-
from anemoi.transform.filters import filter_registry as transform_filter_registry
327-
from anemoi.transform.sources import source_registry as transform_source_registry
337+
# Register filters
338+
for name in transform_filter_registry.registered:
339+
if name not in factories:
340+
factories[name.replace("_", "-")] = new_filter(name, TransformFilterMixin)
328341

329-
from anemoi.datasets.create.sources import source_registry as dataset_source_registry
342+
return factories
330343

331-
# Register sources, local first
332-
for name in dataset_source_registry.registered:
333-
if name not in KLASS:
334-
KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin)
335344

336-
for name in transform_source_registry.registered:
337-
if name not in KLASS:
338-
KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin)
345+
def make(key, config, *path):
339346

340-
# Register filters
341-
for name in transform_filter_registry.registered:
342-
if name not in KLASS:
343-
KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin)
347+
factories = _action_factories()
344348

345-
return KLASS[key.replace("_", "-")](config, *path)
349+
return factories[key.replace("_", "-")](config, *path)
346350

347351

348352
def action_factory(data, *path):
353+
from pydantic import BaseModel
349354

350355
assert len(path) > 0, f"Path must contain at least one element {path}"
351356
assert path[0] in ("input", "data_sources")
352357

358+
if isinstance(data, BaseModel):
359+
data = data.model_dump()
360+
353361
assert isinstance(data, dict), f"Input data must be a dictionary, got {type(data)}"
354362
assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}"
355363

src/anemoi/datasets/create/input/builder.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,37 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10+
from __future__ import annotations
11+
1012
from copy import deepcopy
1113
from functools import cached_property
1214
from typing import TYPE_CHECKING
1315
from typing import Any
1416

1517
if TYPE_CHECKING:
18+
from anemoi.datasets.create.input.action import Action
1619
from anemoi.datasets.create.input.action import Recipe
1720

1821

1922
class InputBuilder:
2023
"""Builder class for creating input data from configuration and data sources."""
2124

22-
def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> None:
25+
def __init__(self, config: "Action" | None, data_sources: "Action" | None, **kwargs: Any) -> None:
2326
"""Initialize the InputBuilder.
2427
2528
Parameters
2629
----------
27-
config : dict
28-
Configuration dictionary.
29-
data_sources : dict
30-
Data sources.
30+
config : Action | None
31+
Input part of the recipe.
32+
data_sources : Action | None
33+
Data sources part of the recipe.
3134
**kwargs : Any
3235
Additional keyword arguments.
3336
"""
3437
self.kwargs = kwargs
35-
self.config = deepcopy(config)
36-
self.data_sources = deepcopy(dict(data_sources=data_sources))
38+
# self.config = deepcopy(config.model_dump() if config is not None else {"empty": {}})
39+
self.config = deepcopy(config.model_dump() if config is not None else {})
40+
self.data_sources = deepcopy(dict(data_sources=data_sources or {}))
3741

3842
@cached_property
3943
def action(self) -> "Recipe":

src/anemoi/datasets/create/recipe/__init__.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,27 @@
99

1010
import json
1111
import logging
12-
from typing import Annotated
1312

1413
import yaml
15-
from anemoi.utils.config import DotDict
1614
from pydantic import BaseModel
17-
from pydantic import BeforeValidator
1815
from pydantic import ConfigDict
1916
from pydantic import Field
2017
from pydantic import model_validator
2118

19+
from .action import Action
2220
from .build import Build
21+
from .dates import Dates
2322
from .output import GriddedOutput
2423
from .output import Output
2524
from .statistics import Statistics
2625

2726
LOG = logging.getLogger(__name__)
2827

2928

30-
def validate_dotdict(v):
31-
if isinstance(v, dict):
32-
return DotDict(v)
33-
return v
34-
35-
36-
DotDictField = Annotated[DotDict, BeforeValidator(validate_dotdict)]
37-
38-
3929
class Recipe(BaseModel):
4030

31+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
32+
4133
@model_validator(mode="after")
4234
def _post_init(self) -> "Recipe":
4335
# We need to call _post_init on nested BaseModel members
@@ -48,27 +40,26 @@ def _post_init(self) -> "Recipe":
4840
member._post_init(self)
4941
return self
5042

51-
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
52-
5343
description: str = "No description provided."
5444
licence: str = "unknown"
5545
attribution: str = "unknown"
5646

57-
dates: DotDictField
47+
dates: Dates
5848
"""The date configuration for the dataset."""
5949

60-
input: DotDictField
50+
input: Action | None = None
6151
"""The input data sources configuration."""
6252

63-
data_sources: list[DotDictField] | DotDictField | None = None
53+
data_sources: dict[str, Action] | list[Action] | None = None
6454
"""The data sources configuration."""
6555

6656
output: Output = Field(default_factory=GriddedOutput)
6757
"""The output configuration."""
6858

6959
build: Build = Build()
7060
"""The build configuration."""
71-
additions: DotDictField | None = Field(
61+
62+
additions: dict | None = Field(
7263
default=None,
7364
deprecated="Top-level 'additions' is deprecated. Use 'statistics.tendencies' instead.",
7465
)
@@ -80,9 +71,48 @@ def _post_init(self) -> "Recipe":
8071
deprecated="Top-level 'env' is deprecated. Please use 'build.env' instead.",
8172
)
8273

74+
def only_non_defaults(self, data: dict) -> dict:
75+
"""Return a dictionary containing only non-default values from the recipe.
76+
77+
Parameters
78+
----------
79+
data : dict
80+
The recipe data as a dictionary.
81+
82+
Returns
83+
-------
84+
dict
85+
A dictionary containing only non-default values.
86+
"""
87+
88+
defaults = Recipe(dates={"values": []}).model_dump()
89+
90+
def _only_non_defaults(d, default_d):
91+
92+
if type(d) is not type(default_d):
93+
return d
94+
95+
if isinstance(d, dict):
96+
res = d.copy()
97+
for k, v in list(d.items()):
98+
if k not in default_d:
99+
del res[k]
100+
continue
101+
102+
if v == default_d[k]:
103+
del res[k]
104+
continue
105+
106+
res[k] = _only_non_defaults(v, default_d[k])
107+
return res
108+
109+
return d
110+
111+
return _only_non_defaults(data, defaults)
112+
83113
def strip_unknown_keys(self, data: dict) -> dict:
84114
assert isinstance(data, dict)
85-
defaults = Recipe(dates={}, input={}).model_dump()
115+
defaults = Recipe(input={"empty": {}}, dates={"values": []}).model_dump()
86116
return {key: data[key] for key in defaults.keys()}
87117

88118

@@ -99,8 +129,6 @@ def loader_recipe_from_yaml(path: str) -> dict:
99129
dict
100130
The dataset recipe.
101131
"""
102-
LOG.info(f"Loading recipe from YAML file at {path}")
103-
104132
with open(path) as f:
105133
recipe_yaml = f.read()
106134
recipe = yaml.safe_load(recipe_yaml)
@@ -122,12 +150,12 @@ def loader_recipe_from_zarr(path: str) -> dict:
122150
"""
123151
import zarr
124152

125-
LOG.info(f"Loading recipe from Zarr store at {path}")
126-
127153
z = zarr.open(path, mode="r")
128154

129155
for name in ("_recipe", "recipe"):
130156
if name not in z.attrs:
157+
# return None
158+
LOG.error(f"No '{name}' found in Zarr store at {path}")
131159
continue
132160

133161
recipe = z.attrs[name]

0 commit comments

Comments
 (0)