Environment
- TransformerEngine: 2.12.0
- Megatron-Core: 0.16.0
- PyTorch: 2.9.1
- CUDA: 12.9
- GPU: H100
Description
DotProductAttention with thd (packed sequence) format crashes with SIGSEGV when cu_seqlens_q differs from cu_seqlens_q_padded and FP8 blockwise recipe is used.
Per the TE documentation, cu_seqlens_q should contain actual token boundaries while cu_seqlens_q_padded contains the padded memory layout boundaries. This is needed when there is padding between sequences in a packed batch (e.g., for FP8 alignment).
Expected Behavior
Attention should:
- Use
cu_seqlens_q for attention masking (only attend to real tokens)
- Use
cu_seqlens_q_padded for memory layout (tensor indexing)
- Return output tensor in padded layout (same shape as input Q)
Actual Behavior
- FP8 blockwise: SIGSEGV in
tex.fused_attn_fwd C++ kernel
- FP8 delayed: Output tensor has unpadded size instead of padded, causing downstream shape mismatches
- BF16 (no FP8): Works correctly when cu_seqlens differ (non-FP8 backends handle it)
Impact
This bug makes FP8 training with sequence packing unusable when FP8 alignment padding is needed. The VERL framework (volcengine/verl) adds FP8 alignment padding for TE compatibility but passes identical cu_seqlens_q and cu_seqlens_q_padded as a workaround. This causes padding tokens to be visible to attention, corrupting the model output (training perplexity goes from 3.7 to 3055).
Reproduction
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention import DotProductAttention
# Setup
batch_size = 4
real_seqlens = [100, 150, 120, 130] # actual sequence lengths
padded_seqlens = [112, 160, 128, 144] # padded to 16-byte alignment for FP8
total_padded = sum(padded_seqlens)
hidden_dim = 1536
num_heads = 32
head_dim = hidden_dim // num_heads
# Create cu_seqlens
cu_seqlens_q = torch.tensor([0] + list(torch.cumsum(torch.tensor(real_seqlens), 0)),
dtype=torch.int32, device='cuda')
cu_seqlens_q_padded = torch.tensor([0] + list(torch.cumsum(torch.tensor(padded_seqlens), 0)),
dtype=torch.int32, device='cuda')
# Create Q, K, V in thd format (padded layout)
q = torch.randn(total_padded, num_heads, head_dim, device='cuda', dtype=torch.bfloat16)
k = torch.randn(total_padded, num_heads, head_dim, device='cuda', dtype=torch.bfloat16)
v = torch.randn(total_padded, num_heads, head_dim, device='cuda', dtype=torch.bfloat16)
# This works (cu_seqlens_q == cu_seqlens_q_padded):
attn = DotProductAttention(num_heads, head_dim, head_dim)
with te.fp8_autocast(enabled=True, fp8_recipe=te.recipe.BlockScaling()):
out = attn(q, k, v,
qkv_format='thd',
cu_seqlens_q=cu_seqlens_q_padded, # same as padded
cu_seqlens_kv=cu_seqlens_q_padded,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_q_padded,
max_seqlen_q=max(padded_seqlens),
max_seqlen_kv=max(padded_seqlens),
attn_mask_type='causal')
# This CRASHES (cu_seqlens_q != cu_seqlens_q_padded):
with te.fp8_autocast(enabled=True, fp8_recipe=te.recipe.BlockScaling()):
out = attn(q, k, v,
qkv_format='thd',
cu_seqlens_q=cu_seqlens_q, # actual boundaries
cu_seqlens_kv=cu_seqlens_q,
cu_seqlens_q_padded=cu_seqlens_q_padded, # padded boundaries
cu_seqlens_kv_padded=cu_seqlens_q_padded,
max_seqlen_q=max(real_seqlens),
max_seqlen_kv=max(real_seqlens),
attn_mask_type='causal')
# SIGSEGV ^^^
Core Problem: cu_seqlens_q used for BOTH attention masking AND SP communication
When cu_seqlens_q != cu_seqlens_q_padded:
- DotProductAttention: Works correctly in isolation (masks padding, returns padded-size output)
- ColumnParallelLinear with sequence_parallel=True: Uses
cu_seqlens_q for allgather sizing → allgathered tensor gets UNPADDED size → breaks downstream RoPE/Linear that expect padded size
This means TE uses cu_seqlens_q for two conflicting purposes:
- Attention masking (should use actual/unpadded boundaries)
- SP allgather/scatter (should use padded boundaries)
Proposed fix: TE should separate these two uses. SP communication should ALWAYS use cu_seqlens_q_padded for sizing, while attention masking should use cu_seqlens_q.
Impact on Training
With VERL framework training DeepSeek 10B MoE with Megatron-Core:
- BF16 baseline: grad_norm=0.28, training_ppl=3.7
- FP8 E2E (cu_seqlens_q == cu_seqlens_q_padded): grad_norm=130-500, training_ppl=3055 (garbage)
- BF16 + FP8 padding only (no FP8 compute): grad_norm=1064, training_ppl=6.6 (proves padding is the cause)
Related Issues
repro_fp8_padding_bug.py
Environment
Description
DotProductAttentionwiththd(packed sequence) format crashes with SIGSEGV whencu_seqlens_qdiffers fromcu_seqlens_q_paddedand FP8 blockwise recipe is used.Per the TE documentation,
cu_seqlens_qshould contain actual token boundaries whilecu_seqlens_q_paddedcontains the padded memory layout boundaries. This is needed when there is padding between sequences in a packed batch (e.g., for FP8 alignment).Expected Behavior
Attention should:
cu_seqlens_qfor attention masking (only attend to real tokens)cu_seqlens_q_paddedfor memory layout (tensor indexing)Actual Behavior
tex.fused_attn_fwdC++ kernelImpact
This bug makes FP8 training with sequence packing unusable when FP8 alignment padding is needed. The VERL framework (volcengine/verl) adds FP8 alignment padding for TE compatibility but passes identical cu_seqlens_q and cu_seqlens_q_padded as a workaround. This causes padding tokens to be visible to attention, corrupting the model output (training perplexity goes from 3.7 to 3055).
Reproduction
Core Problem: cu_seqlens_q used for BOTH attention masking AND SP communication
When
cu_seqlens_q != cu_seqlens_q_padded:cu_seqlens_qfor allgather sizing → allgathered tensor gets UNPADDED size → breaks downstream RoPE/Linear that expect padded sizeThis means TE uses
cu_seqlens_qfor two conflicting purposes:Proposed fix: TE should separate these two uses. SP communication should ALWAYS use
cu_seqlens_q_paddedfor sizing, while attention masking should usecu_seqlens_q.Impact on Training
With VERL framework training DeepSeek 10B MoE with Megatron-Core:
Related Issues
pad_between_seqs=TrueDotProductAttentiondocstring showscu_seqlens_q=[0,3,5,9]withcu_seqlens_q_padded=[0,4,8,13]as valid use casemulti_latent_attention.py:567-574expects cu_seqlens_q to differ from cu_seqlens_q_paddedrepro_fp8_padding_bug.py