Development notes, architecture decisions, and lessons learned during the port of Voxtral-4B-TTS from Python to Rust.
# macOS (MLX — Apple Silicon GPU)
git submodule update --init --recursive
cargo build --release --no-default-features --features mlx
# macOS (libtorch — CPU, for testing/development)
# Download libtorch 2.7.x (must match tch crate version — tch 0.20 requires 2.7)
curl -Lo libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.7.1.zip
unzip libtorch.zip
export LIBTORCH=$(pwd)/libtorch
export LIBTORCH_BYPASS_VERSION_CHECK=1
cargo build --release
# Linux (libtorch — CPU or CUDA GPU)
export LIBTORCH=$(pwd)/libtorch
cargo build --release
# Run (libtorch needs library path)
DYLD_LIBRARY_PATH=$LIBTORCH/lib ./target/release/voxtral-tts ... # macOS
LD_LIBRARY_PATH=$LIBTORCH/lib ./target/release/voxtral-tts ... # Linux
# Run tests
cargo test
# Run with debug logging
RUST_LOG=debug ./target/release/voxtral-tts models/voxtral-4b-tts --text "Hello." --voice neutral_female --output output.wavThe safetensors checkpoint uses Mistral-style weight naming, not HuggingFace convention. This is the most common source of load errors.
| Weight | Key pattern |
|---|---|
| Token embeddings | mm_audio_embeddings.tok_embeddings.weight [131072, 3072] |
| Audio codebook embeddings | mm_audio_embeddings.audio_codebook_embeddings.embeddings.weight [9088, 3072] |
| Layer attention Q/K/V/O | layers.{i}.attention.wq.weight, .wk.weight, .wv.weight, .wo.weight |
| Layer FFN gate/down/up | layers.{i}.feed_forward.w1.weight, .w2.weight, .w3.weight |
| Layer norms | layers.{i}.attention_norm.weight, layers.{i}.ffn_norm.weight |
| Final norm | norm.weight |
| LM head | Absent -- tied to tok_embeddings.weight |
Prefix: acoustic_transformer. (NOT multimodal.acoustic_transformer.)
| Weight | Key |
|---|---|
| Input projection | acoustic_transformer.input_projection.weight [3072, 36] |
| LLM projection | acoustic_transformer.llm_projection.weight [3072, 3072] |
| Time projection | acoustic_transformer.time_projection.weight [3072, 3072] |
| Semantic output | acoustic_transformer.semantic_codebook_output.weight [8320, 3072] |
| Acoustic output | acoustic_transformer.acoustic_codebook_output.weight [36, 3072] |
| Layers | acoustic_transformer.layers.{i}.attention.{wq,wk,wv,wo}.weight |
Prefix: audio_tokenizer. (NOT multimodal.audio_tokenizer.)
Convolutions use weight normalization with parametrizations.weight.original0 (g, magnitude [out, 1, 1]) and parametrizations.weight.original1 (v, direction [out, in, kernel]).
| Weight | Key pattern |
|---|---|
| Decoder conv blocks | audio_tokenizer.decoder_blocks.{i}.conv.parametrizations.weight.original{0,1} |
| Decoder transformer layers | audio_tokenizer.decoder_blocks.{i}.layers.{j}.attention.{wq,wk,wv,wo}.weight |
| Layer scale | audio_tokenizer.decoder_blocks.{i}.layers.{j}.attention_scale, .ffn_scale |
| QK norm | audio_tokenizer.decoder_blocks.{i}.layers.{j}.attention.q_norm.weight, .k_norm.weight |
| Semantic codebook (EMA) | audio_tokenizer.quantizer.semantic_codebook.embedding_sum [8192, 256] |
| Cluster usage (EMA) | audio_tokenizer.quantizer.semantic_codebook.cluster_usage [8192] |
| Output projection | audio_tokenizer.output_proj.conv.parametrizations.weight.original{0,1} |
Decoder block layout (8 blocks total):
- Even indices (0, 2, 4, 6): Conv blocks
- Odd indices (1, 3, 5, 7): Transformer blocks (2 layers each)
params.json stores decoder config as comma-separated strings, not JSON arrays:
{
"decoder_convs_strides_str": "1,2,2,2",
"decoder_convs_kernels_str": "3,4,4,4",
"decoder_transformer_lengths_str": "2,2,2,2"
}The code must parse these with resolve_str_fields() after serde deserialization. Without this, the codec decoder loads with 0 blocks and the dequantized latent passes straight through to the output projection, causing a channel mismatch.
MLX builds a computation graph lazily. The graph must be evaluated periodically to prevent unbounded growth, but over-evaluating kills performance. Per the MLX documentation and mlx-lm reference implementation, eval() should be called at outer loop boundaries, not per-layer.
Where eval() is required (current optimized placement):
- After the full 26-layer backbone forward pass in
forward_one_embedding()— 1 eval per frame - After the full 26-layer backbone forward pass in
forward_prefill_embeddings()— 1 eval total - After each Euler step in
decode_acoustic()(flow matching ODE) — 7 evals per frame - After the full codec decoder (all 4 blocks) in
run_decoder()— 1 eval per decode
Where eval() is NOT needed:
- Per transformer layer (the graph for 26 layers is fine — "thousands of ops" per eval is OK)
- Per flow matching
predict_velocity()call (only 3 layers on 3 tokens — tiny graph) - Per codec conv/transformer block
Symptom of too few eval(): Graph grows across iterations, causing exponential slowdowns. Symptom of too many eval(): Each eval() has fixed overhead for graph traversal, scheduling, and GPU synchronization. Reducing from ~130 to ~8 eval() calls per frame improved flow matching from 0.53s to 0.28s per frame.
MLX fast_rope has a traditional parameter controlling how dimension pairs are formed:
traditional=true(interleaved): pairs consecutive dimensions(x[2d], x[2d+1])— correct for Mistral native checkpointstraditional=false(split-half): pairs(x[d], x[d + dim/2])
Mistral's native safetensors checkpoint uses the interleaved convention. HuggingFace/vLLM internally permute Q/K weights to use split-half, but our weights are in native format.
The tch backend must match this exactly — reshape to [..., D/2, 2], select even/odd on the last dim, apply rotation, stack back and reshape. See apply_rotary_emb() in layers.rs.
Symptom: Backbone hidden states have wrong norms, semantic code stuck at one value (e.g., 855 or 10), END_AUDIO never predicted. All weights are correct but outputs diverge at Layer 0.
Discovery method: Compare layer-by-layer outputs against mlx-audio reference. Divergence appears immediately at Layer 0 attention output when RoPE convention is wrong.
The attention layer already concatenates old cache + new K/V internally before returning. The KV cache update() method must replace the stored tensors, not concatenate again:
// CORRECT: simple replacement (attention already did the cat)
pub fn update(&mut self, layer_idx: usize, new_k: Tensor, new_v: Tensor) {
self.k_cache[layer_idx] = Some(new_k);
self.v_cache[layer_idx] = Some(new_v);
}
// WRONG: double-concatenation causes exponential cache growth
// 221 -> 443 -> 886 -> 1773 -> ...Symptom: Matmul shape errors with large unexpected dimensions (e.g., 113664 instead of 222).
| Operation | PyTorch | MLX |
|---|---|---|
| Conv1d input | [N, C, L] (NCL) |
[N, L, C] (NLC) |
| Conv1d weight | [C_out, C_in, K] |
[C_out, K, C_in] |
| ConvTranspose1d weight | [C_in, C_out, K] |
[C_out, K, C_in] |
The tensor.rs conv methods handle these transposes automatically. The weights from safetensors are in PyTorch format and get transposed before calling MLX ops.
Device::best_available() must call init_mlx(true) before any MLX operations. Without this, all MLX calls panic with "MLX not initialized".
The tch Rust crate pins to a specific PyTorch major version. tch 0.20 requires libtorch 2.7.x. Using an older version (e.g., 2.4, 2.6) causes C++ compilation errors in torch-sys (no member named '_dyn_quant_matmul_4bit', etc.). Set LIBTORCH_BYPASS_VERSION_CHECK=1 for patch version mismatches (e.g., 2.7.1 vs 2.7.0).
libtorch 2.7+ supports BF16 matmul on Apple Silicon ARM64 CPU. Weights can stay in BF16 (as loaded from safetensors) without converting to F32. This is ~180x faster than F32 matmul on CPU for this model.
Key implementation details:
Linear::forwardcasts weight to match input dtype, so BF16 input → BF16 matmul, F32 input → F32 matmulRMSNorm::forwardcomputes variance in F32 for stability, then casts output back to input dtypesoftmax()computes in F32 internally, then casts back to input dtype (not hardcoded to F32 output)- Voice embeddings and all model weights are kept in BF16 (no
need_f32conversion)
The single most critical line in Attention::forward is:
let (q, k) = rotary_emb.forward_at_pos(&q, &k, pos, seq_len);This must appear after Q/K reshape to multi-head format and before KV cache concatenation. Without this line, the model has no positional information and produces degenerate output (stuck semantic codes, wrong hidden state norms).
This line was accidentally deleted during an Edit operation and took extensive debugging to find. The symptom (11% prefill norm divergence, degenerate codes) looked like a precision issue but was actually a missing computation.
The TCH backend supports CUDA GPUs with zero code changes. Device::best_available() auto-detects CUDA via tch::Cuda::is_available(), and Device::Gpu(i) maps to tch::Device::Cuda(i). All tensor operations (matmul, softmax, RoPE, etc.) work identically on CUDA. Download the CUDA variant of libtorch:
# CUDA 12.6 example
curl -Lo libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-cxx11-abi-shared-with-deps-2.7.1%2Bcu126-linux-x86_64.zip
On GPU (CUDA or MLX), to_vec_f32() calls in diagnostic logging trigger device→CPU copies. Gate expensive logging behind tracing::enabled!(Level::DEBUG) or keep only at key checkpoints (prefill output, first decode step).
| Token | ID | Usage |
|---|---|---|
| BOS | 1 | Start of sequence |
| AUDIO | 24 | Fed after prefill to trigger first frame generation |
| BEGIN_AUDIO | 25 | Marks start of audio region (before voice embs, before generation) |
| REPEAT_AUDIO_TEXT | 35 | Marks end of text, before second BEGIN_AUDIO |
| NEXT_AUDIO_TEXT | 36 | Marks end of voice embs, before text tokens |
| EMPTY_AUDIO | 0 | Semantic code: never valid (masked to -1e9) |
| END_AUDIO | 1 | Semantic code: signals end of generation |
The authoritative Python reference is mlx-audio (mlx_audio.tts), specifically:
mlx_audio/tts/voxtral_tts/voxtral_tts.py— model class,generate(),_encode_text(),_build_input_embeddings()mlx_audio/tts/voxtral_tts/acoustic_head.py— flow matching transformer, Euler ODE, CFG
Install from git main (PyPI may lag): pip3 install git+https://github.com/Blaizzy/mlx-audio.git@main
[BOS(1)] [BEGIN_AUDIO(25)] [voice_emb_0, ..., voice_emb_N] [NEXT_AUDIO_TEXT(36)] [text_tok_0, ..., text_tok_M] [REPEAT_AUDIO_TEXT(35)] [BEGIN_AUDIO(25)]
- BOS, BEGIN_AUDIO, NEXT_AUDIO_TEXT, REPEAT_AUDIO_TEXT, final BEGIN_AUDIO: looked up in
tok_embeddings[131072, 3072] - Voice embeddings: pre-computed backbone hidden states [N, 3072], injected directly (not via embedding table)
- Text tokens: looked up in
tok_embeddings
After prefill, the AUDIO token (ID 24) is fed as the first decode step to produce the initial hidden state for frame generation.
- Semantic code: Cast hidden state to F32, then
semantic_codebook_output.forward(hidden_state_f32)-> argmax over [8320] logits- Code 0 = EMPTY_AUDIO (masked to -1e9, never predicted)
- Code 1 = END_AUDIO (left unmasked; signals stop when predicted)
- Valid semantic codes: [2, 8194), codes >= 8194 masked to -1e9
- F32 precision is required for the matmul (matches mlx-audio reference)
- 36 acoustic codes: Euler ODE flow matching
- Initialize
x_0 ~ N(0, 1)with shape [1, 36] - 7 Euler steps from t=0 to t=1
- Each step: build batched 3-token sequence
[acoustic_proj, time_emb, llm_proj]with batch=2 (cond + uncond), run 3 bidirectional transformer layers - Classifier-free guidance:
v = 1.2 * v_cond - 0.2 * v_uncond(batched CFG: both passes in single forward) - Quantize output to FSQ levels: map [-1,1] to [0,20], add +2 offset for special tokens
- Initialize
The backbone has a single [9088, 3072] codebook embedding table for 37 codebooks:
- Codebook 0 (semantic): 8192 + 2 special = 8194 entries, offset 0
- Codebooks 1-36 (acoustic): 21 + 2 special = 23 entries each
Each frame's 37 codebook embeddings are summed together to produce a single [dim] vector. This is fed directly into the backbone for the next step — the AUDIO token (ID 24) embedding is not added per-frame (it is only used once as the initial decode step after prefill).
- Input: dequantized codes [1, 292, T] where 292 = 256 (semantic) + 36 (acoustic)
- Block 0: Conv1d [292 -> 1024], stride 1, kernel 3
- Blocks 1-3: ConvTranspose1d [1024 -> 1024], strides [2, 2, 2], kernels [4, 4, 4]
- Each conv is followed by 2 transformer layers
- Output projection: Conv1d [1024 -> 240], kernel 7
- Final reshape: [1, 240, T'] -> [1, 1, T' * 240] (patch_size=240)
- Total upsampling: T * 1 * 2 * 2 * 2 = 8T frames -> 8T * 240 = 1920T samples at 24kHz
Semantic dequantization: look up code in EMA codebook (embedding_sum / cluster_usage) -> [256] vector.
Acoustic dequantization: FSQ decode level [0,20] -> value in [-1, 1].
Effective weight = g * v / ||v|| where:
g=parametrizations.weight.original0(magnitude, shape [out_ch, 1, 1])v=parametrizations.weight.original1(direction, shape [out_ch, in_ch, kernel])||v||= L2 norm over (in_ch, kernel) dimensions
Voice embeddings are pre-computed backbone hidden states. Each voice is a tensor of shape [N, 3072] where N varies by voice (typically 100-300 frames = 8-24 seconds of reference).
The original checkpoint stores these as PyTorch .pt files. For the MLX backend, they must be converted to .safetensors format (key: embedding).
The codec transformer layers have three features not present in the backbone. All are required for correct audio output — without them, decoder values explode and produce static noise:
-
QK Norm: RMSNorm applied to Q and K projections before multi-head reshape (weight shape [1024] matches full projected dim, not per-head). Uses
qk_norm_eps= 1e-6. -
Layer Scale: Per-channel learnable scales applied to attention and FFN outputs before the residual add:
x + scale * attn_out. Without this, values explode through decoder blocks (89→46→260→700 max_abs). -
norm_eps = 0.01: The codec uses a much larger norm_eps (0.01) for attention_norm/ffn_norm than the backbone (1e-5). This is separate from qk_norm_eps (1e-6).
-
Causal attention: The codec uses causal (not bidirectional) attention.
-
Sliding window:
attn_sliding_window_size: 16— implemented in the attention layer for the codec transformer.
On Apple M4 Max (TCH backend, BF16 CPU):
- Model loading: ~18s (full tensor copy, no memory mapping)
- Prefill (225 tokens, 26 layers): ~13s
- Per-frame generation: ~0.55s (backbone + flow matching)
- Codec decoding: ~0.2s
- "Hello." (24 frames, 1.92s audio): ~30s total
- "The quick brown fox..." (36 frames, 2.88s audio): ~30s total
On Apple M4 Max (MLX backend):
- Model loading: ~0.1s (memory-mapped safetensors)
- Prefill (225 tokens, 26 layers): ~3.3s
- Per-frame generation: ~0.34s (backbone ~70ms + flow matching ~270ms)
- Codec decoding: ~0.16s
- "Hello." (19 frames, 1.52s audio): ~10.5s total
- "The quick brown fox jumps over the lazy dog." (41 frames, 3.28s audio): ~17s total
-
Fused MLX ops: SDPA (
fast_scaled_dot_product_attention), RMSNorm (fast_rms_norm), and RoPE (fast_rope) use fused Metal kernels. SDPA handles GQA natively (norepeat_kvexpansion needed). RMSNorm replaces 6 discrete ops with 1 kernel. -
Batched CFG: The flow matching CFG conditional and unconditional passes are batched together (batch=2) into a single forward pass through the 3 transformer layers, halving Metal kernel dispatch overhead.
-
Reduced eval() frequency: From ~130 eval() calls per frame to ~2 (backbone + flow matching). The backbone does a single eval() after all 26 layers, and the flow matching does one eval() after all 7 Euler steps.
-
BF16 flow matching: Random noise, zeros, and time embeddings are cast to BF16 to match weight dtype, avoiding implicit F32 promotion.
-
Pre-computed time projections: The 7 sinusoidal time step embeddings (constant across all frames) are pre-computed at model init.
-
SDPA mask NaN fix: Causal attention mask uses
-1e9instead of-infto avoid0 * -inf = NaNin IEEE 754.
Flow matching is 80% of per-frame time (270ms vs 70ms backbone). Each frame requires 7 Euler steps × 3 transformer layers (batch=2, seq=3). The small matrix sizes [6, 3072] cannot saturate the GPU efficiently. The backbone is relatively efficient at ~2.6ms/layer for single-token decode with KV cache.
The Voxtral checkpoint uses a Tekken BPE tokenizer (tekken.json). Key details:
- Special token offset: BPE token IDs are offset by
num_special_tokens(1000). Special/control tokens occupy IDs 0–999, BPE tokens start at 1000. The tokenizer must add this offset when encoding. - Format:
tekken.jsoncan be either{ "config": ..., "vocab": [...] }(v7) or a bare array[...](legacy). The code handles both. - Vocab cap: The tokenizer caps output IDs to
vocab_size(131072) to prevent OOB ontok_embeddings.
The API server supports 6 output formats: wav, pcm, mp3, flac, ogg/opus. All encoding functions are in src/audio.rs, dispatched by encode_audio().
| Format | Crate | Content-Type | Notes |
|---|---|---|---|
wav |
hound | audio/wav |
24kHz 16-bit mono |
pcm |
raw | audio/pcm |
24kHz 16-bit LE mono |
mp3 |
mp3lame-encoder 0.2 | audio/mpeg |
Resampled to 44.1kHz, 128kbps CBR |
flac |
flacenc 0.5 | audio/flac |
24kHz 16-bit mono, lossless |
ogg/opus |
audiopus 0.2 + ogg 0.9 | audio/ogg |
Resampled to 48kHz via rubato |
mp3lame-encoder: Use MonoPcm (not InterleavedPcm) for mono audio — InterleavedPcm always divides sample count by 2, producing half-duration chipmunk audio. Use encode_to_vec() and flush_to_vec::<FlushNoGap>(), not encode() / flush(). The non-vec variants require &mut [MaybeUninit<u8>] buffers and unsafe set_len().
flacenc:
into_verified()returnsResult<Verified<Encoder>, (Encoder, VerifyError)>— the error is a tuple, not a Display type. Map with|(_enc, e)| ....MemSource::from_samples(&samples, channels, bits_per_sample, sample_rate)— all params areusize, notu32. Castsample_rate as usize.- Block size is
config.block_size(singularusize), notconfig.block_sizes[0]. - Write output to
ByteSink(alias forMemSink<u8>), notVec<u8>. UseByteSink::new(),stream.write(&mut sink),sink.into_inner().
audiopus: Encoder::new() returns a non-mut encoder; encode_float() takes &self. Don't declare let mut encoder.