From ab828fe944a5f8228a37ce938b51a17d4d3ebb16 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 19 Mar 2025 12:33:09 +0000 Subject: [PATCH 1/4] add padded=True in open_dataset --- src/anemoi/datasets/data/dataset.py | 14 +++ src/anemoi/datasets/data/padded.py | 165 ++++++++++++++++++++++++++++ tests/test_data.py | 26 ++++- 3 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 src/anemoi/datasets/data/padded.py diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 20ff6a898..75171d33d 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -179,6 +179,17 @@ def __subset(self, **kwargs: Any) -> "Dataset": if "start" in kwargs or "end" in kwargs: start = kwargs.pop("start", None) end = kwargs.pop("end", None) + padding = kwargs.pop("padding", None) + + if padding: + from .padded import Padded + + frequency = kwargs.pop("frequency", self.frequency) + return ( + Padded(self, start, end, frequency, dict(start=start, end=end, frequency=frequency)) + ._subset(**kwargs) + .mutate() + ) from .subset import Subset @@ -705,6 +716,9 @@ def grids(self) -> TupleIndex: """Return the grid shape of the dataset.""" return (self.shape[-1],) + def empty_item(self) -> NDArray[Any]: + return np.zeros((*self.shape[1:-1], 0), dtype=self.dtype) + def _check(self) -> None: """Check for overridden private methods in the dataset.""" common = Dataset.__dict__.keys() & self.__class__.__dict__.keys() diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py new file mode 100644 index 000000000..8d1a1ed48 --- /dev/null +++ b/src/anemoi/datasets/data/padded.py @@ -0,0 +1,165 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +from functools import cached_property +from typing import Any +from typing import Dict +from typing import Set + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta +from numpy.typing import NDArray + +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import expand_list_indexing + +LOG = logging.getLogger(__name__) + + +def _normalise_date(date, default): + if date is None: + date = default + + if isinstance(date, str): + try: + date = datetime.datetime.fromisoformat(date) + except ValueError: + raise ValueError(f"Invalid date {date}, only isoformat is supported with padding") + + if isinstance(date, datetime.datetime): + date = np.datetime64(date, "s") + + assert isinstance(date, np.datetime64), (date, type(date)) + + return date + + +class Padded(Forwards): + _before: int = 0 + _after: int = 0 + _inside: int = 0 + + def __init__(self, dataset, start, end, frequency, reason): + self.reason = {k: v for k, v in reason.items() if v is not None} + + if frequency is None: + frequency = dataset.frequency + + self._frequency = frequency_to_timedelta(frequency) + + start = _normalise_date(start, dataset.dates[0]) + end = _normalise_date(end, dataset.dates[-1]) + + assert isinstance(dataset.dates[0], np.datetime64), (dataset.dates[0], type(dataset.dates[0])) + + timedelta = np.array([frequency], dtype="timedelta64[s]")[0] + + dates_parts = [] + + if start < dataset.dates[0]: + dates_parts.append(np.arange(start, dataset.dates[0], timedelta)) + self._before = len(dates_parts[-1]) + + dates_parts.append(dataset.dates) + self._inside = len(dates_parts[-1]) + + if end > dataset.dates[-1]: + dates_parts.append(np.arange(dataset.dates[-1] + timedelta, end + timedelta, timedelta)) + self._after = len(dates_parts[-1]) + + self._dates = np.hstack(dates_parts) + assert len(self._dates) == self._before + self._inside + self._after, ( + len(self._dates), + self._before, + self._inside, + self._after, + ) + + # Forward other properties to the super dataset + super().__init__(dataset) + + @debug_indexing + def __getitem__(self, n: FullIndex) -> NDArray[Any]: + if isinstance(n, tuple): + return self._get_tuple(n) + + if isinstance(n, slice): + return self._get_slice(n) + + if 0 <= n < self._before: + return self.empty_item() + + if (self._before + self._inside) <= n < (self._before + self._inside + self._after): + return self.empty_item() + + return self.forward[n - self._before] + + @debug_indexing + def _get_slice(self, s: slice) -> NDArray[Any]: + LOG.warning("Padded subset does not support slice indexing, returning a list") + return [self[i] for i in range(*s.indices(self._len))] + + @debug_indexing + @expand_list_indexing + def _get_tuple(self, n: TupleIndex) -> NDArray[Any]: + LOG.warning("Padded subset does not support tuple indexing, returning a list") + return [self[i] for i in n] + + def empty_item(self): + return self.forward.empty_item() + + def __len__(self) -> int: + print("len", len(self._dates)) + return len(self._dates) + + @property + def frequency(self) -> datetime.timedelta: + """Get the frequency of the subset.""" + return self._frequency + + @property + def dates(self) -> NDArray[np.datetime64]: + return self._dates + + @property + def shape(self) -> Shape: + return (len(self.dates),) + self.forward.shape[1:] + + @cached_property + def missing(self) -> Set[int]: + raise NotImplementedError + + return self.forward.missing + + def tree(self) -> Node: + """Get the tree representation of the subset. + + Returns: + Node: The tree representation of the subset. + """ + return Node(self, [self.forward.tree()], **self.reason) + + def forwards_subclass_metadata_specific(self) -> Dict[str, Any]: + """Get the metadata specific to the forwards subclass. + + Returns: + Dict[str, Any]: The metadata specific to the forwards subclass. + """ + return { + # "indices": self.indices, + "reason": self.reason, + } diff --git a/tests/test_data.py b/tests/test_data.py index c2fba512d..e65511ceb 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -30,6 +30,7 @@ from anemoi.datasets.data.join import Join from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date +from anemoi.datasets.data.padded import Padded from anemoi.datasets.data.select import Rename from anemoi.datasets.data.select import Select from anemoi.datasets.data.statistics import Statistics @@ -388,6 +389,7 @@ def run( time_increment: datetime.timedelta, statistics_reference_dataset: Optional[Union[str, list]], statistics_reference_variables: Optional[Union[str, list]], + regular_shape: bool = True, ) -> None: """Run the dataset tests. @@ -413,6 +415,8 @@ def run( Reference dataset for statistics. statistics_reference_variables : Optional[Union[str, list]] Reference variables for statistics. + regular_shape : bool, optional + Whether the dataset has a regular shape, by default True. """ if isinstance(expected_variables, str): expected_variables = [v for v in expected_variables] @@ -451,7 +455,8 @@ def run( statistics_reference_variables, ) - self.indexing(self.ds) + if regular_shape: + self.indexing(self.ds) self.metadata(self.ds) self.ds.tree() @@ -704,6 +709,25 @@ def test_subset_2() -> None: ) +@mockup_open_zarr +def test_subset_2_padding() -> None: + """Test subsetting a dataset (case 2).""" + test = DatasetTester("test-2022-2022-1h-o96-abcd", start="2021-01-01", end="2023-12-31 23:00", padding=True) + test.run( + expected_class=Padded, + expected_length=365 * 24 * 3, + expected_shape=(365 * 24 * 3, 4, 1, VALUES), + expected_variables="abcd", + expected_name_to_index="abcd", + date_to_row=lambda date: simple_row(date, "abcd") if date.year == 2022 else np.zeros((4, 1, 0)), + start_date=datetime.datetime(2021, 1, 1), + time_increment=datetime.timedelta(hours=1), + statistics_reference_dataset="test-2022-2022-1h-o96-abcd", + statistics_reference_variables="abcd", + regular_shape=False, + ) + + @mockup_open_zarr def test_subset_3() -> None: """Test subsetting a dataset (case 3).""" From d7126b49699fafbc8ddd77970450a482f42b5f0e Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 19 Mar 2025 13:25:41 +0000 Subject: [PATCH 2/4] better error msg --- src/anemoi/datasets/data/padded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index 8d1a1ed48..9576905ad 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -141,7 +141,7 @@ def shape(self) -> Shape: @cached_property def missing(self) -> Set[int]: - raise NotImplementedError + raise NotImplementedError("Need to decide whether to include the added dates as missing or not") return self.forward.missing From b5c41c22d013b3f1790c56317424b1f3370ef3e0 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 19 Mar 2025 13:26:08 +0000 Subject: [PATCH 3/4] better error msg --- src/anemoi/datasets/data/padded.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index 9576905ad..df24d3d1f 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -142,8 +142,7 @@ def shape(self) -> Shape: @cached_property def missing(self) -> Set[int]: raise NotImplementedError("Need to decide whether to include the added dates as missing or not") - - return self.forward.missing + # return self.forward.missing def tree(self) -> Node: """Get the tree representation of the subset. From aaa64bebe40de3f6595ceb2bdb57f3356b6154b0 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 19 Mar 2025 13:26:30 +0000 Subject: [PATCH 4/4] clean --- src/anemoi/datasets/data/padded.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index df24d3d1f..e155dd963 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -123,7 +123,6 @@ def empty_item(self): return self.forward.empty_item() def __len__(self) -> int: - print("len", len(self._dates)) return len(self._dates) @property