Russian TTS training pipeline based on the Kokoro architecture, with strong Apple Silicon (MPS) support, MFA duration alignment, feature caching, and stability-focused training defaults.
- Stabilized training defaults for fewer practical gradient spikes (longer warmup + lower OneCycle peak LR pressure).
- Built-in spike safeguards for projection/attention layers and warmup-aware gradient explosion detection.
- Dynamic frame-based batching enabled by default with MPS-aware auto-caps.
- Cleaner diagnostics: detailed stabilization logs are now behind
--verbose. - Better CI portability in tests (device-safe behavior when MPS is unavailable).
- MPS-first training flow (with CUDA/CPU fallback).
- MFA integration for phoneme-duration supervision.
- Precomputed feature caching for faster epochs.
- Variance prediction (pitch + energy) for improved prosody.
- Validation + early stopping + checkpoint resume.
- Gradient checkpointing and adaptive memory management.
pip install -r requirements.txt
pip install -e .Recommended explicit environment setup (macOS / Linux)
# Use a supported Python version (recommended: 3.11)
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip setuptools wheel
# Install package requirements and the project in editable mode
pip install -r requirements.txt
pip install -e .
# Quick verification of core dependencies
python - <<PY
import sys, torch
print('Python', sys.version.split()[0])
print('PyTorch', torch.__version__)
print('MPS available:', getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available())
PYOptional dev extras:
pip install -e .[dev]Expected corpus structure:
ruslan_corpus/
├── metadata_RUSLAN_22200.csv
└── wavs/
├── 000001_RUSLAN.wav
└── ...
Metadata format:
audio_filename|transcription
kokoro-preprocess --corpus ./ruslan_corpus --output ./mfa_outputQuick verification and smoke tests (minimal reproducible checks)
Minimal smoke-run (small, non-production)
# If you want to run a very short training run as a sanity check, use --no-mfa and a tiny corpus.
# Replace ./my_small_corpus with a directory that contains a single wav and a metadata CSV in the expected format.
kokoro-train --corpus ./my_small_corpus --output ./tmp_kokoro_test --no-mfa --epochs 1 --batch-size 1 --verbose
# Expected: the run should start, print a few log lines including the model init and dataloader length,
# and either finish one epoch or fail gracefully with a clear error (dataset/corpus misconfiguration).kokoro-precompute --corpus ./ruslan_corpuskokoro-train --corpus ./ruslan_corpus --output ./models/kokoro_russianIf you hit occasional MPS backend fallback issues:
PYTORCH_ENABLE_MPS_FALLBACK=1 kokoro-train --corpus ./ruslan_corpusModule equivalents:
python -m kokoro.cli.precompute_features --corpus ./ruslan_corpus
python -m kokoro.cli.preprocess --corpus ./ruslan_corpus --output ./mfa_output
python -m kokoro.cli.training --corpus ./ruslan_corpusComplete CLI flags and interactions (concise):
-
--corpus, -c(default:./ruslan_corpus): Corpus directory containingmetadata_*.csvandwavs/. -
--output, -o(default:./kokoro_russian_model): Output model directory. -
--resume, -r(default:None):autoor explicit checkpoint path to resume training. -
--batch-size, -b(default:8): Per-device batch size (subject to dynamic batching overrides). -
--epochs, -e(default:50): Number of epochs. -
--learning-rate, -lr(default:1e-4): Base learning rate. -
--save-every(default:2): Save checkpoint every N epochs. -
--mfa-alignments(default:auto): Path to MFAalignments/directory. Use--no-mfato disable MFA usage. -
--no-mfa(flag): Disable MFA and use estimated durations. -
--val-split(default:0.1): Validation split fraction. -
--no-validation(flag): Disable validation entirely. -
--validation-interval(default:1): Validate every N epochs. -
--early-stopping-patience(default:10): Early stopping patience. -
Dynamic batching and frame caps:
--dynamic-batching(enabled by default): Use frame-based dynamic batching.--no-dynamic-batching(flag): Use fixed-size batching.--max-frames(default: config-driven): Maximum mel frames allowed in a dynamic batch.--min-batch-size(default:4): Minimum batch size under dynamic batching.--max-batch-size(default:32): Maximum batch size under dynamic batching (may be auto-capped on MPS).
-
Profiling / AMP:
--profile-amp(flag): Run AMP profiling to select stable AMP usage before training.--profile-amp-batches(default:10): Number of batches used for AMP profiling.
-
Optimizer / fused AdamW flags & interactions:
--fused-adamw(flag): Force-enable fused AdamW (may only be supported on some backends).--no-fused-adamw(flag): Force-disable fused AdamW.--try-fused-adamw-mps(default:True): Attempt to use fused AdamW on MPS.
Optimizer selection behavior summary:
- If neither --fused-adamw nor --no-fused-adamw is set, selection is automatic: fused AdamW is used when the device and PyTorch version support it.
- --fused-adamw forces attempted use; if unavailable it may raise when forced.
- --no-fused-adamw forces the standard torch.optim.AdamW implementation.
- On MPS, --try-fused-adamw-mps enables an experimental code path that attempts a fused variant; it will auto-fallback if unsupported.
- Diagnostics and memory:
--verbose, -v(flag): Enable verbose stabilization diagnostics (duration pred vs target stats, mask counts).--no-memory-cache(flag): Disable in-memory feature caching (use on-disk cache only).
Examples:
# 1) Basic training with MFA (default) and dynamic batching
kokoro-train --corpus ./ruslan_corpus --output ./models/kokoro_russian --batch-size 8 --epochs 50
# 2) Force fused AdamW (may fail if unsupported) or force-disable it
kokoro-train --corpus ./ruslan_corpus --output ./models/kokoro_russian --fused-adamw
kokoro-train --corpus ./ruslan_corpus --output ./models/kokoro_russian --no-fused-adamw
# 3) Try fused AdamW on MPS (experimental) — auto-fallback if not supported
kokoro-train --corpus ./ruslan_corpus --output ./models/kokoro_russian --try-fused-adamw-mps
# 4) Minimal debugging run: single epoch, no MFA, verbose logs for duration diagnostics
kokoro-train --corpus ./my_small_corpus --output ./tmp_kokoro_test --no-mfa --epochs 1 --batch-size 1 --verbose
# 5) Explicit alignment directory
kokoro-train --corpus ./ruslan_corpus --output ./my_model --mfa-alignments ./mfa_output/alignmentsNotes:
- If you see fused-optimizer errors on startup, pass
--no-fused-adamwto force the fallback optimizer and avoid runtime crashes. - The
--try-fused-adamw-mpsflag is safe: it will attempt the fused code path on Apple Silicon and fall back when necessary, but behavior can vary by PyTorch version. --verboseprints helpful diagnostics (duration pred vs target mean/std/min/max and phoneme mask counts) useful when diagnosing duration-loss convergence.
From TrainingConfig in src/kokoro/training/config.py:
- OneCycle LR enabled,
max_lr_multiplier=2.0. - Linear warmup enabled,
warmup_steps=1200. - Gradient accumulation default:
2. - Dynamic batching default: on.
- In-memory feature cache: enabled by default. Use
--no-memory-cacheto disable keeping precomputed features in RAM (reduces host memory usage at cost of slightly higher I/O and cache latency). - Stability safeguards: projection/attention pre-clipping + warmup-aware explosion thresholds.
- MPS-aware auto-limits can reduce oversized values (e.g., frame caps/seq length/batch sizes).
# Verify feature cache health before training
python3 -m kokoro.utils.cache_manager --corpus ./ruslan_corpus --status
# Resume automatically from latest checkpoint
kokoro-train --corpus ./ruslan_corpus --output ./models/kokoro_russian --resume auto
# Train without MFA durations
kokoro-train --corpus ./ruslan_corpus --no-mfa
# Train with explicit dynamic batching bounds
kokoro-train --corpus ./ruslan_corpus --max-frames 18000 --min-batch-size 4 --max-batch-size 12
# Force fused AdamW (or force-disable it)
kokoro-train --corpus ./ruslan_corpus --fused-adamw
kokoro-train --corpus ./ruslan_corpus --no-fused-adamw
# Fused AdamW on MPS is enabled by default (experimental)
kokoro-train --corpus ./ruslan_corpus --try-fused-adamw-mps
# Inference from final model or latest checkpoint in a model directory
python -m kokoro.inference.inference --model ./my_model --text "Привет, это тест." --output output.wav --device mps
# Inference tuning (helps early checkpoints avoid very short outputs)
# Note: an explicit `--stop-threshold` passed on the CLI overrides any
# checkpoint-tuned or internal model default and will be honored during
# generation.
python -m kokoro.inference.inference --model ./my_model --text "Привет, это тест." --output output.wav --device mps --stop-threshold 0.6 --min-len-ratio 0.9 --max-len 1600
# Run focused unit tests
python -m pytest tests/unit/test_attention_operations.py tests/unit/test_multi_layer_attention.py tests/unit/test_trainer_adaptive_stabilization.py -q# The trainer writes logs to `<output_dir>/logs` (SummaryWriter) and profiler traces to
# `<output_dir>/profiler_logs/<timestamp>`. Example:
tensorboard --logdir my_model/logs --bind_all
# To view profiler traces (TensorBoard Profiler) you can point TensorBoard at the profiler
# directory or the parent `profiler_logs` directory:
tensorboard --logdir my_model/profiler_logs --bind_allWhat to look for in TensorBoard:
- Scalars:
total_loss,mel_loss,duration_loss,stop_token_loss,pitch_loss,energy_loss. - Histograms: model parameter distributions and gradients (if enabled).
- Graph: model graph (if exported) and profiler timelines (CPU/GPU/MPS activity).
Tips:
- If running on a remote machine, forward the tensorboard port (default 6006) to your local machine.
Example:
ssh -L 6006:localhost:6006 user@remotethen run tensorboard on remote and open http://localhost:6006 locally. - Use
--verboseduring a short validation run to get additional diagnostics printed to the training logs that complement TensorBoard (duration pred vs target stats, mask counts).
scripts/analyze_training_regression.py is the primary diagnostic tool for monitoring training health and catching regressions. It combines checkpoint weight inspection with a comprehensive TensorBoard metrics analysis in a single terminal report.
# Default: reads my_model/ (checkpoints) and my_model/logs/ (TensorBoard events)
python scripts/analyze_training_regression.py
# Custom model directory
python scripts/analyze_training_regression.py --model path/to/modelRequirements: tensorboard package installed; checkpoints present at <model_dir>/checkpoint_epoch_N.pth; TensorBoard event files in <model_dir>/logs/.
Checkpoint weight analysis (per epoch, per layer):
- Weight norm table: norm, delta-norm, NaN/Inf flags for every saved checkpoint
- Key layer weight norms and deltas across epochs (encoder, decoder, postnet, stop token projection)
- Top-10 largest weight changes per epoch transition — useful for spotting sudden layer-level drift
TensorBoard analysis (printed in order):
| Section | What it shows |
|---|---|
| Step-level loss summary | All 6 losses (total, mel, duration, stop, pitch, energy): first/last/Δ/trend/mean/min/max |
| Val mel epoch series | Per-epoch validation mel loss with explicit ▲/▼ Δ, best epoch, total Δ |
| Epoch train/val table | Train vs val side-by-side per epoch with ▲ regression flags |
| Mel vs stop 200-step window correlation | 200-step windows: mel mean/Δ, stop mean/Δ, LR%, co-move label (both↑ LR pressure, both↓ improving, stop↑ only, mel↑ only) |
| Stop token analysis | Loss percentiles (p50/p90/p99), burst detection split first vs second half |
| Gradient health | Spike counts (>5/10/20 thresholds), overall + per-epoch clipping saturation % |
| Late spike context | Per-spike: raw grad norm, LR % of peak, stop nearby, stop elevated, attribution label (LR at peak, LR peak + stop, stop burst, outlier batch) |
| LR trajectory | 8-point sample across training, phase detection (warmup/ramp/peak/decay) |
| LR phase detail | 100-step resolution from 90% of peak — decoder and encoder % of peak |
| Regression flag summary | PASS/WARN/FAIL checklist for 6 key indicators |
| Analysis & recommendations | Prioritized CRITICAL / WARN / INFO recommendations with specific config guidance |
both↑ (LR pressure)— mel and stop both rising together; root cause is LR, not a stop-specific problemboth↓ (improving)— healthy descentstop↑ only (stop source)— stop rising while mel is stable or falling; investigatestop_token_pos_weightorstop_token_loss_weightmel↑ only— mel regression without stop involvement; check for outlier batches or LR decay issues
| Label | Condition |
|---|---|
LR at peak |
LR ≥ 97% of peak, stop not elevated |
LR peak + stop |
LR ≥ 97% of peak AND stop > p75 |
stop burst |
Stop > p75 AND LR < 97% of peak |
outlier batch |
Neither LR pressure nor elevated stop |
- MPS out-of-memory: lower
--max-framesand/or--batch-size; see MPS OOM Solutions. - Missing metadata/audio: verify corpus layout and
metadata_RUSLAN_22200.csv. - Slower-than-expected startup: first epoch may build caches; precompute features to speed up.
- Gradient spike warnings: use defaults first, then reduce
--learning-rateor--max-framesif needed.
Typical output directory:
models/kokoro_russian/
├── checkpoint_epoch_2.pth
├── checkpoint_epoch_4.pth
├── ...
└── kokoro_russian_final.pth
PRs are welcome. For larger changes, open an issue first so implementation direction is aligned.
This project is intended for educational and research use with the Ruslan corpus and Kokoro-style TTS training workflows. Contact the owner with questions and/or for commercial usage.