Skip to content

Commit eab3338

Browse files
juntaoclaude
andcommitted
Fix TCH backend: restore RoPE, interleaved pairs, BF16 computation
The TCH (libtorch) backend was producing degenerate audio (stuck semantic codes, 11% prefill norm divergence from MLX). Three root causes fixed: 1. **Restore accidentally deleted RoPE application** — the line `rotary_emb.forward_at_pos(&q, &k, pos, seq_len)` was removed during a prior edit, causing the model to have zero positional information. This was the primary cause of degenerate output. 2. **Interleaved RoPE convention** — Mistral native checkpoints use interleaved pairs (x[2d], x[2d+1]), not split-half (x[d], x[d+dim/2]). Rewrote apply_rotary_emb() to reshape→select→rotate→stack. 3. **Native BF16 computation** — libtorch 2.7+ supports BF16 matmul on Apple Silicon CPU. Keep weights in BF16 (~180x faster than F32), with Linear casting weight to match input dtype and softmax preserving input dtype. After fix: prefill norm=116.19 (MLX: 116.60), frame 0 semantic=7843 on both backends. Generates real speech audio. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b2105c9 commit eab3338

6 files changed

Lines changed: 232 additions & 43 deletions

File tree

CLAUDE.md

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,26 @@ Development notes, architecture decisions, and lessons learned during the port o
55
## Build Commands
66

