Skip to content

feat: Add gemma4 drafter model support#2240

Open
athitten wants to merge 6 commits into
mainfrom
athitten/gemma4_drafter_support
Open

feat: Add gemma4 drafter model support#2240
athitten wants to merge 6 commits into
mainfrom
athitten/gemma4_drafter_support

Conversation

@athitten
Copy link
Copy Markdown
Contributor

@athitten athitten commented May 15, 2026

What does this PR do ?

This PR adds joint fine-tuning support for Gemma 4 base (also called target model) and drafter/assistant models that enable multi-token prediction (MTP). The drafter is co-trained with the Gemma 4 base end-to-end via a composite model (Gemma4WithDrafter) that wires up shared K/V states, sqrt(H_b)-scaled embeddings, and a K-step recurrent forward matching the Gemma 4 drafter tech report. The PR provides two reference configs for joint fine-tuning of gemma-4-E4B-it and gemma-4-E4B-it-assistant, one with MedPix VQA dataset and the other with a text-only Tulu-3 + Magicoder mix. Also provides an inference benchmark script that validates speculative-decode throughput on the saved checkpoint.

The feature has been verified only against the Gemma 4 4B (E4B) base + drafter pair. The composite is architecturally model-agnostic within the Gemma 4 family, but the example YAMLs, the parity tests, and all training/inference verification in this PR target the 4B pair.

Implementation

  • Gemma4WithDrafter composite (nemo_automodel/components/models/gemma4_drafter/)

    • Wraps a Gemma4ForConditionalGeneration base + Gemma4AssistantForCausalLM drafter as a single nn.Module for FSDP2 training.
    • K-step recurrent forward: round 0 feeds cat(base.embed(input_ids), base.h_final); round k≥1 feeds cat(base.embed(input_ids_shifted_by_k), prev_drafter.last_hidden_state). shared_kv_states is captured once from a single base forward and reused at every round (per the Gemma 4 drafter tech report).
    • Saves to <ckpt>/base/ and <ckpt>/drafter/ HF-loadable subdirs for vLLM / HF inference handoff.
    • Conditionally unfreezes post_projection only when drafter_num_steps > 1
    • masked_embedding.centroids always frozen (the torch.topk inside it blocks gradient flow back to the centroids).
  • Recipe (recipes/vlm/finetune.py, recipes/base_recipe.py)

    • New _maybe_add_drafter_loss(out, base_loss, labels, …) sums per-step CE over the composite's drafter_logits list with _shift_labels_left(labels, k) per round.
    • base_recipe.load_checkpoint dispatches to model.load_pretrained when the composite exposes it, so the base/ + drafter/ subdir layout reloads correctly on resume.
  • Text-only dataset adapter (components/datasets/vlm/datasets.py)

    • make_tulu3_magicoder_text_mix_dataset interleaves allenai/tulu-3-sft-mixture (80 %) and ise-uiuc/Magicoder-OSS-Instruct-75K (20 %) into {"conversation": [...]} dicts with no image field, so default_collate_fn emits batches without pixel_values. The composite + base accept text-only inputs unchanged.
  • Example configs and benchmark (examples/vlm_finetune/gemma4_joint_drafter/)

    • gemma4_4b_joint_drafter_medpix.yaml — Joint fine-tuning recipe of google/gemma-4-E4B-it + google/gemma-4-E4B-it-assistant with MedPix VQA
    • gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml — Joint fine-tuning on tulu and magiccoder mix, same base + drafter pair
    • benchmark_mtp_inference.py — measures speculative-decode acceptance + throughput end-to-end against the trained 4B pair

Testing and Verification

All numbers below are for joint fine-tuning of gemma-4-E4B-it with /gemma-4-E4B-it-assistant

  1. Fine-tuning the joint model on MedPix-VQA: no NaNs, no loss spikes, no grad-norm spikes. Using the recipe examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yaml
    Loss curve:
Screenshot 2026-05-24 at 4 25 48 PM
  1. Large-scale fine-tuning on a Tulu-3 (80 %) + Magicoder (20 %) mix for 500 steps. Stable fine-tuning run observed with the recipe examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml
    Loss curve:
Screenshot 2026-05-24 at 4 28 18 PM
  1. Inference run on the tulu + magicoder fine-tuned checkpoint after 500 steps. Ran benchmark_mtp_inference.py against the saved base/ + drafter/ pair. Results below:
