This repository implements Bootstrap Your Own Latent (BYOL) self-supervised learning using the LightlySSL framework for pretraining on unlabeled images.
BYOL is a self-supervised learning algorithm that learns visual representations by predicting one augmented view of an image from another augmented view, without using negative examples. This implementation uses the LightlySSL framework to provide a clean and efficient implementation.
The current setup is configured for a driver behavior classification dataset with:
- Labeled training images: ~22,000 images across 10 classes (c0-c9)
- SSL pretraining: Uses only training set (~22,000 images) split into train/validation
- No information leakage: Test set is preserved for final evaluation only
├── requirements.txt # Python dependencies
├── prepare_dataset.py # Dataset preparation script
├── byol_model.py # BYOL model implementation
├── train_byol.py # Training script
├── evaluate_byol.py # Evaluation script
├── run_byol_pipeline.py # Complete pipeline script
└── README.md # This file
pip install -r requirements.txtpython run_byol_pipeline.py --epochs 100 --batch_size 64 --backbone resnet18python prepare_dataset.pypython train_byol.py \
--train_data_dir data/ssl_dataset/train_unified \
--val_data_dir data/ssl_dataset/val_unified \
--output_dir outputs/byol \
--backbone resnet18 \
--epochs 100 \
--batch_size 64 \
--learning_rate 3e-4python evaluate_byol.py \
--model_path outputs/byol/best_model.pth \
--backbone resnet18 \
--output_dir outputs/evaluation- Momentum-based target network: Prevents collapse during training
- Symmetric loss: Predicts both views symmetrically
- Strong augmentations: BYOL-specific augmentation pipeline
- Flexible backbones: ResNet18, ResNet34, ResNet50 support
- Automatic checkpointing: Saves best model and periodic checkpoints
- Learning rate scheduling: Cosine annealing with warmup
- Progress tracking: Real-time loss monitoring
- Visualization: Training history plots
- Linear probing: Tests frozen feature quality
- Fine-tuning: End-to-end evaluation
- Comprehensive metrics: Accuracy, confusion matrix, classification report
- Visualization: Confusion matrices and training curves
| Parameter | Default | Description |
|---|---|---|
--epochs |
100 | Number of training epochs |
--batch_size |
64 | Batch size for training |
--learning_rate |
3e-4 | Learning rate |
--weight_decay |
1e-6 | Weight decay for regularization |
--backbone |
resnet18 | Backbone architecture |
| Parameter | Default | Description |
|---|---|---|
--min_scale |
0.2 | Minimum scale for random resized crop |
--cj_prob |
0.8 | Color jitter probability |
--cj_bright |
0.4 | Color jitter brightness |
--cj_contrast |
0.4 | Color jitter contrast |
--cj_sat |
0.2 | Color jitter saturation |
--cj_hue |
0.1 | Color jitter hue |
--gaussian_blur_prob |
0.1 | Gaussian blur probability |
--solarization_prob |
0.2 | Solarization probability |
- Two Views: Creates two different augmented views of the same image
- Online Network: Main network that learns representations
- Target Network: Momentum-updated copy of the online network
- Prediction: Online network predicts target network's output
- Symmetric Loss: Both views predict each other's target representation
- No negative examples: Unlike contrastive methods, doesn't need negative pairs
- Stable training: Momentum updates prevent collapse
- Strong representations: Learns rich visual features without labels
- Transfer learning: Pretrained features work well for downstream tasks
- Loss: Should decrease steadily during training
- Learning rate: Follows cosine annealing schedule
- Momentum updates: Target network slowly follows online network
- Linear Probe Accuracy: Tests quality of frozen features
- Fine-tuning Accuracy: Tests end-to-end performance
- Confusion Matrix: Shows per-class performance
- Out of Memory: Reduce batch size or use gradient accumulation
- Slow Training: Increase number of workers or use mixed precision
- Poor Performance: Check augmentation parameters or learning rate
- Convergence Issues: Adjust momentum parameter or learning rate schedule
- Use GPU: Training is much faster on GPU
- Adjust batch size: Larger batches often lead to better performance
- Monitor loss: Should decrease steadily without oscillations
- Check augmentations: Strong augmentations are crucial for BYOL
To use with your own dataset:
- Organize images in a single directory (no subdirectories needed)
- Update the data path in training script
- Adjust input size if needed
- Modify augmentation parameters for your domain
Key parameters to tune:
- Learning rate: Start with 3e-4, adjust based on loss curve
- Momentum: Default 0.996 works well, try 0.99-0.999
- Augmentation strength: Adjust based on your data characteristics
- Batch size: Larger is generally better if memory allows
- BYOL Paper: Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
- LightlySSL Documentation: Self-supervised learning framework
- BYOL Implementation Guide
This implementation is provided for educational and research purposes.