77
```bash
8-
# macOS (MLX)
8+
# macOS (MLX — Apple Silicon GPU)
99
git submodule update --init --recursive
1010
cargo build --release --no-default-features --features mlx
1111

12-
# Linux (libtorch)
12+
# macOS (libtorch — CPU, for testing/development)
13+
# Download libtorch 2.7.x (must match tch crate version — tch 0.20 requires 2.7)
14+
curl -Lo libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.7.1.zip
15+
unzip libtorch.zip
1316
export LIBTORCH=$(pwd)/libtorch
17+
export LIBTORCH_BYPASS_VERSION_CHECK=1
1418
cargo build --release
1519

20+
# Linux (libtorch — CPU or CUDA GPU)
21+
export LIBTORCH=$(pwd)/libtorch
22+
cargo build --release
23+
24+
# Run (libtorch needs library path)
25+
DYLD_LIBRARY_PATH=$LIBTORCH/lib ./target/release/voxtral-tts ... # macOS
26+
LD_LIBRARY_PATH=$LIBTORCH/lib ./target/release/voxtral-tts ... # Linux
27+
1628
# Run tests
1729
cargo test
1830

@@ -104,17 +116,19 @@ MLX builds a computation graph lazily. The graph must be evaluated periodically
104116
**Symptom of too few eval():** Graph grows across iterations, causing exponential slowdowns.
105117
**Symptom of too many eval():** Each eval() has fixed overhead for graph traversal, scheduling, and GPU synchronization. Reducing from ~130 to ~8 eval() calls per frame improved flow matching from 0.53s to 0.28s per frame.
106118

107-
### 2. RoPE Must Use Split-Half (traditional=true)
119+
### 2. RoPE Must Use Interleaved Pairs (traditional=true)
108120

109121
MLX `fast_rope` has a `traditional` parameter controlling how dimension pairs are formed:
110-
- `traditional=true` (split-half): pairs dimension `d` with `d + dim/2`**correct for Llama/Mistral**
111-
- `traditional=false` (interleaved): pairs `2d` with `2d+1`
122+
- `traditional=true` (**interleaved**): pairs consecutive dimensions `(x[2d], x[2d+1])`**correct for Mistral native checkpoints**
123+
- `traditional=false` (split-half): pairs `(x[d], x[d + dim/2])`
124+
125+
Mistral's native safetensors checkpoint uses the interleaved convention. HuggingFace/vLLM internally permute Q/K weights to use split-half, but our weights are in native format.
112126

113-
Llama/Mistral models require split-half format. Using interleaved format corrupts all attention computations, causing the backbone hidden states to be completely wrong.
127+
The tch backend must match this exactly — reshape to `[..., D/2, 2]`, select even/odd on the last dim, apply rotation, stack back and reshape. See `apply_rotary_emb()` in `layers.rs`.
114128

115-
**Symptom:** Backbone hidden states have wrong norms, semantic code is always the same value (e.g., 10), END_AUDIO is never predicted. All weights are correct but outputs diverge at Layer 0.
129+
**Symptom:** Backbone hidden states have wrong norms, semantic code stuck at one value (e.g., 855 or 10), END_AUDIO never predicted. All weights are correct but outputs diverge at Layer 0.
116130

117-
**Discovery method:** Compare layer-by-layer outputs against mlx-audio reference implementation. The divergence appears immediately at Layer 0 attention output when RoPE is wrong.
131+
**Discovery method:** Compare layer-by-layer outputs against mlx-audio reference. Divergence appears immediately at Layer 0 attention output when RoPE convention is wrong.
118132

119133
### 3. KV Cache Must Replace, Not Concatenate
120134

@@ -147,6 +161,44 @@ The `tensor.rs` conv methods handle these transposes automatically. The weights
147161

148162
`Device::best_available()` must call `init_mlx(true)` before any MLX operations. Without this, all MLX calls panic with "MLX not initialized".
149163

164+
## TCH Backend (libtorch) -- Critical Lessons
165+
166+
### 1. libtorch Version Must Match tch Crate
167+
168+
The `tch` Rust crate pins to a specific PyTorch major version. `tch 0.20` requires **libtorch 2.7.x**. Using an older version (e.g., 2.4, 2.6) causes C++ compilation errors in `torch-sys` (`no member named '_dyn_quant_matmul_4bit'`, etc.). Set `LIBTORCH_BYPASS_VERSION_CHECK=1` for patch version mismatches (e.g., 2.7.1 vs 2.7.0).
169+
170+
### 2. BF16 Matmul Works on CPU (libtorch 2.7+)
171+
172+
libtorch 2.7+ supports BF16 matmul on Apple Silicon ARM64 CPU. Weights can stay in BF16 (as loaded from safetensors) without converting to F32. This is ~180x faster than F32 matmul on CPU for this model.
173+
174+
Key implementation details:
175+
- `Linear::forward` casts weight to match input dtype, so BF16 input → BF16 matmul, F32 input → F32 matmul
176+
- `RMSNorm::forward` computes variance in F32 for stability, then casts output back to input dtype
177+
- `softmax()` computes in F32 internally, then casts back to input dtype (not hardcoded to F32 output)
178+
- Voice embeddings and all model weights are kept in BF16 (no `need_f32` conversion)
179+
180+
### 3. RoPE Application Must Not Be Accidentally Deleted
181+
182+
The single most critical line in `Attention::forward` is:
183+
```rust
184+
let (q, k) = rotary_emb.forward_at_pos(&q, &k, pos, seq_len);
185+
```
186+
This must appear after Q/K reshape to multi-head format and before KV cache concatenation. Without this line, the model has no positional information and produces degenerate output (stuck semantic codes, wrong hidden state norms).
187+
188+
**This line was accidentally deleted during an Edit operation and took extensive debugging to find.** The symptom (11% prefill norm divergence, degenerate codes) looked like a precision issue but was actually a missing computation.
189+
190+
### 4. CUDA GPU Support
191+
192+
The TCH backend supports CUDA GPUs with zero code changes. `Device::best_available()` auto-detects CUDA via `tch::Cuda::is_available()`, and `Device::Gpu(i)` maps to `tch::Device::Cuda(i)`. All tensor operations (matmul, softmax, RoPE, etc.) work identically on CUDA. Download the CUDA variant of libtorch:
193+
```
194+
# CUDA 12.6 example
195+
curl -Lo libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-cxx11-abi-shared-with-deps-2.7.1%2Bcu126-linux-x86_64.zip
196+
```
197+
198+
### 5. Diagnostic Logging Considerations
199+
200+
On GPU (CUDA or MLX), `to_vec_f32()` calls in diagnostic logging trigger device→CPU copies. Gate expensive logging behind `tracing::enabled!(Level::DEBUG)` or keep only at key checkpoints (prefill output, first decode step).
201+
150202
## Special Token IDs
151203

152204
| Token | ID | Usage |
@@ -245,6 +297,14 @@ The codec transformer layers have three features not present in the backbone. Al
245297

246298
## Performance Notes
247299

300+
On Apple M4 Max (TCH backend, BF16 CPU):
301+
- Model loading: ~18s (full tensor copy, no memory mapping)
302+
- Prefill (225 tokens, 26 layers): ~13s
303+
- Per-frame generation: ~0.55s (backbone + flow matching)
304+
- Codec decoding: ~0.2s
305+
- "Hello." (24 frames, 1.92s audio): ~30s total
306+
- "The quick brown fox..." (36 frames, 2.88s audio): ~30s total
307+
248308
On Apple M4 Max (MLX backend):
249309
- Model loading: ~0.1s (memory-mapped safetensors)
250310
- Prefill (225 tokens, 26 layers): ~3.3s

src/model/backbone.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
use std::collections::HashMap;
1111

1212
use crate::config::VoxtralConfig;
13-
use crate::tensor::{Device, Tensor};
13+
use crate::tensor::{DType, Device, Tensor};
1414

1515
use super::kv_cache::KVCache;
16+
1617
use super::layers::{RMSNorm, RotaryEmbedding, TransformerLayer};
1718

1819
// ---------------------------------------------------------------------------
@@ -180,6 +181,22 @@ impl Backbone {
180181
let (out, new_k, new_v) = layer.forward(&h, &self.rotary_emb, 0, kv_cache.get(i), true);
181182
kv_cache.update(i, new_k, new_v);
182183
h = out;
184+
185+
// Log per-layer norms during prefill for cross-backend comparison
186+
if i == 0 || i == 12 || i == 25 {
187+
// Sample the last position's values for comparison
188+
let last_pos = seq_len as i64 - 1;
189+
let vals = h.select(1, last_pos).squeeze_dim(0).to_vec_f32();
190+
let norm: f32 = vals.iter().map(|v| v * v).sum::<f32>().sqrt();
191+
tracing::info!(
192+
"Prefill layer {} last_pos: norm={:.4}, min={:.4}, max={:.4}, first3=[{:.4}, {:.4}, {:.4}]",
193+
i, norm,
194+
vals.iter().cloned().fold(f32::INFINITY, f32::min),
195+
vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
196+
vals.get(0).unwrap_or(&0.0), vals.get(1).unwrap_or(&0.0),
197+
vals.get(2).unwrap_or(&0.0),
198+
);
199+
}
183200
}
184201

185202
let h = self.norm.forward(&h);
@@ -189,7 +206,20 @@ impl Backbone {
189206
// Single eval materializes the entire 26-layer prefill graph at once,
190207
// allowing MLX to optimize the full computation on the GPU.
191208
out.eval();
192-
tracing::debug!("Backbone prefill done, output shape: {:?}", out.size());
209+
// Log prefill output norm for cross-backend comparison
210+
{
211+
let vals = out.to_vec_f32();
212+
let norm: f32 = vals.iter().map(|v| v * v).sum::<f32>().sqrt();
213+
tracing::info!(
214+
"Prefill output: norm={:.4}, min={:.4}, max={:.4}, first5=[{:.4}, {:.4}, {:.4}, {:.4}, {:.4}]",
215+
norm,
216+
vals.iter().cloned().fold(f32::INFINITY, f32::min),
217+
vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
218+
vals.get(0).unwrap_or(&0.0), vals.get(1).unwrap_or(&0.0),
219+
vals.get(2).unwrap_or(&0.0), vals.get(3).unwrap_or(&0.0),
220+
vals.get(4).unwrap_or(&0.0),
221+
);
222+
}
193223
out
194224
}
195225

@@ -202,13 +232,38 @@ impl Backbone {
202232
pub fn forward_one_embedding(&self, embedding: &Tensor, kv_cache: &mut KVCache) -> Tensor {
203233
let h = embedding.reshape(&[1, 1, self.config.dim as i64]);
204234
let pos = kv_cache.seq_len();
235+
// Log input embedding for the first few decode steps
236+
if pos <= 226 {
237+
let vals = embedding.to_vec_f32();
238+
let norm: f32 = vals.iter().map(|v| v * v).sum::<f32>().sqrt();
239+
tracing::info!(
240+
"Decode pos={}: input_norm={:.4}, first5=[{:.4}, {:.4}, {:.4}, {:.4}, {:.4}]",
241+
pos, norm,
242+
vals.get(0).unwrap_or(&0.0), vals.get(1).unwrap_or(&0.0),
243+
vals.get(2).unwrap_or(&0.0), vals.get(3).unwrap_or(&0.0),
244+
vals.get(4).unwrap_or(&0.0),
245+
);
246+
}
247+
// Log per-layer norms on first decode step (pos=225 for "Hello." with neutral_female)
248+
let log_layers = pos <= 226;
205249
let mut h = h;
206250

207251
for (i, layer) in self.layers.iter().enumerate() {
208252
let (out, new_k, new_v) =
209253
layer.forward(&h, &self.rotary_emb, pos, kv_cache.get(i), true);
210254
kv_cache.update(i, new_k, new_v);
211255
h = out;
256+
257+
if log_layers && (i == 0 || i == 12 || i == 25) {
258+
let vals = h.squeeze_dim(0).squeeze_dim(0).to_vec_f32();
259+
let norm: f32 = vals.iter().map(|v| v * v).sum::<f32>().sqrt();
260+
tracing::info!(
261+
"Layer {} output: norm={:.4}, min={:.4}, max={:.4}",
262+
i, norm,
263+
vals.iter().cloned().fold(f32::INFINITY, f32::min),
264+
vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
265+
);
266+
}
212267
}
213268

214269
let out = self.norm.forward(&h).squeeze_dim(0).squeeze_dim(0); // [dim]

0 commit comments

Comments
 (0)