@@ -493,36 +493,6 @@ def test_velocity_position_roundtrip_centered(self):
493493 self .assertTrue (np .allclose (lon , self .lon , atol = 1e-2 ))
494494 self .assertTrue (np .allclose (lat , self .lat , atol = 1e-2 ))
495495
496- def test_time_axis (self ):
497- uf = np .transpose (
498- np .reshape (np .tile (self .uf , 4 ), (2 , 2 , self .uf .size )), (0 , 2 , 1 )
499- )
500- vf = np .transpose (
501- np .reshape (np .tile (self .vf , 4 ), (2 , 2 , self .vf .size )), (0 , 2 , 1 )
502- )
503- time = np .transpose (
504- np .reshape (np .tile (self .time , 4 ), (2 , 2 , self .time .size )), (0 , 2 , 1 )
505- )
506- expected_lon = np .transpose (
507- np .reshape (np .tile (self .lon , 4 ), (2 , 2 , self .lon .size )), (0 , 2 , 1 )
508- )
509- expected_lat = np .transpose (
510- np .reshape (np .tile (self .lat , 4 ), (2 , 2 , self .lat .size )), (0 , 2 , 1 )
511- )
512- lon , lat = position_from_velocity (
513- uf ,
514- vf ,
515- time ,
516- self .lon [0 ],
517- self .lat [0 ],
518- integration_scheme = "forward" ,
519- time_axis = 1 ,
520- )
521- self .assertTrue (np .allclose (lon , expected_lon ))
522- self .assertTrue (np .allclose (lat , expected_lat ))
523- self .assertTrue (np .all (lon .shape == expected_lon .shape ))
524- self .assertTrue (np .all (lat .shape == expected_lat .shape ))
525-
526496 def test_works_with_xarray (self ):
527497 lon , lat = position_from_velocity (
528498 xr .DataArray (data = self .uf ),
@@ -554,6 +524,36 @@ def test_works_with_2d_array(self):
554524 self .assertTrue (np .allclose (lon .shape , expected_lon .shape ))
555525 self .assertTrue (np .allclose (lon .shape , expected_lat .shape ))
556526
527+ def test_time_axis (self ):
528+ uf = np .reshape (np .tile (self .uf , 6 ), (2 , 3 , self .uf .size ))
529+ vf = np .reshape (np .tile (self .vf , 6 ), (2 , 3 , self .vf .size ))
530+ time = np .reshape (np .tile (self .time , 6 ), (2 , 3 , self .time .size ))
531+ expected_lon = np .reshape (np .tile (self .lon , 6 ), (2 , 3 , self .lon .size ))
532+ expected_lat = np .reshape (np .tile (self .lat , 6 ), (2 , 3 , self .lat .size ))
533+
534+ for time_axis in [0 , 1 , 2 ]:
535+ # Pass inputs with swapped axes and differentiate along that time
536+ # axis.
537+ lon , lat = position_from_velocity (
538+ np .swapaxes (uf , time_axis , - 1 ),
539+ np .swapaxes (vf , time_axis , - 1 ),
540+ np .swapaxes (time , time_axis , - 1 ),
541+ self .lon [0 ],
542+ self .lat [0 ],
543+ integration_scheme = "forward" ,
544+ time_axis = time_axis ,
545+ )
546+
547+ # Swap axes back to compare with the expected result.
548+ self .assertTrue (np .allclose (np .swapaxes (lon , time_axis , - 1 ), expected_lon ))
549+ self .assertTrue (np .allclose (np .swapaxes (lat , time_axis , - 1 ), expected_lat ))
550+ self .assertTrue (
551+ np .all (np .swapaxes (lon , time_axis , - 1 ).shape == expected_lon .shape )
552+ )
553+ self .assertTrue (
554+ np .all (np .swapaxes (lat , time_axis , - 1 ).shape == expected_lat .shape )
555+ )
556+
557557
558558class velocity_from_position_tests (unittest .TestCase ):
559559 def setUp (self ):
@@ -621,26 +621,31 @@ def test_works_with_3d_array(self):
621621 self .assertTrue (np .all (vf .shape == expected_vf .shape ))
622622
623623 def test_time_axis (self ):
624- lon = np .transpose (
625- np .reshape (np .tile (self .lon , 4 ), (2 , 2 , self .lon .size )), (0 , 2 , 1 )
626- )
627- lat = np .transpose (
628- np .reshape (np .tile (self .lat , 4 ), (2 , 2 , self .lat .size )), (0 , 2 , 1 )
629- )
630- time = np .transpose (
631- np .reshape (np .tile (self .time , 4 ), (2 , 2 , self .time .size )), (0 , 2 , 1 )
632- )
633- expected_uf = np .transpose (
634- np .reshape (np .tile (self .uf , 4 ), (2 , 2 , self .uf .size )), (0 , 2 , 1 )
635- )
636- expected_vf = np .transpose (
637- np .reshape (np .tile (self .vf , 4 ), (2 , 2 , self .vf .size )), (0 , 2 , 1 )
638- )
639- uf , vf = velocity_from_position (lon , lat , time , time_axis = 1 )
640- self .assertTrue (np .all (uf == expected_uf ))
641- self .assertTrue (np .all (vf == expected_vf ))
642- self .assertTrue (np .all (uf .shape == expected_uf .shape ))
643- self .assertTrue (np .all (vf .shape == expected_vf .shape ))
624+ lon = np .reshape (np .tile (self .lon , 6 ), (2 , 3 , self .lon .size ))
625+ lat = np .reshape (np .tile (self .lat , 6 ), (2 , 3 , self .lat .size ))
626+ time = np .reshape (np .tile (self .time , 6 ), (2 , 3 , self .time .size ))
627+ expected_uf = np .reshape (np .tile (self .uf , 6 ), (2 , 3 , self .uf .size ))
628+ expected_vf = np .reshape (np .tile (self .vf , 6 ), (2 , 3 , self .vf .size ))
629+
630+ for time_axis in [0 , 1 , 2 ]:
631+ # Pass inputs with swapped axes and differentiate along that time
632+ # axis.
633+ uf , vf = velocity_from_position (
634+ np .swapaxes (lon , time_axis , - 1 ),
635+ np .swapaxes (lat , time_axis , - 1 ),
636+ np .swapaxes (time , time_axis , - 1 ),
637+ time_axis = time_axis ,
638+ )
639+
640+ # Swap axes back to compare with the expected result.
641+ self .assertTrue (np .all (np .swapaxes (uf , time_axis , - 1 ) == expected_uf ))
642+ self .assertTrue (np .all (np .swapaxes (vf , time_axis , - 1 ) == expected_vf ))
643+ self .assertTrue (
644+ np .all (np .swapaxes (uf , time_axis , - 1 ).shape == expected_uf .shape )
645+ )
646+ self .assertTrue (
647+ np .all (np .swapaxes (vf , time_axis , - 1 ).shape == expected_uf .shape )
648+ )
644649
645650
646651class apply_ragged_tests (unittest .TestCase ):
0 commit comments