image
  1. Verified save + standalone HF reload works correctly. Checkpoints are stored in separate <ckpt>/base/ and <ckpt>/drafter/ subdirs as a prerequisite for generation. After a 5-step training run, each sub-checkpoint loads via plain HF from_pretrained without any NeMo-specific code:

    • Gemma4ForConditionalGeneration.from_pretrained("<ckpt>/base/model/consolidated")7.94 B params
    • Gemma4AssistantForCausalLM.from_pretrained("<ckpt>/drafter/model/consolidated")78.5 M params
    • State-dict key sets match the upstream released references (google/gemma-4-E4B-it and google/gemma-4-E4B-it-assistant) exactly (0 missing, 0 extra).
    • masked_embedding.token_ordering (int64[262144]) buffer survives the DCP → safetensors path.
  2. Verified resume-from-checkpoint loss parity. Run A: 10 steps fresh with ckpt_every_steps=5. Run B: same config but --checkpoint.restore_from <RunA>/epoch_0_step_4. Per-step loss is bit-identical between A and B at every overlapping step (5–9): 3.1257, 3.2992, 3.4045, 2.5375, 3.2824. Confirms model weights + optimizer state + LR scheduler state + dataloader state + RNG all restore correctly.

  3. Unit and functional tests.

  • tests/unit_tests/models/gemma4_drafter/test_composite.py — composite pre-projection input layout, shared-KV plumbing, K-step recurrence, save/load round-trip.
  • tests/unit_tests/models/gemma4_drafter/test_composite_fsdp2.py — FSDP2 wrap + expert/grad sync.
  • tests/unit_tests/models/gemma4_drafter/test_drafter_wrapper.py — drafter sub-module integration.
  • tests/unit_tests/recipes/test_vlm_drafter_helpers.py_shift_labels_left + _maybe_add_drafter_loss.
  • tests/functional_tests/L2_HF_Transformer_VLM_Gemma4_Joint_Drafter.sh — L2 single-node smoke.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@athitten athitten changed the title Add gemma4 drafter model support [WIP] Add gemma4 drafter model support May 15, 2026
athitten and others added 3 commits May 15, 2026 14:49
Apply fixes for joint base + drafter training:

* Drop ``use_cache=False`` override in ``composite.forward``. Without the
  ``DynamicCache``, HF's sliding-window mask path silently degrades
  (SDPA mask-skip can collapse sliding layers into plain causal attention),
  inflating the initial training loss. The YAML's
  ``text_config.use_cache: true`` now takes effect.
* Change drafter label shift from ``k + 1`` to ``k``. The VLM collate
  pre-shifts labels by 1 so ``labels[t] == input_ids[t + 1]``; the prior
  ``k + 1`` shift was training the drafter to predict ``input_ids[t + 2]``
  instead of ``input_ids[t + 1]``.
* Add hard asserts: ``cp_size == 1`` and ``torch_dtype == bfloat16`` in
  ``Gemma4WithDrafter.from_pretrained``.
* Add plan knobs: ``freeze_base_for_drafter``, ``share_embedding_with_base``
  (one-shot init copy; FSDP2-safe), ``base_activation_checkpointing``.
* Recipe: factor joint loss into ``FinetuneRecipeForVLM._maybe_add_drafter_loss``,
  gate log on ``is_remote_logging_step`` (was per-microbatch), and make
  validation drafter-aware so ``val_loss`` reflects drafter drift.
* Remove dead ``from_pretrained`` override in drafter wrapper.
* Drop redundant ``text_config.output_hidden_states`` from YAML; expand the
  ``use_cache: true`` comment to explain the real reason (sliding-window mask,
  not KV sharing).
* Add ``test_post_collate_semantic_alignment`` that pins the label-shift
  convention so a future regression to ``k + 1`` fails loudly. Refine
  ``test_drafter_loss_reaches_drafter_params`` to reflect that
  ``post_projection`` only sees gradient in multi-step chains.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
- Composite: K-step recurrent forward where the drafter consumes its prior
  round's post-projected last_hidden_state and a teacher-forced shifted
  token id at every k>=1. shared_kv_states captured once from a single base
  forward and reused across rounds. post_projection conditionally unfrozen
  when drafter_num_steps > 1.
- Recipe load path: dispatch to model.load_pretrained when the composite
  exposes it so the base/ + drafter/ subdir layout produced by
  save_pretrained can be reloaded for resume.
- Dataset adapter: make_tulu3_magicoder_text_mix_dataset interleaves
  allenai/tulu-3-sft-mixture (80%) and ise-uiuc/Magicoder-OSS-Instruct-75K
  (20%) into a text-only VLM-shaped list consumed by default_collate_fn
  without producing pixel_values.
- YAMLs: rename joint_drafter.yaml -> _medpix.yaml; add
  _tulu_magicoder_mix.yaml at drafter_loss_weight 0.001 (1/10 of the MedPix
  setting to compensate for ~10x larger summed CE on longer text sequences).
…mark

- Move gemma4_4b_joint_drafter_medpix.yaml and
  gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml into a dedicated
  examples/vlm_finetune/gemma4_joint_drafter/ subdir so the joint-drafter
  variants are easy to find next to each other.
- Add benchmark_mtp_inference.py for measuring speculative-decoding
  throughput / acceptance with the trained base + drafter pair.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@athitten athitten changed the title [WIP] Add gemma4 drafter model support feat: Add gemma4 drafter model support May 24, 2026
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test 753ee70

@athitten athitten requested review from pthombre and zyzhou5 as code owners May 24, 2026 23:48
…aths

- gemma4_4b_joint_drafter_medpix.yaml,
  gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml: replace local
  /workspace/hf_gemma4_e4b_it{,_assistant} paths with public HF repo
  ids google/gemma-4-E4B-it and google/gemma-4-E4B-it-assistant, and
  comment out the personal wandb block (leave a placeholder for users
  to fill in).
- benchmark_mtp_inference.py: drop the author-specific default
  checkpoint paths; make --base-ckpt and --drafter-ckpt required
  arguments and update the docstring + example command to point at the
  canonical joint-recipe save layout
  (<run>/<epoch_X_step_Y>/{base,drafter}/model/consolidated).

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant