Skip to content

Commit 09461c4

Browse files
authored
🐛 refactor wind_transfer calls in TransferFunctionTestGradient for compatibility with python 3.11-12 (#586)
* 🐛 refactor wind_transfer calls in TransferFunctionTestGradient for compatibility with python 3.11-12 * 🐛 fix variable name from 'result' to 'results' for consistency in TransferFunctionTestGradient
1 parent 8fda8d1 commit 09461c4

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

tests/transfer_test.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -612,48 +612,91 @@ def test_gradient_ekman_case(self):
612612

613613
for i in range(len(self.delta)):
614614
for j in range(len(self.bld)):
615-
wind_transfer_init[i, j], dG_ddelta[i, j], dG_dbld[i, j] = (
616-
wind_transfer(
617-
omega=omega,
618-
z=z,
619-
cor_freq=cor_freq,
620-
delta=self.delta[i],
621-
mu=mu,
622-
bld=self.bld[j],
623-
)
615+
results = wind_transfer(
616+
omega=omega,
617+
z=z,
618+
cor_freq=cor_freq,
619+
delta=self.delta[i],
620+
mu=mu,
621+
bld=self.bld[j],
622+
)
623+
wind_transfer_init[i, j] = (
624+
results[0].item()
625+
if isinstance(results[0], np.ndarray)
626+
else results[0]
624627
)
625-
wind_transfer_ddelta_plus[i, j], _, _ = wind_transfer(
628+
dG_ddelta[i, j] = (
629+
results[1].item()
630+
if isinstance(results[1], np.ndarray)
631+
else results[1]
632+
)
633+
dG_dbld[i, j] = (
634+
results[2].item()
635+
if isinstance(results[2], np.ndarray)
636+
else results[2]
637+
)
638+
# wind_transfer_init[i, j], dG_ddelta[i, j], dG_dbld[i, j] = (
639+
# wind_transfer(
640+
# omega=omega,
641+
# z=z,
642+
# cor_freq=cor_freq,
643+
# delta=self.delta[i],
644+
# mu=mu,
645+
# bld=self.bld[j],
646+
# )
647+
# )
648+
results = wind_transfer(
626649
omega=omega,
627650
z=z,
628651
cor_freq=cor_freq,
629652
delta=self.delta[i] + delta_delta / 2,
630653
mu=mu,
631654
bld=self.bld[j],
632655
)
633-
wind_transfer_ddelta_minus[i, j], _, _ = wind_transfer(
656+
wind_transfer_ddelta_plus[i, j] = (
657+
results[0].item()
658+
if isinstance(results[0], np.ndarray)
659+
else results[0]
660+
)
661+
results = wind_transfer(
634662
omega=omega,
635663
z=z,
636664
cor_freq=cor_freq,
637665
delta=self.delta[i] - delta_delta / 2,
638666
mu=mu,
639667
bld=self.bld[j],
640668
)
641-
wind_transfer_dbld_plus[i, j], _, _ = wind_transfer(
669+
wind_transfer_ddelta_minus[i, j] = (
670+
results[0].item()
671+
if isinstance(results[0], np.ndarray)
672+
else results[0]
673+
)
674+
results = wind_transfer(
642675
omega=omega,
643676
z=z,
644677
cor_freq=cor_freq,
645678
delta=self.delta[i],
646679
mu=mu,
647680
bld=self.bld[j] + delta_bld / 2,
648681
)
649-
wind_transfer_dbld_minus[i, j], _, _ = wind_transfer(
682+
wind_transfer_dbld_plus[i, j] = (
683+
results[0].item()
684+
if isinstance(results[0], np.ndarray)
685+
else results[0]
686+
)
687+
results = wind_transfer(
650688
omega=omega,
651689
z=z,
652690
cor_freq=cor_freq,
653691
delta=self.delta[i],
654692
mu=mu,
655693
bld=self.bld[j] - delta_bld / 2,
656694
)
695+
wind_transfer_dbld_minus[i, j] = (
696+
results[0].item()
697+
if isinstance(results[0], np.ndarray)
698+
else results[0]
699+
)
657700

658701
dG_ddelta_fd = (
659702
wind_transfer_ddelta_plus - wind_transfer_ddelta_minus

0 commit comments

Comments
 (0)