Skip to content

Commit a5541bf

Browse files
committed
implement also concat for trajectories
1 parent 854c3a9 commit a5541bf

3 files changed

Lines changed: 59 additions & 20 deletions

File tree

src/anemoi/datasets/create/gridded/context.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,14 @@ def create_result(self, argument: Any, data: Any) -> SimpleGriddedResult:
6767
"""
6868
return SimpleGriddedResult(self, argument, data)
6969

70-
def matching_dates(self, filtering_dates: Any, group_of_dates: Any) -> GroupOfDates:
71-
"""Find dates that match between filtering_dates and group_of_dates.
70+
def matching_dates(self, filters: dict, group_of_dates: Any) -> GroupOfDates:
71+
"""Find dates that match between filters and group_of_dates.
7272
7373
Parameters
7474
----------
75-
filtering_dates : Any
76-
The dates to filter by.
75+
filters : dict
76+
A dict mapping filter keys to DatesProvider objects.
77+
Gridded layouts only support ``'dates'``.
7778
group_of_dates : Any
7879
The group of dates to compare against.
7980
@@ -82,7 +83,14 @@ def matching_dates(self, filtering_dates: Any, group_of_dates: Any) -> GroupOfDa
8283
GroupOfDates
8384
A GroupOfDates object containing the intersection of the two sets.
8485
"""
85-
86+
unsupported = set(filters) - {"dates"}
87+
if unsupported:
88+
raise ValueError(
89+
f"Gridded layout does not support filtering by {unsupported}. "
90+
"Use 'dates' instead."
91+
)
92+
93+
filtering_dates = filters["dates"]
8694
return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider)
8795

8896
def origin(self, data: Any, action: Any, action_arguments: Any) -> Any:

src/anemoi/datasets/create/input/action.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class Concat(Action):
6969
7070
"""
7171

72+
FILTER_KEYS = ("dates", "base_dates", "steps")
73+
7274
def __init__(self, config, *path):
7375
super().__init__(config, *path, "concat")
7476

@@ -78,10 +80,24 @@ def __init__(self, config, *path):
7880

7981
for i, item in enumerate(config):
8082

81-
dates = item["dates"]
82-
filtering_dates = DatesProvider.from_config(**dates)
83-
action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i))
84-
self.choices.append((filtering_dates, action))
83+
filters = {}
84+
for key in self.FILTER_KEYS:
85+
if key in item:
86+
if key == "steps":
87+
from anemoi.datasets.create.trajectories.context import Steps
88+
89+
filters[key] = Steps(item[key])
90+
else:
91+
filters[key] = DatesProvider.from_config(**item[key])
92+
93+
if not filters:
94+
raise ValueError(
95+
f"Concat entry must have at least one of {self.FILTER_KEYS}."
96+
)
97+
98+
action_config = {k: v for k, v in item.items() if k not in self.FILTER_KEYS}
99+
action = action_factory(action_config, *self.path, str(i))
100+
self.choices.append((filters, action))
85101

86102
def __repr__(self):
87103
return f"Concat({self.choices})"
@@ -90,8 +106,8 @@ def __call__(self, context, argument):
90106

91107
results = []
92108

93-
for filtering_dates, action in self.choices:
94-
dates = context.matching_dates(filtering_dates, argument)
109+
for filters, action in self.choices:
110+
dates = context.matching_dates(filters, argument)
95111
if len(dates) == 0:
96112
continue
97113
results.append(action(context, dates))
@@ -102,7 +118,7 @@ def __call__(self, context, argument):
102118

103119
def dump(self, dumper):
104120
return dumper.concat(
105-
{filtering_dates.dump(dumper): action.dump(dumper) for filtering_dates, action in self.choices}
121+
{list(filters.values())[0].dump(dumper): action.dump(dumper) for filters, action in self.choices}
106122
)
107123

108124

src/anemoi/datasets/create/trajectories/context.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,41 @@ def __init__(self, recipe: Any) -> None:
7474
def create_result(self, argument: Any, data: Any) -> TrajectoryGriddedResult:
7575
return TrajectoryGriddedResult(self, argument, data)
7676

77-
def matching_dates(self, filtering_dates: Any, group_of_dates: Any) -> Any:
78-
"""Find dates that match between filtering_dates and group_of_dates.
77+
def matching_dates(self, filters: dict, group_of_dates: Any) -> Any:
78+
"""Find dates that match between filters and group_of_dates.
7979
8080
Parameters
8181
----------
82-
filtering_dates : Any
83-
The dates to filter by.
82+
filters : dict
83+
A dict mapping filter keys to DatesProvider objects.
84+
Trajectory layouts support ``'base_dates'`` and ``'steps'``.
8485
group_of_dates : ForecastDates
8586
The ``(valid_time, basetime)`` pairs for this group.
8687
8788
Returns
8889
-------
8990
ForecastDates
90-
A ForecastDates containing only the pairs whose basetime is in
91-
``filtering_dates``.
91+
A ForecastDates containing only the matching pairs.
9292
"""
93+
unsupported = set(filters) - {"base_dates", "steps"}
94+
if unsupported:
95+
raise ValueError(
96+
f"Trajectory layout does not support filtering by {unsupported}. "
97+
"Use 'base_dates' and/or 'steps' instead."
98+
)
99+
93100
from anemoi.datasets.create.arguments import ForecastDates
94101

95-
filter_set = set(filtering_dates)
96-
matched = [(vt, bt) for vt, bt in group_of_dates if bt in filter_set]
102+
matched = list(group_of_dates)
103+
104+
if "base_dates" in filters:
105+
base_dates_set = set(filters["base_dates"])
106+
matched = [(vt, bt) for vt, bt in matched if bt in base_dates_set]
107+
108+
if "steps" in filters:
109+
steps_set = set(filters["steps"])
110+
matched = [(vt, bt) for vt, bt in matched if (vt - bt) in steps_set]
111+
97112
return ForecastDates(matched)
98113

99114
def origin(self, data: Any, action: Any, action_arguments: Any) -> Any:

0 commit comments

Comments
 (0)