This project trains a Transformer decoder-only language model on WikiText-2 using JAX/Flax.
-
Create and activate a virtual environment:
python3 -m venv .venv source .venv/bin/activate -
Upgrade pip and install dependencies:
pip install --upgrade pip pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install -r requirements.txt
Note for WSL2 GPU users: Ensure you have both the NVIDIA CUDA toolkit and cuDNN libraries installed inside WSL2. For example:
sudo apt update
# Install CUDA toolkit (version may vary, e.g. 12.8) from the local CUDA repo
sudo apt update
sudo apt install -y cuda-12-8
# (Optional) To enable deeper GPU support, install cuDNN for your CUDA version:
# 1. Visit https://developer.nvidia.com/rdp/cudnn-download and download the cuDNN Debian package matching your CUDA version and OS.
# 2. Install the local repository package, e.g.:
# wget https://developer.download.nvidia.com/compute/cudnn/<version>/local_installers/cudnn-local-repo-*-deb
# sudo dpkg -i cudnn-local-repo-*.deb
# sudo cp /var/cudnn-local-repo-*/cudnn-local.keyring /usr/share/keyrings/
# sudo apt update
# sudo apt install libcudnn8 libcudnn8-devThen add the driver library path to the loader:
echo "/usr/lib/wsl/lib" | sudo tee /etc/ld.so.conf.d/nvidia.conf
sudo ldconfigVerify with:
python - << 'EOF'
import jax; import jax.numpy as jnp
print("Devices:", jax.devices())
EOF- Download the WikiText-2 data:
python data/download_wikitext2.py
- Train the model:
python src/train.py --config configs/config.yamlFor a larger/future chat-style model, try the advanced config:
python src/train.py --config configs/config_advanced.yaml.
├── .venv/ # Python virtual environment
├── data/ # Scripts and processed data
│ └── download_wikitext2.py
├── configs/ # Configuration files
│ ├── config.yaml # default training
│ ├── config_smoke.yaml # smoke-test config
│ └── config_advanced.yaml# larger model config
├── logs/ # Monitoring & checkpoint logs
├── requirements.txt # Python dependencies
└── src/ # Source code
├── tokenizer.py
├── model.py
├── train.py
└── generate.py # Interactive text generation
After training, you can generate text interactively:
source .venv/bin/activate
# Generate with optional sampling parameters:
JAX_PLATFORM_NAME=cpu python src/generate.py \
--config configs/config.yaml \
--length 20 \
--temperature 0.8 \
--top_k 50 \
--top_p 0.9At the prompt, enter a starting text (empty input to exit) and the model will produce the next tokens.