Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/anemoi/datasets/create/filters/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def execute(self, input: ekd.FieldList) -> ekd.FieldList:

Parameters
----------
context : Any
The context in which the execution occurs.
input : ekd.FieldList
The input data to be transformed.

Expand Down
16 changes: 16 additions & 0 deletions src/anemoi/datasets/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# from .dataset import Shape
# from .dataset import TupleIndex
from .misc import _open_dataset
from .misc import _save_dataset
from .misc import add_dataset_path
from .misc import add_named_dataset

Expand Down Expand Up @@ -92,6 +93,21 @@ def open_dataset(*args: Any, **kwargs: Any) -> "Dataset":
return ds


def save_dataset(recipe: dict, zarr_path: str, n_workers: int = 1) -> None:
"""Open a dataset and save it to disk.

Parameters
----------
recipe : dict
Recipe used with open_dataset (not a dataset creation recipe).
zarr_path : str
Path to store the obtained anemoi dataset to disk.
n_workers : int
Number of workers to use for parallel processing. If none, sequential processing will be performed.
"""
_save_dataset(recipe, zarr_path, n_workers)


def list_dataset_names(*args: Any, **kwargs: Any) -> list[str]:
"""List the names of datasets.

Expand Down
159 changes: 159 additions & 0 deletions src/anemoi/datasets/data/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,162 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset":
return dataset._subset(**kwargs)

return sets[0]._subset(**kwargs)


def append_to_zarr(new_data: np.ndarray, new_dates: np.ndarray, zarr_path: str) -> None:
"""Append data from a subset (for one date) to the Zarr store.

Parameters
----------
new_data : np.ndarray
The new data to append.
new_dates : np.ndarray
The new dates to append.
zarr_path : str
The path to the Zarr store.

Notes
-----
- "dates" dataset is created with chunks equal to len(big_dataset.dates).
- "data" dataset is created with chunk size 1 along the first (time) dimension.
"""
print("Appending data for", new_dates, flush=True)
# Re-open the zarr store to avoid root object accumulating memory.
root = zarr.open(zarr_path, mode="a")
# Convert new dates to strings (using str) regardless of input dtype.
new_dates = np.array(new_dates, dtype="datetime64[ns]")
dates_ds = root["dates"]
old_len = dates_ds.shape[0]
dates_ds.resize((old_len + len(new_dates),))
dates_ds[old_len:] = new_dates
# Append to "data" dataset.
data_ds = root["data"]
old_shape = data_ds.shape # (time, n_vars, ensembles, gridpoints)
new_shape = (old_shape[0] + len(new_dates),) + old_shape[1:]
data_ds.resize(new_shape)
data_ds[old_shape[0] :] = new_data


def process_date(date: Any, big_dataset: Any) -> Tuple[np.ndarray, np.ndarray]:
"""Open the subset corresponding to the given date and return (date, subset).

Parameters
----------
date : Any
The date to process.
big_dataset : Any
The dataset to process.

Returns
-------
Tuple[np.ndarray, np.ndarray]
The subset and the date.
"""
print("Processing:", date, flush=True)
subset = _open_dataset(big_dataset, start=date, end=date).mutate()
s = subset[:]
date = subset.dates
return s, date


def initialize_zarr_store(root: Any, big_dataset: Any, recipe: Dict[str, Any]) -> None:
"""Initialize the Zarr store with the given dataset and recipe.

Parameters
----------
root : Any
The root of the Zarr store.
big_dataset : Any
The dataset to initialize the store with.
recipe : Dict[str, Any]
The recipe for initializing the store.
"""
ensembles = big_dataset.shape[1]
# Create or append to "dates" dataset.
if "dates" not in root:
full_length = len(big_dataset.dates)
root.create_dataset("dates", data=np.array([], dtype="datetime64[s]"), chunks=(full_length,))

if "data" not in root:
dims = (1, len(big_dataset.variables), ensembles, big_dataset.grids[0])
root.create_dataset(
"data",
shape=dims,
dtype=np.float64,
chunks=dims,
)

for k, v in big_dataset.statistics.items():
if k not in root:
root.create_dataset(
k,
data=v,
compressor=None,
)

# Create spatial coordinate datasets if missing.
if "latitudes" not in root or "longitudes" not in root:
root.create_dataset("latitudes", data=big_dataset.latitudes, compressor=None)
root.create_dataset("longitudes", data=big_dataset.longitudes, compressor=None)

# Set store-wide attributes if not already set.
if "frequency" not in root.attrs:
root.attrs["frequency"] = "10m"
root.attrs["resolution"] = "1km"
root.attrs["name_to_index"] = {k: i for i, k in enumerate(big_dataset.variables)}
root.attrs["ensemble_dimension"] = 1
root.attrs["field_shape"] = big_dataset.field_shape
root.attrs["flatten_grid"] = True
root.attrs["recipe"] = recipe


def _save_dataset(recipe: Dict[str, Any], zarr_path: str, n_workers: int = 1) -> None:
"""Incrementally create (or update) a Zarr store from an Anemoi dataset.

Parameters
----------
recipe : Dict[str, Any]
The recipe for creating the dataset.
zarr_path : str
The path to the Zarr store.
n_workers : int, optional
The number of worker processes to use, by default 1.

Notes
-----
Worker processes extract data for each date in parallel, but all writes
to the store happen sequentially in the main process (i.e. single-writer).

The "dates" dataset is created with chunking equal to the full length of
big_dataset.dates, while "data" is chunked with 1 in the time dimension.
"""
from concurrent.futures import ProcessPoolExecutor

full_ds = _open_dataset(recipe).mutate()
print("Opened full dataset.", flush=True)

# Use ProcessPoolExecutor for parallel data extraction.
# Workers return (date, subset) tuples.
root = zarr.open(zarr_path, mode="a")
initialize_zarr_store(root, full_ds, recipe)
print("Zarr store initialized.", flush=True)

existing_dates = np.array(sorted(root["dates"]), dtype="datetime64[s]")
all_dates = full_ds.dates
# To resume creation of the Zarr store in case the job is interrupted.
dates_to_process = np.array(sorted(set(all_dates).difference(existing_dates)), dtype="datetime64[s]")

use_pool = False

if use_pool:
with ProcessPoolExecutor(n_workers) as pool:
futures = [pool.submit(process_date, date, full_ds) for date in dates_to_process]
for future in futures:
subset, date = future.result()
# All appends happen sequentially here to
# avoid dates being added in a random order.
append_to_zarr(subset, date, zarr_path)
else:
for date in dates_to_process:
subset, date = process_date(date, full_ds)
append_to_zarr(subset, date, zarr_path)
Loading