Skip to content

Commit 9f25fc3

Browse files
committed
fix(omnivoice): remove erroneous transpose in OmniVoiceConv1d
OmniVoiceConv1d initialized weights as [out,k,in] (MLX format) but then transposed them to [out,in,k] inside callAsFunction before passing to MLX.conv1d. This permuted every conv1d weight in the vocoder encoder/decoder, causing pure noise output even with correctly-loaded checkpoint weights. Remove the spurious transpose so the weight layout matches what checkpoint and MLX.conv1d both expect.
1 parent 5b6b6ca commit 9f25fc3

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

Sources/MLXAudioTTS/Models/OmniVoice/OmniVoice.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,10 +1015,11 @@ final class OmniVoiceConv1d: Module {
10151015
}
10161016

10171017
func callAsFunction(_ x: MLXArray) -> MLXArray {
1018-
// Weight stored as [out, kernel, in] (MLX format); data flows in NCL [B, C, L]
1019-
// MLX.conv1d expects NLC [B, L, C] input and [out, k, in] weight
1018+
// Weight stored as [out, in, kernel] (PyTorch format) → transpose to [out, kernel, in] (MLX)
1019+
let w = weight.transposed(0, 2, 1)
1020+
// Data flows in NCL [B, C, L]; transpose to NLC for MLX conv1d, then back
10201021
let xNLC = x.transposed(0, 2, 1)
1021-
var h = MLX.conv1d(xNLC, weight, stride: strideVal, padding: paddingVal)
1022+
var h = MLX.conv1d(xNLC, w, stride: strideVal, padding: paddingVal)
10221023
if let b = bias {
10231024
let n = b.size
10241025
h = h + b.reshaped([n])

0 commit comments

Comments
 (0)