Skip to content

Commit 4759dd9

Browse files
feat: save opened anemoi_dataset (#259)
1 parent 777fbac commit 4759dd9

2 files changed

Lines changed: 175 additions & 0 deletions

File tree

src/anemoi/datasets/data/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# from .dataset import Shape
1717
# from .dataset import TupleIndex
1818
from .misc import _open_dataset
19+
from .misc import _save_dataset
1920
from .misc import add_dataset_path
2021
from .misc import add_named_dataset
2122

@@ -92,6 +93,21 @@ def open_dataset(*args: Any, **kwargs: Any) -> "Dataset":
9293
return ds
9394

9495

96+
def save_dataset(recipe: dict, zarr_path: str, n_workers: int = 1) -> None:
97+
"""Open a dataset and save it to disk.
98+
99+
Parameters
100+
----------
101+
recipe : dict
102+
Recipe used with open_dataset (not a dataset creation recipe).
103+
zarr_path : str
104+
Path to store the obtained anemoi dataset to disk.
105+
n_workers : int
106+
Number of workers to use for parallel processing. If none, sequential processing will be performed.
107+
"""
108+
_save_dataset(recipe, zarr_path, n_workers)
109+
110+
95111
def list_dataset_names(*args: Any, **kwargs: Any) -> list[str]:
96112
"""List the names of datasets.
97113

src/anemoi/datasets/data/misc.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,162 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset":
539539
return dataset._subset(**kwargs)
540540

541541
return sets[0]._subset(**kwargs)
542+
543+
544+
def append_to_zarr(new_data: np.ndarray, new_dates: np.ndarray, zarr_path: str) -> None:
545+
"""Append data from a subset (for one date) to the Zarr store.
546+
547+
Parameters
548+
----------
549+
new_data : np.ndarray
550+
The new data to append.
551+
new_dates : np.ndarray
552+
The new dates to append.
553+
zarr_path : str
554+
The path to the Zarr store.
555+
556+
Notes
557+
-----
558+
- "dates" dataset is created with chunks equal to len(big_dataset.dates).
559+
- "data" dataset is created with chunk size 1 along the first (time) dimension.
560+
"""
561+
print("Appending data for", new_dates, flush=True)
562+
# Re-open the zarr store to avoid root object accumulating memory.
563+
root = zarr.open(zarr_path, mode="a")
564+
# Convert new dates to strings (using str) regardless of input dtype.
565+
new_dates = np.array(new_dates, dtype="datetime64[ns]")
566+
dates_ds = root["dates"]
567+
old_len = dates_ds.shape[0]
568+
dates_ds.resize((old_len + len(new_dates),))
569+
dates_ds[old_len:] = new_dates
570+
# Append to "data" dataset.
571+
data_ds = root["data"]
572+
old_shape = data_ds.shape # (time, n_vars, ensembles, gridpoints)
573+
new_shape = (old_shape[0] + len(new_dates),) + old_shape[1:]
574+
data_ds.resize(new_shape)
575+
data_ds[old_shape[0] :] = new_data
576+
577+
578+
def process_date(date: Any, big_dataset: Any) -> Tuple[np.ndarray, np.ndarray]:
579+
"""Open the subset corresponding to the given date and return (date, subset).
580+
581+
Parameters
582+
----------
583+
date : Any
584+
The date to process.
585+
big_dataset : Any
586+
The dataset to process.
587+
588+
Returns
589+
-------
590+
Tuple[np.ndarray, np.ndarray]
591+
The subset and the date.
592+
"""
593+
print("Processing:", date, flush=True)
594+
subset = _open_dataset(big_dataset, start=date, end=date).mutate()
595+
s = subset[:]
596+
date = subset.dates
597+
return s, date
598+
599+
600+
def initialize_zarr_store(root: Any, big_dataset: Any, recipe: Dict[str, Any]) -> None:
601+
"""Initialize the Zarr store with the given dataset and recipe.
602+
603+
Parameters
604+
----------
605+
root : Any
606+
The root of the Zarr store.
607+
big_dataset : Any
608+
The dataset to initialize the store with.
609+
recipe : Dict[str, Any]
610+
The recipe for initializing the store.
611+
"""
612+
ensembles = big_dataset.shape[1]
613+
# Create or append to "dates" dataset.
614+
if "dates" not in root:
615+
full_length = len(big_dataset.dates)
616+
root.create_dataset("dates", data=np.array([], dtype="datetime64[s]"), chunks=(full_length,))
617+
618+
if "data" not in root:
619+
dims = (1, len(big_dataset.variables), ensembles, big_dataset.grids[0])
620+
root.create_dataset(
621+
"data",
622+
shape=dims,
623+
dtype=np.float64,
624+
chunks=dims,
625+
)
626+
627+
for k, v in big_dataset.statistics.items():
628+
if k not in root:
629+
root.create_dataset(
630+
k,
631+
data=v,
632+
compressor=None,
633+
)
634+
635+
# Create spatial coordinate datasets if missing.
636+
if "latitudes" not in root or "longitudes" not in root:
637+
root.create_dataset("latitudes", data=big_dataset.latitudes, compressor=None)
638+
root.create_dataset("longitudes", data=big_dataset.longitudes, compressor=None)
639+
640+
# Set store-wide attributes if not already set.
641+
if "frequency" not in root.attrs:
642+
root.attrs["frequency"] = "10m"
643+
root.attrs["resolution"] = "1km"
644+
root.attrs["name_to_index"] = {k: i for i, k in enumerate(big_dataset.variables)}
645+
root.attrs["ensemble_dimension"] = 1
646+
root.attrs["field_shape"] = big_dataset.field_shape
647+
root.attrs["flatten_grid"] = True
648+
root.attrs["recipe"] = recipe
649+
650+
651+
def _save_dataset(recipe: Dict[str, Any], zarr_path: str, n_workers: int = 1) -> None:
652+
"""Incrementally create (or update) a Zarr store from an Anemoi dataset.
653+
654+
Parameters
655+
----------
656+
recipe : Dict[str, Any]
657+
The recipe for creating the dataset.
658+
zarr_path : str
659+
The path to the Zarr store.
660+
n_workers : int, optional
661+
The number of worker processes to use, by default 1.
662+
663+
Notes
664+
-----
665+
Worker processes extract data for each date in parallel, but all writes
666+
to the store happen sequentially in the main process (i.e. single-writer).
667+
668+
The "dates" dataset is created with chunking equal to the full length of
669+
big_dataset.dates, while "data" is chunked with 1 in the time dimension.
670+
"""
671+
from concurrent.futures import ProcessPoolExecutor
672+
673+
full_ds = _open_dataset(recipe).mutate()
674+
print("Opened full dataset.", flush=True)
675+
676+
# Use ProcessPoolExecutor for parallel data extraction.
677+
# Workers return (date, subset) tuples.
678+
root = zarr.open(zarr_path, mode="a")
679+
initialize_zarr_store(root, full_ds, recipe)
680+
print("Zarr store initialized.", flush=True)
681+
682+
existing_dates = np.array(sorted(root["dates"]), dtype="datetime64[s]")
683+
all_dates = full_ds.dates
684+
# To resume creation of the Zarr store in case the job is interrupted.
685+
dates_to_process = np.array(sorted(set(all_dates).difference(existing_dates)), dtype="datetime64[s]")
686+
687+
use_pool = False
688+
689+
if use_pool:
690+
with ProcessPoolExecutor(n_workers) as pool:
691+
futures = [pool.submit(process_date, date, full_ds) for date in dates_to_process]
692+
for future in futures:
693+
subset, date = future.result()
694+
# All appends happen sequentially here to
695+
# avoid dates being added in a random order.
696+
append_to_zarr(subset, date, zarr_path)
697+
else:
698+
for date in dates_to_process:
699+
subset, date = process_date(date, full_ds)
700+
append_to_zarr(subset, date, zarr_path)

0 commit comments

Comments
 (0)