feat: Add gemma4 drafter model support#2240
Open
athitten wants to merge 6 commits into
Open
Conversation
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Apply fixes for joint base + drafter training: * Drop ``use_cache=False`` override in ``composite.forward``. Without the ``DynamicCache``, HF's sliding-window mask path silently degrades (SDPA mask-skip can collapse sliding layers into plain causal attention), inflating the initial training loss. The YAML's ``text_config.use_cache: true`` now takes effect. * Change drafter label shift from ``k + 1`` to ``k``. The VLM collate pre-shifts labels by 1 so ``labels[t] == input_ids[t + 1]``; the prior ``k + 1`` shift was training the drafter to predict ``input_ids[t + 2]`` instead of ``input_ids[t + 1]``. * Add hard asserts: ``cp_size == 1`` and ``torch_dtype == bfloat16`` in ``Gemma4WithDrafter.from_pretrained``. * Add plan knobs: ``freeze_base_for_drafter``, ``share_embedding_with_base`` (one-shot init copy; FSDP2-safe), ``base_activation_checkpointing``. * Recipe: factor joint loss into ``FinetuneRecipeForVLM._maybe_add_drafter_loss``, gate log on ``is_remote_logging_step`` (was per-microbatch), and make validation drafter-aware so ``val_loss`` reflects drafter drift. * Remove dead ``from_pretrained`` override in drafter wrapper. * Drop redundant ``text_config.output_hidden_states`` from YAML; expand the ``use_cache: true`` comment to explain the real reason (sliding-window mask, not KV sharing). * Add ``test_post_collate_semantic_alignment`` that pins the label-shift convention so a future regression to ``k + 1`` fails loudly. Refine ``test_drafter_loss_reaches_drafter_params`` to reflect that ``post_projection`` only sees gradient in multi-step chains. Signed-off-by: Abhishree <abhishreetm@gmail.com>
- Composite: K-step recurrent forward where the drafter consumes its prior round's post-projected last_hidden_state and a teacher-forced shifted token id at every k>=1. shared_kv_states captured once from a single base forward and reused across rounds. post_projection conditionally unfrozen when drafter_num_steps > 1. - Recipe load path: dispatch to model.load_pretrained when the composite exposes it so the base/ + drafter/ subdir layout produced by save_pretrained can be reloaded for resume. - Dataset adapter: make_tulu3_magicoder_text_mix_dataset interleaves allenai/tulu-3-sft-mixture (80%) and ise-uiuc/Magicoder-OSS-Instruct-75K (20%) into a text-only VLM-shaped list consumed by default_collate_fn without producing pixel_values. - YAMLs: rename joint_drafter.yaml -> _medpix.yaml; add _tulu_magicoder_mix.yaml at drafter_loss_weight 0.001 (1/10 of the MedPix setting to compensate for ~10x larger summed CE on longer text sequences).
…mark - Move gemma4_4b_joint_drafter_medpix.yaml and gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml into a dedicated examples/vlm_finetune/gemma4_joint_drafter/ subdir so the joint-drafter variants are easy to find next to each other. - Add benchmark_mtp_inference.py for measuring speculative-decoding throughput / acceptance with the trained base + drafter pair. Signed-off-by: Abhishree <abhishreetm@gmail.com>
Contributor
Author
|
/ok to test 753ee70 |
…aths
- gemma4_4b_joint_drafter_medpix.yaml,
gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml: replace local
/workspace/hf_gemma4_e4b_it{,_assistant} paths with public HF repo
ids google/gemma-4-E4B-it and google/gemma-4-E4B-it-assistant, and
comment out the personal wandb block (leave a placeholder for users
to fill in).
- benchmark_mtp_inference.py: drop the author-specific default
checkpoint paths; make --base-ckpt and --drafter-ckpt required
arguments and update the docstring + example command to point at the
canonical joint-recipe save layout
(<run>/<epoch_X_step_Y>/{base,drafter}/model/consolidated).
Signed-off-by: Abhishree <abhishreetm@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
This PR adds joint fine-tuning support for Gemma 4 base (also called target model) and drafter/assistant models that enable multi-token prediction (MTP). The drafter is co-trained with the Gemma 4 base end-to-end via a composite model (
Gemma4WithDrafter) that wires up shared K/V states, sqrt(H_b)-scaled embeddings, and a K-step recurrent forward matching the Gemma 4 drafter tech report. The PR provides two reference configs for joint fine-tuning of gemma-4-E4B-it and gemma-4-E4B-it-assistant, one with MedPix VQA dataset and the other with a text-only Tulu-3 + Magicoder mix. Also provides an inference benchmark script that validates speculative-decode throughput on the saved checkpoint.The feature has been verified only against the Gemma 4 4B (E4B) base + drafter pair. The composite is architecturally model-agnostic within the Gemma 4 family, but the example YAMLs, the parity tests, and all training/inference verification in this PR target the 4B pair.
Implementation
Gemma4WithDraftercomposite (nemo_automodel/components/models/gemma4_drafter/)Gemma4ForConditionalGenerationbase +Gemma4AssistantForCausalLMdrafter as a singlenn.Modulefor FSDP2 training.cat(base.embed(input_ids), base.h_final); round k≥1 feedscat(base.embed(input_ids_shifted_by_k), prev_drafter.last_hidden_state).shared_kv_statesis captured once from a single base forward and reused at every round (per the Gemma 4 drafter tech report).<ckpt>/base/and<ckpt>/drafter/HF-loadable subdirs for vLLM / HF inference handoff.post_projectiononly whendrafter_num_steps > 1masked_embedding.centroidsalways frozen (thetorch.topkinside it blocks gradient flow back to the centroids).Recipe (
recipes/vlm/finetune.py,recipes/base_recipe.py)_maybe_add_drafter_loss(out, base_loss, labels, …)sums per-step CE over the composite'sdrafter_logitslist with_shift_labels_left(labels, k)per round.base_recipe.load_checkpointdispatches tomodel.load_pretrainedwhen the composite exposes it, so thebase/+drafter/subdir layout reloads correctly on resume.Text-only dataset adapter (
components/datasets/vlm/datasets.py)make_tulu3_magicoder_text_mix_datasetinterleavesallenai/tulu-3-sft-mixture(80 %) andise-uiuc/Magicoder-OSS-Instruct-75K(20 %) into{"conversation": [...]}dicts with noimagefield, sodefault_collate_fnemits batches withoutpixel_values. The composite + base accept text-only inputs unchanged.Example configs and benchmark (
examples/vlm_finetune/gemma4_joint_drafter/)gemma4_4b_joint_drafter_medpix.yaml— Joint fine-tuning recipe ofgoogle/gemma-4-E4B-it+google/gemma-4-E4B-it-assistantwith MedPix VQAgemma4_4b_joint_drafter_tulu_magicoder_mix.yaml— Joint fine-tuning on tulu and magiccoder mix, same base + drafter pairbenchmark_mtp_inference.py— measures speculative-decode acceptance + throughput end-to-end against the trained 4B pairTesting and Verification
examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yamlLoss curve:
examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_tulu_magicoder_mix.yamlLoss curve:
benchmark_mtp_inference.pyagainst the savedbase/+drafter/pair. Results below:Verified save + standalone HF reload works correctly. Checkpoints are stored in separate
<ckpt>/base/and<ckpt>/drafter/subdirs as a prerequisite for generation. After a 5-step training run, each sub-checkpoint loads via plain HFfrom_pretrainedwithout any NeMo-specific code:Gemma4ForConditionalGeneration.from_pretrained("<ckpt>/base/model/consolidated")→ 7.94 B params ✓Gemma4AssistantForCausalLM.from_pretrained("<ckpt>/drafter/model/consolidated")→ 78.5 M params ✓google/gemma-4-E4B-itandgoogle/gemma-4-E4B-it-assistant) exactly (0 missing, 0 extra).masked_embedding.token_ordering(int64[262144]) buffer survives the DCP → safetensors path.Verified resume-from-checkpoint loss parity. Run A: 10 steps fresh with
ckpt_every_steps=5. Run B: same config but--checkpoint.restore_from <RunA>/epoch_0_step_4. Per-step loss is bit-identical between A and B at every overlapping step (5–9):3.1257, 3.2992, 3.4045, 2.5375, 3.2824. Confirms model weights + optimizer state + LR scheduler state + dataloader state + RNG all restore correctly.Unit and functional tests.
tests/unit_tests/models/gemma4_drafter/test_composite.py— composite pre-projection input layout, shared-KV plumbing, K-step recurrence, save/load round-trip.tests/unit_tests/models/gemma4_drafter/test_composite_fsdp2.py— FSDP2 wrap + expert/grad sync.tests/unit_tests/models/gemma4_drafter/test_drafter_wrapper.py— drafter sub-module integration.tests/unit_tests/recipes/test_vlm_drafter_helpers.py—_shift_labels_left+_maybe_add_drafter_loss.tests/functional_tests/L2_HF_Transformer_VLM_Gemma4_Joint_Drafter.sh— L2 single-node smoke.Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information