Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/anemoi/transform/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,24 @@ def new_field_with_metadata(template: WrappedField, **metadata: Any) -> NewMetad
return NewMetadataField(template, **metadata)


def new_field_with_units(template: WrappedField, units: str) -> NewMetadataField:
"""Create a new field with units.

Parameters
----------
template : WrappedField
The template field to use.
units : str
The units for the new field.

Returns
-------
NewMetadataField
The new field with the provided units.
"""
return NewMetadataField(template, units=units)


def new_field_from_latitudes_longitudes(
template: WrappedField, latitudes: np.ndarray, longitudes: np.ndarray
) -> NewLatLonField:
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/transform/filters/fields/glacier_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ def forward_transform(self, snow_depth: ekd.Field) -> ekd.Field:
"""
snow_depth_masked = mask_glaciers(snow_depth.to_numpy(), self.glacier_mask)

return self.new_field_from_numpy(snow_depth_masked, template=snow_depth, param=self.snow_depth_masked)
return self.new_field_from_numpy(
snow_depth_masked, template=snow_depth, param=self.snow_depth_masked, units="Fraction"
)
8 changes: 4 additions & 4 deletions src/anemoi/transform/filters/fields/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _transform(
result.append(f)
return self.new_fieldlist_from_list(result)

def new_field_from_numpy(self, array: np.ndarray, *, template: ekd.Field, param: str) -> ekd.Field:
def new_field_from_numpy(self, array: np.ndarray, *, template: ekd.Field, **kwargs) -> ekd.Field:
"""Create a new field from a numpy array.

