Skip to content

Commit 3fc07b5

Browse files
authored
Fix torchscript related test failures. (#4069)
1 parent 190c3ea commit 3fc07b5

5 files changed

Lines changed: 41 additions & 11 deletions

File tree

.github/scripts/unittest-linux/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ fi
2929
export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_pytorch_lightning=true
3030
export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MULTIGPU_CUDA=true
3131
cd test
32-
pytest torchaudio_unittest -k "not torchscript and not fairseq and not demucs ${PYTEST_K_EXTRA}"
32+
pytest torchaudio_unittest -k "not fairseq and not demucs ${PYTEST_K_EXTRA}"
3333
)

.github/scripts/unittest-windows/run_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ env | grep TORCHAUDIO || true
1313

1414
cd test
1515
if [ -z "${CUDA_VERSION:-}" ] ; then
16-
pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not torchscript and not fairseq and not demucs and not librosa"
16+
pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not fairseq and not demucs and not librosa"
1717
else
18-
pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not cpu and (cuda or gpu) and not torchscript and not fairseq and not demucs and not librosa"
18+
pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not cpu and (cuda or gpu) and not fairseq and not demucs and not librosa"
1919
fi
2020
coverage html

src/torchaudio/functional/filtering.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946946
b_coeff_flipped = b_coeffs.flip(1).contiguous()
947947
padded_waveform = F.pad(waveform, (n_order - 1, 0))
948948
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
949-
ctx.save_for_backward(waveform, b_coeffs, output)
949+
if not torch.jit.is_scripting():
950+
ctx.save_for_backward(waveform, b_coeffs, output)
950951
return output
951952

952953
@staticmethod
@@ -955,6 +956,7 @@ def backward(ctx, dy):
955956
n_batch = x.size(0)
956957
n_channel = x.size(1)
957958
n_order = b_coeffs.size(1)
959+
958960
db = (
959961
F.conv1d(
960962
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
@@ -970,6 +972,13 @@ def backward(ctx, dy):
970972
dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None
971973
return (dx, db)
972974

975+
@staticmethod
976+
def ts_apply(waveform, b_coeffs):
977+
if torch.jit.is_scripting():
978+
return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs)
979+
else:
980+
return DifferentiableFIR.apply(waveform, b_coeffs)
981+
973982

974983
class DifferentiableIIR(torch.autograd.Function):
975984
@staticmethod
@@ -984,7 +993,8 @@ def forward(ctx, waveform, a_coeffs_normalized):
984993
)
985994
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
986995
output = padded_output_waveform[:, :, n_order - 1 :]
987-
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
996+
if not torch.jit.is_scripting():
997+
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
988998
return output
989999

9901000
@staticmethod
@@ -1006,10 +1016,17 @@ def backward(ctx, dy):
10061016
)
10071017
return (dx, da)
10081018

1019+
@staticmethod
1020+
def ts_apply(waveform, a_coeffs_normalized):
1021+
if torch.jit.is_scripting():
1022+
return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized)
1023+
else:
1024+
return DifferentiableIIR.apply(waveform, a_coeffs_normalized)
1025+
10091026

10101027
def _lfilter(waveform, a_coeffs, b_coeffs):
1011-
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1012-
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
1028+
filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1029+
return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
10131030

10141031

10151032
def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:

src/torchaudio/functional/functional.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,8 @@ def mask_along_axis_iid(
847847

848848
if axis not in [dim - 2, dim - 1]:
849849
raise ValueError(
850-
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
850+
"Only Frequency and Time masking are supported"
851+
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
851852
)
852853

853854
if not 0.0 <= p <= 1.0:
@@ -919,7 +920,8 @@ def mask_along_axis(
919920

920921
if axis not in [dim - 2, dim - 1]:
921922
raise ValueError(
922-
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
923+
"Only Frequency and Time masking are supported"
924+
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
923925
)
924926

925927
if not 0.0 <= p <= 1.0:
@@ -1731,6 +1733,16 @@ def backward(ctx, dy):
17311733
result = grad * grad_out
17321734
return (result, None, None, None, None, None, None, None)
17331735

1736+
@staticmethod
1737+
def ts_apply(logits, targets, logit_lengths, target_lengths, blank: int, clamp: float, fused_log_softmax: bool):
1738+
if torch.jit.is_scripting():
1739+
output, saved = torch.ops.torchaudio.rnnt_loss_forward(
1740+
logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax
1741+
)
1742+
return output
1743+
else:
1744+
return RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
1745+
17341746

17351747
def rnnt_loss(
17361748
logits: Tensor,
@@ -1774,7 +1786,7 @@ def rnnt_loss(
17741786
if blank < 0: # reinterpret blank index if blank < 0.
17751787
blank = logits.shape[-1] + blank
17761788

1777-
costs = RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
1789+
costs = RnntLoss.ts_apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
17781790

17791791
if reduction == "mean":
17801792
return costs.mean()

src/torchaudio/transforms/_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0
12021202
specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p
12031203
)
12041204
else:
1205-
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p)
1205+
mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value
1206+
return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p)
12061207

12071208

12081209
class FrequencyMasking(_AxisMasking):

0 commit comments

Comments
 (0)