@@ -612,48 +612,91 @@ def test_gradient_ekman_case(self):
612612
613613 for i in range (len (self .delta )):
614614 for j in range (len (self .bld )):
615- wind_transfer_init [i , j ], dG_ddelta [i , j ], dG_dbld [i , j ] = (
616- wind_transfer (
617- omega = omega ,
618- z = z ,
619- cor_freq = cor_freq ,
620- delta = self .delta [i ],
621- mu = mu ,
622- bld = self .bld [j ],
623- )
615+ results = wind_transfer (
616+ omega = omega ,
617+ z = z ,
618+ cor_freq = cor_freq ,
619+ delta = self .delta [i ],
620+ mu = mu ,
621+ bld = self .bld [j ],
622+ )
623+ wind_transfer_init [i , j ] = (
624+ results [0 ].item ()
625+ if isinstance (results [0 ], np .ndarray )
626+ else results [0 ]
624627 )
625- wind_transfer_ddelta_plus [i , j ], _ , _ = wind_transfer (
628+ dG_ddelta [i , j ] = (
629+ results [1 ].item ()
630+ if isinstance (results [1 ], np .ndarray )
631+ else results [1 ]
632+ )
633+ dG_dbld [i , j ] = (
634+ results [2 ].item ()
635+ if isinstance (results [2 ], np .ndarray )
636+ else results [2 ]
637+ )
638+ # wind_transfer_init[i, j], dG_ddelta[i, j], dG_dbld[i, j] = (
639+ # wind_transfer(
640+ # omega=omega,
641+ # z=z,
642+ # cor_freq=cor_freq,
643+ # delta=self.delta[i],
644+ # mu=mu,
645+ # bld=self.bld[j],
646+ # )
647+ # )
648+ results = wind_transfer (
626649 omega = omega ,
627650 z = z ,
628651 cor_freq = cor_freq ,
629652 delta = self .delta [i ] + delta_delta / 2 ,
630653 mu = mu ,
631654 bld = self .bld [j ],
632655 )
633- wind_transfer_ddelta_minus [i , j ], _ , _ = wind_transfer (
656+ wind_transfer_ddelta_plus [i , j ] = (
657+ results [0 ].item ()
658+ if isinstance (results [0 ], np .ndarray )
659+ else results [0 ]
660+ )
661+ results = wind_transfer (
634662 omega = omega ,
635663 z = z ,
636664 cor_freq = cor_freq ,
637665 delta = self .delta [i ] - delta_delta / 2 ,
638666 mu = mu ,
639667 bld = self .bld [j ],
640668 )
641- wind_transfer_dbld_plus [i , j ], _ , _ = wind_transfer (
669+ wind_transfer_ddelta_minus [i , j ] = (
670+ results [0 ].item ()
671+ if isinstance (results [0 ], np .ndarray )
672+ else results [0 ]
673+ )
674+ results = wind_transfer (
642675 omega = omega ,
643676 z = z ,
644677 cor_freq = cor_freq ,
645678 delta = self .delta [i ],
646679 mu = mu ,
647680 bld = self .bld [j ] + delta_bld / 2 ,
648681 )
649- wind_transfer_dbld_minus [i , j ], _ , _ = wind_transfer (
682+ wind_transfer_dbld_plus [i , j ] = (
683+ results [0 ].item ()
684+ if isinstance (results [0 ], np .ndarray )
685+ else results [0 ]
686+ )
687+ results = wind_transfer (
650688 omega = omega ,
651689 z = z ,
652690 cor_freq = cor_freq ,
653691 delta = self .delta [i ],
654692 mu = mu ,
655693 bld = self .bld [j ] - delta_bld / 2 ,
656694 )
695+ wind_transfer_dbld_minus [i , j ] = (
696+ results [0 ].item ()
697+ if isinstance (results [0 ], np .ndarray )
698+ else results [0 ]
699+ )
657700
658701 dG_ddelta_fd = (
659702 wind_transfer_ddelta_plus - wind_transfer_ddelta_minus
0 commit comments