Skip to content

Commit 33609b5

Browse files
🐛 fix binned_statistics issue when passing a list or array as bins (#584)
* passing a list/ array as bins * set default name for coord and data
1 parent 537411d commit 33609b5

File tree

2 files changed

+127
-93
lines changed

2 files changed

+127
-93
lines changed

clouddrift/binning.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import xarray as xr
1111

1212
DEFAULT_BINS_NUMBER = 10
13+
DEFAULT_COORD_NAME = "coord"
14+
DEFAULT_DATA_NAME = "data"
1315

1416

1517
def binned_statistics(
@@ -51,10 +53,10 @@ def binned_statistics(
5153
- a list containing any combination of the above, e.g., ['mean', np.nanmax, ('ke', lambda data: np.sqrt(np.mean(data[0] ** 2 + data[1] ** 2)))].
5254
dim_names : list of str, optional
5355
Names for the dimensions of the output xr.Dataset.
54-
If None, default names are "dim_0_bin", "dim_1_bin", etc.
56+
If None, default names are "coord_0", "coord_1", etc.
5557
output_names : list of str, optional
5658
Names for output variables in the xr.Dataset.
57-
If None, default names are "binned_0_{statistic}", "binned_1_{statistic}", etc.
59+
If None, default names are "data_0_{statistic}", "data_1_{statistic}", etc.
5860
5961
Returns
6062
-------
@@ -165,21 +167,22 @@ def binned_statistics(
165167

166168
# set default dimension names
167169
if dim_names is None:
168-
dim_names = [f"dim_{i}_bin" for i in range(len(coords))]
170+
dim_names = [f"{DEFAULT_COORD_NAME}_{i}" for i in range(len(coords))]
169171
else:
170172
dim_names = [
171-
name if name is not None else f"dim_{i}_bin"
173+
name if name is not None else f"{DEFAULT_COORD_NAME}_{i}"
172174
for i, name in enumerate(dim_names)
173175
]
174176

175177
# set default variable names
176178
if output_names is None:
177179
output_names = [
178-
f"binned_{i}" if data[0].size else "binned" for i in range(len(data))
180+
f"{DEFAULT_DATA_NAME}_{i}" if data[0].size else DEFAULT_DATA_NAME
181+
for i in range(len(data))
179182
]
180183
else:
181184
output_names = [
182-
name if name is not None else f"binned_{i}"
185+
name if name is not None else f"{DEFAULT_DATA_NAME}_{i}"
183186
for i, name in enumerate(output_names)
184187
]
185188

@@ -192,7 +195,10 @@ def binned_statistics(
192195
raise ValueError("`coords` and `data` must have the same number of data points")
193196

194197
# edges and bin centers
195-
edges = [np.linspace(r[0], r[1], b + 1) for r, b in zip(bins_range, bins)]
198+
if isinstance(bins, int) or isinstance(bins[0], int):
199+
edges = [np.linspace(r[0], r[1], b + 1) for r, b in zip(bins_range, bins)]
200+
else:
201+
edges = [np.asarray(b) for b in bins]
196202
edges_sz = [len(e) - 1 for e in edges]
197203
n_bins = int(np.prod(edges_sz))
198204
bin_centers = [0.5 * (e[:-1] + e[1:]) for e in edges]
@@ -206,7 +212,7 @@ def binned_statistics(
206212
# by adding a small tolerance to the last edge (1s for date coordinates)
207213
edges_with_tol = [e.copy() for e in edges]
208214
for i, e in enumerate(edges_with_tol):
209-
e[-1] += np.finfo(e.dtype).eps if i not in coords_datetime_index else 1
215+
e[-1] += np.finfo(float).eps if i not in coords_datetime_index else 1
210216
indices = [np.digitize(c, edges_with_tol[j]) - 1 for j, c in enumerate(coords)]
211217
valid = np.all(
212218
[(j >= 0) & (j < edges_sz[i]) for i, j in enumerate(indices)], axis=0
@@ -332,7 +338,7 @@ def _get_variable_name(
332338
str
333339
Name of the function or a custom function name for lambda function.
334340
"""
335-
default_name = "stat_0"
341+
default_name = "stat"
336342
if isinstance(func, partial):
337343
function_name = getattr(func.func, "__name__", default_name)
338344
else:
@@ -342,17 +348,15 @@ def _get_variable_name(
342348

343349
# avoid name collisions with existing variables
344350
# by adding a suffix if the name already exists
345-
if f"{output_name}_{function_name}" in ds_vars:
346-
i = 0
347-
if function_name.split("_")[-1].isdigit():
348-
function_name, num = function_name.split("_")
349-
i = int(num) + 1
351+
base_name = f"{output_name}_{function_name}"
352+
name = base_name
350353

351-
while f"{output_name}_{function_name}_{i}" in ds_vars:
352-
i += 1
353-
function_name = f"{function_name}_{i}"
354+
i = 1
355+
while name in ds_vars:
356+
name = f"{base_name}_{i}"
357+
i += 1
354358

355-
return f"{output_name}_{function_name}"
359+
return name
356360

357361

358362
def _filter_valid_and_finite(

0 commit comments

Comments
 (0)