You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The TCH (libtorch) backend was producing degenerate audio (stuck semantic
codes, 11% prefill norm divergence from MLX). Three root causes fixed:
1. **Restore accidentally deleted RoPE application** — the line
`rotary_emb.forward_at_pos(&q, &k, pos, seq_len)` was removed during
a prior edit, causing the model to have zero positional information.
This was the primary cause of degenerate output.
2. **Interleaved RoPE convention** — Mistral native checkpoints use
interleaved pairs (x[2d], x[2d+1]), not split-half (x[d], x[d+dim/2]).
Rewrote apply_rotary_emb() to reshape→select→rotate→stack.
3. **Native BF16 computation** — libtorch 2.7+ supports BF16 matmul on
Apple Silicon CPU. Keep weights in BF16 (~180x faster than F32),
with Linear casting weight to match input dtype and softmax preserving
input dtype.
After fix: prefill norm=116.19 (MLX: 116.60), frame 0 semantic=7843 on
both backends. Generates real speech audio.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
LD_LIBRARY_PATH=$LIBTORCH/lib ./target/release/voxtral-tts ... # Linux
27
+
16
28
# Run tests
17
29
cargo test
18
30
@@ -104,17 +116,19 @@ MLX builds a computation graph lazily. The graph must be evaluated periodically
104
116
**Symptom of too few eval():** Graph grows across iterations, causing exponential slowdowns.
105
117
**Symptom of too many eval():** Each eval() has fixed overhead for graph traversal, scheduling, and GPU synchronization. Reducing from ~130 to ~8 eval() calls per frame improved flow matching from 0.53s to 0.28s per frame.
106
118
107
-
### 2. RoPE Must Use Split-Half (traditional=true)
119
+
### 2. RoPE Must Use Interleaved Pairs (traditional=true)
108
120
109
121
MLX `fast_rope` has a `traditional` parameter controlling how dimension pairs are formed:
110
-
-`traditional=true` (split-half): pairs dimension `d` with `d + dim/2` — **correct for Llama/Mistral**
111
-
-`traditional=false` (interleaved): pairs `2d` with `2d+1`
Mistral's native safetensors checkpoint uses the interleaved convention. HuggingFace/vLLM internally permute Q/K weights to use split-half, but our weights are in native format.
112
126
113
-
Llama/Mistral models require split-half format. Using interleaved format corrupts all attention computations, causing the backbone hidden states to be completely wrong.
127
+
The tch backend must match this exactly — reshape to `[..., D/2, 2]`, select even/odd on the last dim, apply rotation, stack back and reshape. See `apply_rotary_emb()` in `layers.rs`.
114
128
115
-
**Symptom:** Backbone hidden states have wrong norms, semantic code is always the same value (e.g., 10), END_AUDIO is never predicted. All weights are correct but outputs diverge at Layer 0.
129
+
**Symptom:** Backbone hidden states have wrong norms, semantic code stuck at one value (e.g., 855 or 10), END_AUDIO never predicted. All weights are correct but outputs diverge at Layer 0.
116
130
117
-
**Discovery method:** Compare layer-by-layer outputs against mlx-audio reference implementation. The divergence appears immediately at Layer 0 attention output when RoPE is wrong.
131
+
**Discovery method:** Compare layer-by-layer outputs against mlx-audio reference. Divergence appears immediately at Layer 0 attention output when RoPE convention is wrong.
118
132
119
133
### 3. KV Cache Must Replace, Not Concatenate
120
134
@@ -147,6 +161,44 @@ The `tensor.rs` conv methods handle these transposes automatically. The weights
147
161
148
162
`Device::best_available()` must call `init_mlx(true)` before any MLX operations. Without this, all MLX calls panic with "MLX not initialized".
149
163
164
+
## TCH Backend (libtorch) -- Critical Lessons
165
+
166
+
### 1. libtorch Version Must Match tch Crate
167
+
168
+
The `tch` Rust crate pins to a specific PyTorch major version. `tch 0.20` requires **libtorch 2.7.x**. Using an older version (e.g., 2.4, 2.6) causes C++ compilation errors in `torch-sys` (`no member named '_dyn_quant_matmul_4bit'`, etc.). Set `LIBTORCH_BYPASS_VERSION_CHECK=1` for patch version mismatches (e.g., 2.7.1 vs 2.7.0).
169
+
170
+
### 2. BF16 Matmul Works on CPU (libtorch 2.7+)
171
+
172
+
libtorch 2.7+ supports BF16 matmul on Apple Silicon ARM64 CPU. Weights can stay in BF16 (as loaded from safetensors) without converting to F32. This is ~180x faster than F32 matmul on CPU for this model.
173
+
174
+
Key implementation details:
175
+
-`Linear::forward` casts weight to match input dtype, so BF16 input → BF16 matmul, F32 input → F32 matmul
176
+
-`RMSNorm::forward` computes variance in F32 for stability, then casts output back to input dtype
177
+
-`softmax()` computes in F32 internally, then casts back to input dtype (not hardcoded to F32 output)
178
+
- Voice embeddings and all model weights are kept in BF16 (no `need_f32` conversion)
179
+
180
+
### 3. RoPE Application Must Not Be Accidentally Deleted
181
+
182
+
The single most critical line in `Attention::forward` is:
183
+
```rust
184
+
let (q, k) =rotary_emb.forward_at_pos(&q, &k, pos, seq_len);
185
+
```
186
+
This must appear after Q/K reshape to multi-head format and before KV cache concatenation. Without this line, the model has no positional information and produces degenerate output (stuck semantic codes, wrong hidden state norms).
187
+
188
+
**This line was accidentally deleted during an Edit operation and took extensive debugging to find.** The symptom (11% prefill norm divergence, degenerate codes) looked like a precision issue but was actually a missing computation.
189
+
190
+
### 4. CUDA GPU Support
191
+
192
+
The TCH backend supports CUDA GPUs with zero code changes. `Device::best_available()` auto-detects CUDA via `tch::Cuda::is_available()`, and `Device::Gpu(i)` maps to `tch::Device::Cuda(i)`. All tensor operations (matmul, softmax, RoPE, etc.) work identically on CUDA. Download the CUDA variant of libtorch:
On GPU (CUDA or MLX), `to_vec_f32()` calls in diagnostic logging trigger device→CPU copies. Gate expensive logging behind `tracing::enabled!(Level::DEBUG)` or keep only at key checkpoints (prefill output, first decode step).
201
+
150
202
## Special Token IDs
151
203
152
204
| Token | ID | Usage |
@@ -245,6 +297,14 @@ The codec transformer layers have three features not present in the backbone. Al
245
297
246
298
## Performance Notes
247
299
300
+
On Apple M4 Max (TCH backend, BF16 CPU):
301
+
- Model loading: ~18s (full tensor copy, no memory mapping)
0 commit comments