Parameters
Expand All @@ -369,15 +369,15 @@ def new_field_from_numpy(self, array: np.ndarray, *, template: ekd.Field, param:
Numpy array containing the field data.
template : ekd.Field
Template field to use for metadata.
param : str
Parameter name for the new field.
**kwargs : Any
Additional keyword arguments for the new field.

Returns
-------
ekd.Field
New field created from the numpy array.
"""
return new_field_from_numpy(array, template=template, param=param)
return new_field_from_numpy(array, template=template, **kwargs)

def new_fieldlist_from_list(self, fields: list[ekd.Field]) -> ekd.FieldList:
"""Create a new field list from a list of fields.
Expand Down
20 changes: 13 additions & 7 deletions src/anemoi/transform/filters/fields/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class RescaleMixin(ABC):
# intended to be inherited from SingleFieldFilter
new_field_from_numpy: Callable

forward_units = None
backward_units = None

@abstractmethod
def prepare_filter(self):
raise NotImplementedError("prepare_filter must be implemented by subclasses.")
Expand All @@ -46,7 +49,7 @@ def forward_select(self):
def forward_transform(self, param: ekd.Field) -> ekd.Field:
"""Apply the forward transformation (x to ax+b)."""
rescaled = self.rescaler.forward(param.to_numpy())
return self.new_field_from_numpy(rescaled, template=param, param=self.param)
return self.new_field_from_numpy(rescaled, template=param, param=self.param, units=self.forward_units)

def backward_transform(self, param: ekd.Field) -> ekd.Field:
"""Apply the backward transformation (ax+b to x)."""
Expand All @@ -66,8 +69,7 @@ def prepare_filter(self):
class Convert(RescaleMixin, SingleFieldFilter):
"""A filter to convert a parameter in a given unit to another unit, and back.

This filter uses :mod:`cfunits` (see the `cfunits documentation <https://ncas-cms.github.io/cfunits/>`_)
to compute the scale and offset.
This filter uses :mod:`pint` to compute the scale and offset.

Examples
--------
Expand All @@ -89,12 +91,16 @@ class Convert(RescaleMixin, SingleFieldFilter):
required_inputs = ("unit_in", "unit_out", "param")

def prepare_filter(self):
from cfunits import Units
import pint

ureg = pint.UnitRegistry()

self.forward_units = self.unit_out
self.backward_units = self.unit_in

u0 = Units(self.unit_in)
u1 = Units(self.unit_out)
x1, x2 = 0.0, 1.0
y1, y2 = Units.conform([x1, x2], u0, u1)
y1 = ureg.Quantity(x1, self.unit_in).to(self.unit_out).magnitude
y2 = ureg.Quantity(x2, self.unit_in).to(self.unit_out).magnitude
scale = (y2 - y1) / (x2 - x1)
offset = y1 - scale * x1

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/transform/filters/fields/snow_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ def forward_transform(self, snow_depth: ekd.Field, snow_density: ekd.Field) -> I
"""
snow_cover = compute_snow_cover(snow_depth.to_numpy(), snow_density.to_numpy())

yield self.new_field_from_numpy(snow_cover, template=snow_depth, param=self.snow_cover)
yield self.new_field_from_numpy(snow_cover, template=snow_depth, param=self.snow_cover, units="Fraction")
39 changes: 39 additions & 0 deletions src/anemoi/transform/units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# (C) Copyright 2026 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

UNITS_MAPPING = {
"Numeric": "dimensionless", # This is WMO, but Numeric will choke pint or cfunits
}


class Units:
def __init__(self, units: str) -> None:
self.units = units

def __str__(self) -> str:
return self.units

def __repr__(self) -> str:
return self.units

def __eq__(self, value):
if isinstance(value, Units):
return self.units == value.units
elif isinstance(value, str):
return self.units == value
else:
return NotImplemented

def __hash__(self):
return hash(self.units)

@classmethod
def to_canonical(cls, units: str) -> str:
"""Converts a unit to its canonical form."""
return UNITS_MAPPING.get(units, units)
64 changes: 64 additions & 0 deletions src/anemoi/transform/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
Expand All @@ -16,6 +17,8 @@
if TYPE_CHECKING:
from datetime import timedelta

LOG = logging.getLogger(__name__)


class Variable(ABC):
"""Variable is a class that represents a variable during
Expand Down Expand Up @@ -191,6 +194,12 @@ def is_from_input(self) -> bool:
"""Check if the variable is from input."""
pass

@property
@abstractmethod
def units(self):
"""Get the units of the variable."""
pass

def similarity(self, other: Any) -> int:
"""Compute the similarity between two variables. This is used when
encoding a variable in GRIB and we do not have a template for it.
Expand All @@ -207,3 +216,58 @@ def similarity(self, other: Any) -> int:
The similarity score.
"""
return 0

def compatible(self, other: Any, return_reason: bool = False, **options) -> bool:
if options is None:
options = {}

def _compare():

if self.units != other.units:
Comment thread
b8raoult marked this conversation as resolved.
if self.units is None or other.units is None:
LOG.warning(
f"{self}: one of the variables has missing units: {self.units} vs {other.units}. Assuming they are compatible."
)
else:
return f"Units are not compatible: {self.units} vs {other.units}"

if self.time_processing != other.time_processing:
return f"Time processinging types are not compatible: {self.time_processing} vs {other.time_processing}"

if self.period != other.period:
return f"Periods are not compatible: {self.period} vs {other.period}"

if self.is_pressure_level != other.is_pressure_level:
return f"Pressure level status is not compatible: {self.is_pressure_level} vs {other.is_pressure_level}"

if self.is_model_level != other.is_model_level:
return f"Model level status is not compatible: {self.is_model_level} vs {other.is_model_level}"

if self.is_surface_level != other.is_surface_level:
return f"Surface level status is not compatible: {self.is_surface_level} vs {other.is_surface_level}"

reason = _compare()
if reason:
return False, reason if return_reason else False

return True, None if return_reason else True

@classmethod
def check_compatibility(cls, variables1: dict, variables2: dict, **options) -> bool:
if options is None:
options = {}

keys1 = set(variables1.keys())
keys2 = set(variables2.keys())

if keys1 != keys2:
raise ValueError(f"Variable compatibility: missing={keys1-keys2}, added={keys2-keys1}")

reasons = []
for k in keys1:
compatible, reason = variables1[k].compatible(variables2[k], return_reason=True, **options)
if not compatible:
reasons.append(f"{k}: {reason}")

if reasons:
raise ValueError(f"Variables are not compatible: {'; '.join(reasons)}")
6 changes: 6 additions & 0 deletions src/anemoi/transform/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from anemoi.utils.dates import as_timedelta

from anemoi.transform.units import Units
from anemoi.transform.variables import Variable

if TYPE_CHECKING:
Expand Down Expand Up @@ -106,6 +107,11 @@ def period(self) -> Union["timedelta", None]:

return as_timedelta(period[1]) - as_timedelta(period[0])

@property
def units(self):
units = self.data.get("units", None)
return Units(units) if units else None

@property
def grib_keys(self) -> dict[str, Any]:
"""Get the GRIB keys of the variable."""
Expand Down
15 changes: 0 additions & 15 deletions tests/field_filters/test_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,11 @@

import earthkit.data as ekd
import numpy.testing as npt
import pytest
from anemoi.utils.testing import skip_if_offline

from anemoi.transform.filters import create_filter_by_name as create_filter


def skip_missing_udunits2():
"""Skip tests if udunits2 package is not available."""
# Can't use utils.testing.skip_missing_packages because it only fails
# when cfunits.Units is imported...
# NB: cfunits depends on udunits2 which is a system library and not installed with pip
try:
from cfunits import Units # noqa: F401

return lambda f: f
except FileNotFoundError:
return pytest.mark.skip(reason="udunits2 not found")


@skip_if_offline
def test_rescale(fieldlist: ekd.FieldList) -> None:
"""Test rescaling temperature from Kelvin to Celsius and back.
Expand Down Expand Up @@ -58,7 +44,6 @@ def test_rescale(fieldlist: ekd.FieldList) -> None:
npt.assert_allclose(before_filter[param], after_forward[param])


@skip_missing_udunits2()
@skip_if_offline
def test_convert(fieldlist: ekd.FieldList) -> None:
"""Test converting temperature from Kelvin to Celsius and back.
Expand Down
Loading