This is the code used for the distillation process of the Audio Tower of LAION CLAP.
With this command you will create the virtual env with all the dependencies and start the training:
# Setup and install
python3 -m venv venv && source venv/bin/activate && pip install -r requirements.txt
# Run training
screen -S training
python3 train_real.py --config config.yaml
screen -r training
To run report you can use this (you need to change the name of the .onnx model in the code)
python final_test.py
To re run only the validation on a checkpoint
python3 revalidate_checkpoint.py --ckpt checkpoint_epoch_1.pth
You can check how the average cosine similarity (training and validation) is going for each epoch with this one line command:
python3 - <<'PY'
import glob, torch
for f in sorted(glob.glob('student_clap/checkpoints/checkpoint_epoch_*.pth')):
ckpt = torch.load(f, map_location='cpu')
m = ckpt.get('train_metrics', {})
val_mse = ckpt.get('val_mse', ckpt.get('last_val_mse', ckpt.get('best_val_mse','N/A')))
val_cos = ckpt.get('val_cosine_sim', ckpt.get('last_val_cosine', ckpt.get('best_val_cosine','N/A')))
print(f"{f}: train_cos={m.get('avg_cosine_sim')}, train_mse={m.get('avg_mse')}, val_mse={val_mse}, val_cos={val_cos}, lr={m.get('learning_rate')}")
PY
with semantic metrics
python3 - <<'PY'
import glob, torch
for f in sorted(glob.glob('student_clap/checkpoints/checkpoint_epoch_*.pth')):
ckpt = torch.load(f, map_location='cpu')
m = ckpt.get('train_metrics', {})
val_mse = ckpt.get('val_mse', ckpt.get('last_val_mse', 'N/A'))
val_sem = ckpt.get('val_semantic_error', 'N/A')
print(f"{f}: train_cos={m.get('avg_cosine_sim')}, train_mse={m.get('avg_mse')}, train_sem={m.get('avg_semantic','N/A')}, val_mse={val_mse}, val_sem={val_sem}, lr={m.get('learning_rate')}")
PY
with both:
python3 - <<'PY'
import glob, torch
for f in sorted(glob.glob('student_clap/checkpoints/checkpoint_epoch_*.pth')):
ckpt = torch.load(f, map_location='cpu')
m = ckpt.get('train_metrics', {})
val_mse = ckpt.get('val_mse', ckpt.get('last_val_mse', 'N/A'))
val_cos = ckpt.get('val_cosine', 'N/A')
val_met = ckpt.get('val_metric', 'N/A')
val_met_name = ckpt.get('val_metric_name', 'N/A')
val_sem = ckpt.get('val_semantic_error', 'N/A')
print(f"{f}: train_cos={m.get('avg_cosine_sim')}, train_mse={m.get('avg_mse')}, train_sem={m.get('avg_semantic','N/A')}, val_mse={val_mse}, {val_met_name}={val_met}, val_sem={val_sem}, lr={m.get('learning_rate')}")
PY
You can check the million of parameter used for your input configuration with this command:
PYTHONPATH=.. python -c "import yaml; from student_clap.models.student_onnx_model import StudentCLAPAudio; config=yaml.safe_load(open('config.yaml')); m=StudentCLAPAudio(config); print(m.count_parameters())"
To check instead which configuration of input you used for a checkpoint you can use this command:
PYTHONPATH=.. python -c "import torch; m=torch.load('student_clap/checkpoints/CHECKPOINT-NAME-HERE.pth', map_location='cpu'); print({k: v for k, v in m['config']['model'].items() if k.startswith('efficientat_') or k=='efficientat_model'})"
To force the algorithm to read LR from config.yaml after a stop, instead of reading from the scheduler:
python3 - <<'PY'
import torch, glob
for p in glob.glob('student_clap/checkpoints/checkpoint_epoch_*.pth'):
ckpt = torch.load(p, map_location='cpu')
ckpt.pop('optimizer_state_dict', None)
ckpt.pop('scheduler_state_dict', None)
torch.save(ckpt, p)
print("Stripped optimizer/scheduler from", p)
PY
To force the algorithm to read the weight decay value from config.yaml after a stop, instead of reading from schedule:
python3 - <<'PY'
import torch, glob, shutil
paths = glob.glob('student_clap/checkpoints/checkpoint_epoch_*.pth') + ['student_clap/checkpoints/latest.pth']
for p in paths:
try:
shutil.copy(p, p + '.bak')
ckpt = torch.load(p, map_location='cpu')
changed = False
for k in ('scheduler_state_dict','optimizer_state_dict'):
if ckpt.pop(k, None) is not None:
changed = True
if changed:
torch.save(ckpt, p)
print('Cleaned', p)
else:
print('No optimizer/scheduler state in', p)
except Exception as e:
print('Skipped', p, ':', e)
PY
Reset logit scale to initial value:
python3 -c "
import torch
ckpt = torch.load('student_clap/checkpoints/latest.pth', map_location='cpu')
ckpt['model_state_dict']['logit_scale'] = torch.tensor(2.6592) # Reset to init
torch.save(ckpt, 'student_clap/checkpoints/latest.pth')
print('Reset logit_scale to 2.6592')
"
Clear Adamw optimizer
python -c "
import torch
ck = torch.load('student_clap/checkpoints/checkpoint_epoch_9.pth', map_location='cpu')
ck.pop('optimizer_state_dict', None)
ck.pop('scheduler_state_dict', None)
ck['patience_counter'] = 0
torch.save(ck, 'student_clap/checkpoints/checkpoint_epoch_9_nostate.pth')
print('Keys saved:', list(ck.keys()))
"
Check the cosine and val cosine also in subfolder:
find student_clap/checkpoints -name "checkpoint_epoch_*.pth" | sort -V | python3 -c '
import torch, sys
for line in sys.stdin:
f = line.strip()
try:
ckpt = torch.load(f, map_location="cpu", weights_only=False)
m = ckpt.get("train_metrics", {})
avg = m.get("avg_cosine_sim", "null")
lr = m.get("learning_rate", "null")
val = ckpt.get("last_val_cosine", ckpt.get("val_cosine_sim", ckpt.get("best_val_cosine", "null")))
print(f"{f}: cosine={avg}, val_cosine={val}, lr={lr}")
except Exception as e:
print(f"{f}: ERROR - {e}")
'
Architecture
- Base student: EfficientAT MobileNet (default
mn10_as, n_mels=128 -> 512‑dim). Optional pretrained variants are selectable viamodel.efficientat_model. - Fusion mode (enabled when
model.specialist_checkpointis non‑null): a frozen specialist (previously trained EfficientAT) is paired with a lightweight trainable student backbone. The current default backbone is EdgeNeXt‑XX‑Small (fusion_backbone: edgenext), butefficientat,deit_tinyandmobilevitv2are also supported. A 512‑dim per‑channel gate blends specialist and student outputs. - Projection head: residual MLP (backbone_dim→512), dropout and bias configurable. Final embeddings are L2‑normalized.
Training schedule
- Stage 1 – distill CLAP teacher across all parameters; length controlled by
training.epochs. - Stage 2 – triggered automatically after stage 1; encoder weights are frozen and only the projection/gate (in fusion) remain trainable. Duration, learning rate and scheduler are configured via
training.stage2_*fields. The shortcut flagtraining.projection_onlycan force stage 2 behaviour on startup.
Warmup, LR scheduling (ReduceLROnPlateau by default or optional CosineAnnealingLR), mixed precision, augmentation, mixup and other hyper‑parameters live entirely in config.yaml.
Segmentation & batching
10‑second segments with 50 % overlap; segments are processed in small groups (model.segment_batch_size) and aggregated either individually, averaged or both (training.training_strategy).
Loss & scaling options
- Loss function selectable between
cosine,mseorkl(cosine/KL disable semantic loss). - Static temperature (
loss_temperature) or learnable logit-scale (use_logit_scale+init_logit_scale, clamped bymax_logit_scale_T). - Optional focal weighting, embedding normalization, semantic alignment (
lambda_semantic), and teacher cache toggles.