Skip to content

Commit ec6f121

Browse files
authored
Fix N-d velocity and position functions with time_axis (#243)
* Testing swapaxes in velocity_from_position * Expand tests for time_axis in velocity_from_position * Fix time_axis in position_from_velocity * Bump patch version * Make linter happy
1 parent cf4f9cd commit ec6f121

File tree

3 files changed

+80
-81
lines changed

3 files changed

+80
-81
lines changed

clouddrift/analysis.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -612,17 +612,15 @@ def position_from_velocity(
612612
f" {len(x.shape) - 1}])."
613613
)
614614

615-
# Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1)
616-
target_axes = list(range(len(u.shape)))
617-
618-
# If time_axis is not the last one, transpose the inputs
619-
if time_axis != -1 and time_axis < len(u.shape) - 1:
620-
target_axes.append(target_axes.pop(target_axes.index(time_axis)))
621-
622-
# Reshape the inputs to ensure the time axis is last (fast-varying)
623-
u_ = np.transpose(u, target_axes)
624-
v_ = np.transpose(v, target_axes)
625-
time_ = np.transpose(time, target_axes)
615+
# Swap axes so that we can differentiate along the last axis.
616+
# This is a syntax convenience rather than memory access optimization:
617+
# np.swapaxes returns a view of the array, not a copy, if the input is a
618+
# NumPy array. Otherwise, it returns a copy. For readability, introduce new
619+
# variable names so that we can more easily differentiate between the
620+
# original arrays and those with swapped axes.
621+
u_ = np.swapaxes(u, time_axis, -1)
622+
v_ = np.swapaxes(v, time_axis, -1)
623+
time_ = np.swapaxes(time, time_axis, -1)
626624

627625
x = np.zeros(u_.shape, dtype=u.dtype)
628626
y = np.zeros(v_.shape, dtype=v.dtype)
@@ -659,10 +657,7 @@ def position_from_velocity(
659657
else:
660658
raise ValueError('coord_system must be "spherical" or "cartesian".')
661659

662-
if target_axes == list(range(len(u.shape))):
663-
return x, y
664-
else:
665-
return np.transpose(x, target_axes), np.transpose(y, target_axes)
660+
return np.swapaxes(x, time_axis, -1), np.swapaxes(y, time_axis, -1)
666661

667662

668663
def velocity_from_position(
@@ -754,17 +749,15 @@ def velocity_from_position(
754749
f" {len(x.shape) - 1}])."
755750
)
756751

757-
# Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1)
758-
target_axes = list(range(len(x.shape)))
759-
760-
# If time_axis is not the last one, transpose the inputs
761-
if time_axis != -1 and time_axis < len(x.shape) - 1:
762-
target_axes.append(target_axes.pop(target_axes.index(time_axis)))
763-
764-
# Reshape the inputs to ensure the time axis is last (fast-varying)
765-
x_ = np.transpose(x, target_axes)
766-
y_ = np.transpose(y, target_axes)
767-
time_ = np.transpose(time, target_axes)
752+
# Swap axes so that we can differentiate along the last axis.
753+
# This is a syntax convenience rather than memory access optimization:
754+
# np.swapaxes returns a view of the array, not a copy, if the input is a
755+
# NumPy array. Otherwise, it returns a copy. For readability, introduce new
756+
# variable names so that we can more easily differentiate between the
757+
# original arrays and those with swapped axes.
758+
x_ = np.swapaxes(x, time_axis, -1)
759+
y_ = np.swapaxes(y, time_axis, -1)
760+
time_ = np.swapaxes(time, time_axis, -1)
768761

769762
dx = np.empty(x_.shape)
770763
dy = np.empty(y_.shape)
@@ -873,10 +866,11 @@ def velocity_from_position(
873866
'difference_scheme must be "forward", "backward", or "centered".'
874867
)
875868

876-
if target_axes == list(range(len(x.shape))):
877-
return dx / dt, dy / dt
878-
else:
879-
return np.transpose(dx / dt, target_axes), np.transpose(dy / dt, target_axes)
869+
# This should avoid an array copy when returning the result
870+
dx /= dt
871+
dy /= dt
872+
873+
return np.swapaxes(dx, time_axis, -1), np.swapaxes(dy, time_axis, -1)
880874

881875

882876
def mask_var(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "clouddrift"
7-
version = "0.20.0"
7+
version = "0.20.1"
88
authors = [
99
{ name="Shane Elipot", email="[email protected]" },
1010
{ name="Philippe Miron", email="[email protected]" },

tests/analysis_tests.py

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -493,36 +493,6 @@ def test_velocity_position_roundtrip_centered(self):
493493
self.assertTrue(np.allclose(lon, self.lon, atol=1e-2))
494494
self.assertTrue(np.allclose(lat, self.lat, atol=1e-2))
495495

496-
def test_time_axis(self):
497-
uf = np.transpose(
498-
np.reshape(np.tile(self.uf, 4), (2, 2, self.uf.size)), (0, 2, 1)
499-
)
500-
vf = np.transpose(
501-
np.reshape(np.tile(self.vf, 4), (2, 2, self.vf.size)), (0, 2, 1)
502-
)
503-
time = np.transpose(
504-
np.reshape(np.tile(self.time, 4), (2, 2, self.time.size)), (0, 2, 1)
505-
)
506-
expected_lon = np.transpose(
507-
np.reshape(np.tile(self.lon, 4), (2, 2, self.lon.size)), (0, 2, 1)
508-
)
509-
expected_lat = np.transpose(
510-
np.reshape(np.tile(self.lat, 4), (2, 2, self.lat.size)), (0, 2, 1)
511-
)
512-
lon, lat = position_from_velocity(
513-
uf,
514-
vf,
515-
time,
516-
self.lon[0],
517-
self.lat[0],
518-
integration_scheme="forward",
519-
time_axis=1,
520-
)
521-
self.assertTrue(np.allclose(lon, expected_lon))
522-
self.assertTrue(np.allclose(lat, expected_lat))
523-
self.assertTrue(np.all(lon.shape == expected_lon.shape))
524-
self.assertTrue(np.all(lat.shape == expected_lat.shape))
525-
526496
def test_works_with_xarray(self):
527497
lon, lat = position_from_velocity(
528498
xr.DataArray(data=self.uf),
@@ -554,6 +524,36 @@ def test_works_with_2d_array(self):
554524
self.assertTrue(np.allclose(lon.shape, expected_lon.shape))
555525
self.assertTrue(np.allclose(lon.shape, expected_lat.shape))
556526

527+
def test_time_axis(self):
528+
uf = np.reshape(np.tile(self.uf, 6), (2, 3, self.uf.size))
529+
vf = np.reshape(np.tile(self.vf, 6), (2, 3, self.vf.size))
530+
time = np.reshape(np.tile(self.time, 6), (2, 3, self.time.size))
531+
expected_lon = np.reshape(np.tile(self.lon, 6), (2, 3, self.lon.size))
532+
expected_lat = np.reshape(np.tile(self.lat, 6), (2, 3, self.lat.size))
533+
534+
for time_axis in [0, 1, 2]:
535+
# Pass inputs with swapped axes and differentiate along that time
536+
# axis.
537+
lon, lat = position_from_velocity(
538+
np.swapaxes(uf, time_axis, -1),
539+
np.swapaxes(vf, time_axis, -1),
540+
np.swapaxes(time, time_axis, -1),
541+
self.lon[0],
542+
self.lat[0],
543+
integration_scheme="forward",
544+
time_axis=time_axis,
545+
)
546+
547+
# Swap axes back to compare with the expected result.
548+
self.assertTrue(np.allclose(np.swapaxes(lon, time_axis, -1), expected_lon))
549+
self.assertTrue(np.allclose(np.swapaxes(lat, time_axis, -1), expected_lat))
550+
self.assertTrue(
551+
np.all(np.swapaxes(lon, time_axis, -1).shape == expected_lon.shape)
552+
)
553+
self.assertTrue(
554+
np.all(np.swapaxes(lat, time_axis, -1).shape == expected_lat.shape)
555+
)
556+
557557

558558
class velocity_from_position_tests(unittest.TestCase):
559559
def setUp(self):
@@ -621,26 +621,31 @@ def test_works_with_3d_array(self):
621621
self.assertTrue(np.all(vf.shape == expected_vf.shape))
622622

623623
def test_time_axis(self):
624-
lon = np.transpose(
625-
np.reshape(np.tile(self.lon, 4), (2, 2, self.lon.size)), (0, 2, 1)
626-
)
627-
lat = np.transpose(
628-
np.reshape(np.tile(self.lat, 4), (2, 2, self.lat.size)), (0, 2, 1)
629-
)
630-
time = np.transpose(
631-
np.reshape(np.tile(self.time, 4), (2, 2, self.time.size)), (0, 2, 1)
632-
)
633-
expected_uf = np.transpose(
634-
np.reshape(np.tile(self.uf, 4), (2, 2, self.uf.size)), (0, 2, 1)
635-
)
636-
expected_vf = np.transpose(
637-
np.reshape(np.tile(self.vf, 4), (2, 2, self.vf.size)), (0, 2, 1)
638-
)
639-
uf, vf = velocity_from_position(lon, lat, time, time_axis=1)
640-
self.assertTrue(np.all(uf == expected_uf))
641-
self.assertTrue(np.all(vf == expected_vf))
642-
self.assertTrue(np.all(uf.shape == expected_uf.shape))
643-
self.assertTrue(np.all(vf.shape == expected_vf.shape))
624+
lon = np.reshape(np.tile(self.lon, 6), (2, 3, self.lon.size))
625+
lat = np.reshape(np.tile(self.lat, 6), (2, 3, self.lat.size))
626+
time = np.reshape(np.tile(self.time, 6), (2, 3, self.time.size))
627+
expected_uf = np.reshape(np.tile(self.uf, 6), (2, 3, self.uf.size))
628+
expected_vf = np.reshape(np.tile(self.vf, 6), (2, 3, self.vf.size))
629+
630+
for time_axis in [0, 1, 2]:
631+
# Pass inputs with swapped axes and differentiate along that time
632+
# axis.
633+
uf, vf = velocity_from_position(
634+
np.swapaxes(lon, time_axis, -1),
635+
np.swapaxes(lat, time_axis, -1),
636+
np.swapaxes(time, time_axis, -1),
637+
time_axis=time_axis,
638+
)
639+
640+
# Swap axes back to compare with the expected result.
641+
self.assertTrue(np.all(np.swapaxes(uf, time_axis, -1) == expected_uf))
642+
self.assertTrue(np.all(np.swapaxes(vf, time_axis, -1) == expected_vf))
643+
self.assertTrue(
644+
np.all(np.swapaxes(uf, time_axis, -1).shape == expected_uf.shape)
645+
)
646+
self.assertTrue(
647+
np.all(np.swapaxes(vf, time_axis, -1).shape == expected_uf.shape)
648+
)
644649

645650

646651
class apply_ragged_tests(unittest.TestCase):

0 commit comments

Comments
 (0)