Skip to content

kieryn/gpt-example-wikitext2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Small Transformer LM with Flax

This project trains a Transformer decoder-only language model on WikiText-2 using JAX/Flax.

Setup

  1. Create and activate a virtual environment:

    python3 -m venv .venv
    source .venv/bin/activate
    
  2. Upgrade pip and install dependencies:

    pip install --upgrade pip
    pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    pip install -r requirements.txt
    

Note for WSL2 GPU users: Ensure you have both the NVIDIA CUDA toolkit and cuDNN libraries installed inside WSL2. For example:

sudo apt update
# Install CUDA toolkit (version may vary, e.g. 12.8) from the local CUDA repo
sudo apt update
sudo apt install -y cuda-12-8
# (Optional) To enable deeper GPU support, install cuDNN for your CUDA version:
# 1. Visit https://developer.nvidia.com/rdp/cudnn-download and download the cuDNN Debian package matching your CUDA version and OS.
# 2. Install the local repository package, e.g.:
#    wget https://developer.download.nvidia.com/compute/cudnn/<version>/local_installers/cudnn-local-repo-*-deb
#    sudo dpkg -i cudnn-local-repo-*.deb
#    sudo cp /var/cudnn-local-repo-*/cudnn-local.keyring /usr/share/keyrings/
#    sudo apt update
#    sudo apt install libcudnn8 libcudnn8-dev

Then add the driver library path to the loader:

echo "/usr/lib/wsl/lib" | sudo tee /etc/ld.so.conf.d/nvidia.conf
sudo ldconfig

Verify with:

python - << 'EOF'
import jax; import jax.numpy as jnp
print("Devices:", jax.devices())
EOF
  1. Download the WikiText-2 data:
   python data/download_wikitext2.py
  1. Train the model:
  python src/train.py --config configs/config.yaml

For a larger/future chat-style model, try the advanced config:

  python src/train.py --config configs/config_advanced.yaml

Directory Structure

.
├── .venv/                  # Python virtual environment
├── data/                   # Scripts and processed data
│   └── download_wikitext2.py
├── configs/                # Configuration files
│   ├── config.yaml         # default training
│   ├── config_smoke.yaml   # smoke-test config
│   └── config_advanced.yaml# larger model config
├── logs/                   # Monitoring & checkpoint logs
├── requirements.txt        # Python dependencies
└── src/                    # Source code
    ├── tokenizer.py
    ├── model.py
    ├── train.py
    └── generate.py          # Interactive text generation

Generation

After training, you can generate text interactively:

source .venv/bin/activate
# Generate with optional sampling parameters:
JAX_PLATFORM_NAME=cpu python src/generate.py \
    --config configs/config.yaml \
    --length 20 \
    --temperature 0.8 \
    --top_k 50 \
    --top_p 0.9

At the prompt, enter a starting text (empty input to exit) and the model will produce the next tokens.

About

Example GPT-from-scratch setup using Flax with wikitext2 data.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages