4242import deepspeed .comm as dist
4343import importlib .metadata
4444import math
45+ import re
4546import torch
4647import torch .distributed .nn
4748
@@ -121,6 +122,10 @@ def __init__(
121122 self .skip_all_but_last_attention_debug_mode = False
122123 self .rotating_layer_counter = 0 # used for dev work
123124
125+ self .core_attn_implementation = None # set by register_with_transformers
126+ self ._flex_block_mask_cached = None # cached BlockMask for flex_attention
127+ self ._flex_block_mask_cache_key = None # (batch_size, seq_len) for cache invalidation
128+
124129 self .local_q_head_count = attn_head_count // self .world_size
125130
126131 # if we have 4 kv heads and sp 8, we need to replicate kv heads 2x
@@ -272,23 +277,21 @@ def forward(
272277 key = rearrange (key , "bs hc sl hs -> sl bs hc hs" ) # .contiguous()
273278 value = rearrange (value , "bs hc sl hs -> sl bs hc hs" ) # .contiguous()
274279
275- # core attn like FA2 expects an unsharded `position_ids` - without which packed samples
276- # will return loss=nan.
277- #
278- # XXX: need to figure out if we can do the same for SDPA - as it doesn't require this and
279- # wants an attention mask, so possibly doing this for FA2 only?
280- #
281- # Ideally we would passing the original unsharded position_ids - but we have no way to pass
282- # it here as HF Transformers drops unexpected keys in `batch` - so either we need to stash
283- # it somewhere in UlyssesSPDataLoaderAdapter and retrieve it here or we could gather it once
284- # per batch and stash it inside `module` arg - I already have a machinery to figure out
285- # which layer number is being called below in the skip_all_but_last_attention_debug_mode
286- # code where rotating_layer_counter is used - so we could calculate it on the first layer
287- # and re-use on the remaining layers
288- if "position_ids" in kwargs :
289- position_ids_list = [torch .empty_like (kwargs ["position_ids" ]) for _ in range (self .world_size )]
290- dist .all_gather (position_ids_list , kwargs ["position_ids" ], group = self .process_group )
291- kwargs ["position_ids" ] = torch .cat (position_ids_list , dim = 1 )
280+ # All attention backends need unsharded position_ids after the all-to-all.
281+ # FA2 uses them for packed-sequence detection (flash_varlen_fn), sdpa/flex_attention
282+ # need them to be monotonically increasing so causal masking works correctly.
283+ # UlyssesSPDataLoaderAdapter ensures position_ids are in the batch before sharding,
284+ # so after gathering here they reconstruct to the correct global positions.
285+ assert "position_ids" in kwargs , (
286+ "Ulysses SP requires position_ids in every forward() call so that after all_gather "
287+ "causal masking works correctly. Without them each rank generates local [0..chunk_len-1] "
288+ "positions which, after gathering, look like packed sequences and break attention. "
289+ "For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
290+ "For packed sequences: position_ids must reset at document boundaries. "
291+ "Ensure your data collator or UlyssesSPDataLoaderAdapter includes position_ids." )
292+ position_ids_list = [torch .empty_like (kwargs ["position_ids" ]) for _ in range (self .world_size )]
293+ dist .all_gather (position_ids_list , kwargs ["position_ids" ], group = self .process_group )
294+ kwargs ["position_ids" ] = torch .cat (position_ids_list , dim = 1 )
292295
293296 # please don't remove the white-space vertical alignment in the error message
294297 assert query .shape == self .required_query_shape , (
@@ -311,6 +314,41 @@ def forward(
311314 if self .kv_replication_factor > 1 :
312315 module .num_key_value_groups = query_layer .size (- 3 ) // key_layer .size (- 3 )
313316
317+ # For flex_attention: the wrapper preserved the BlockMask from the model, but it
318+ # was built for the local shard's sequence length. Rebuild it for the full gathered
319+ # sequence length after the all-to-all.
320+ # XXX: currently hardcodes a causal mask_mod — models with sliding window or other
321+ # non-standard patterns would need the mask_mod extracted from the original BlockMask.
322+ if self .core_attn_implementation == "flex_attention" :
323+ from torch .nn .attention .flex_attention import BlockMask , create_block_mask
324+ if isinstance (attention_mask , BlockMask ):
325+ seq_len = query_layer .shape [2 ]
326+ batch_size = query_layer .shape [0 ]
327+ cache_key = (batch_size , seq_len )
328+
329+ # Cache the BlockMask — create_block_mask is expensive and the mask is the
330+ # same for all layers within a forward pass. Only rebuild when dimensions change.
331+ if self ._flex_block_mask_cache_key != cache_key :
332+
333+ def causal_mask (batch_idx , head_idx , q_idx , kv_idx ):
334+ return q_idx >= kv_idx
335+
336+ # Don't compile create_block_mask here — it runs inside the model's
337+ # forward pass where flex_attention already uses torch.compile, and
338+ # nesting compiled contexts causes gradient explosion in the backward
339+ # pass. The BlockMask is cached so creation cost is negligible.
340+ self ._flex_block_mask_cached = create_block_mask (
341+ mask_mod = causal_mask ,
342+ B = batch_size ,
343+ H = None ,
344+ Q_LEN = seq_len ,
345+ KV_LEN = seq_len ,
346+ device = query_layer .device ,
347+ )
348+ self ._flex_block_mask_cache_key = cache_key
349+
350+ attention_mask = self ._flex_block_mask_cached
351+
314352 if not self .skip_all_but_last_attention_debug_mode :
315353 # expects: [bs hc_l sl hs]
316354 context_layer , attn_weights = self .attn (module , query_layer , key_layer , value_layer , attention_mask , * args ,
@@ -411,15 +449,34 @@ def register_with_transformers(
411449 # if we don't have the model yet at this stage
412450 hf_model_config = AutoConfig .from_pretrained (model_name_or_path )
413451
414- supported_attn_implementation = ["flash_attention_2" , "flash_attention_3" , "sdpa" ]
415- if core_attn_implementation not in supported_attn_implementation :
416- # notes on the excluded ones:
417- # - eager: The problem is that `eager` wants an attention_mask and it creates the wrong attention mask it seems if we don't provide one - it's possible that we could somehow solve this, but it's also unlikely someone will want to use the slow eager attention with sequence parallelism
418- # - flex_attention: haven't tried
419-
452+ model_attn_implementation = getattr (hf_model_config , "_attn_implementation" , None )
453+ if model_attn_implementation is not None and model_attn_implementation != core_attn_implementation :
454+ raise ValueError (
455+ f"core_attn_implementation='{ core_attn_implementation } ' does not match "
456+ f"model config attn_implementation='{ model_attn_implementation } '. "
457+ "Set both to the same value so sequence-parallel wrapper can intercept the active attention path." )
458+
459+ # eager always materializes a 4D attention_mask (O(n²) memory) and cannot fall back
460+ # to is_causal=True like sdpa — so it's incompatible with SP which discards masks.
461+ unsupported_attn_implementation = ["eager" , "paged|eager" ]
462+ if core_attn_implementation in unsupported_attn_implementation :
420463 raise ValueError (
421464 f"{ core_attn_implementation } attn_implementation isn't currently supported by Ulysses sequence"
422- f" parallelism. Set core_attn_implementation arg to one of { supported_attn_implementation } ." )
465+ f" parallelism because it requires a 4D attention_mask (O(n²) memory)."
466+ f" Use any flash attention variant, 'flex_attention', 'sdpa',"
467+ f" or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2')." )
468+
469+ # Hub kernels (e.g. kernels-community/flash-attn2) are registered lazily in transformers.
470+ # Ensure registration happens before validating against ALL_ATTENTION_FUNCTIONS.
471+ is_hub_kernel_attn = (isinstance (core_attn_implementation , str ) and re .search (
472+ r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$" , core_attn_implementation ) is not None )
473+ if is_hub_kernel_attn :
474+ try :
475+ from transformers .modeling_flash_attention_utils import lazy_import_flash_attention
476+ except ImportError as e :
477+ raise ImportError ("Hub kernel attention requires a transformers version exposing "
478+ "`transformers.modeling_flash_attention_utils.lazy_import_flash_attention`." ) from e
479+ lazy_import_flash_attention (core_attn_implementation )
423480
424481 if core_attn_implementation not in ALL_ATTENTION_FUNCTIONS :
425482 raise ValueError (
@@ -448,6 +505,7 @@ def register_with_transformers(
448505 global_seq_length = global_seq_length ,
449506 disable_in_eval = disable_in_eval ,
450507 )
508+ uattn .core_attn_implementation = core_attn_implementation
451509
452510 def uattn_wrapper (
453511 module : torch .nn .Module ,
@@ -459,27 +517,41 @@ def uattn_wrapper(
459517 ** kwargs ,
460518 ) -> Tuple [torch .Tensor , torch .Tensor ]:
461519
462- # We are relaying on position_ids for SP to work so attention_mask has to be None
463- # the problem is that HF currently doesn't know anything about ALL_ATTENTION_FUNCTIONS["ulysses"] so it doesn't make a special case like for "flash_attention_2" and "sdpa" and it creates an attention mask on the fly and it breaks things.
464- attention_mask = None
520+ # SP relies on position_ids (not attention_mask) for causal masking.
521+ # HF doesn't know about the SP wrapper, so it creates an attention_mask for
522+ # the local shard's sequence length — which is invalid after the SP all-to-all
523+ # gathers the full sequence. A 4D mask at full sequence length would also be
524+ # O(n²) memory. So we discard 4D tensor masks.
525+ #
526+ # Keep BlockMask (flex_attention) — it's a compressed sparse representation.
527+ # It will be rebuilt for the full gathered sequence in forward().
528+ _is_block_mask = False
529+ if core_attn_implementation == "flex_attention" :
530+ from torch .nn .attention .flex_attention import BlockMask
531+ _is_block_mask = isinstance (attention_mask , BlockMask )
532+
533+ if not _is_block_mask :
534+ attention_mask = None
465535
466536 attn_output , attn_weights = uattn (
467537 module ,
468538 query ,
469539 key ,
470540 value ,
471541 attention_mask ,
472- # XXX: fixme
473542 * args ,
474543 ** kwargs ,
475544 )
476545 return attn_output , attn_weights
477546
478547 # We don't do: ALL_ATTENTION_FUNCTIONS.register("ulysses", uattn_wrapper)
479- # The problem with this approach is that we are missing on all the special use cases in HF Transformers that do things like: if self.config._attn_implementation == "flash_attention_2": ...
480- # So instead we hack `ALL_ATTENTION_FUNCTIONS` to override all existing keys with our implementation, since it only gets used at the point of calling the attention and that's what we want, all other code branches relying on the original core `attn_implementation` will still be executed. This is what we called "Being John Malkovich"
481- for key in ALL_ATTENTION_FUNCTIONS .keys ():
482- ALL_ATTENTION_FUNCTIONS [key ] = uattn_wrapper
548+ # The problem with that approach is that we'd miss all the special-case branches in
549+ # HF Transformers that check `if self.config._attn_implementation == "flash_attention_2": ...`
550+ # So instead we override the requested core implementation key in ALL_ATTENTION_FUNCTIONS
551+ # with our wrapper. All other code paths relying on the original core attn_implementation
552+ # will still be executed — we only intercept at the point of calling attention.
553+ # This is what we called "Being John Malkovich".
554+ ALL_ATTENTION_FUNCTIONS [core_attn_implementation ] = uattn_wrapper
483555
484556 return mpu
485557
@@ -574,6 +646,18 @@ def refill(self):
574646 micro_batches = defaultdict (dict )
575647 # XXX: replace with more efficient all-to-all?
576648
649+ # position_ids must exist before sharding so that after all_gather in
650+ # UlyssesSPAttentionHF.forward() they reconstruct to correct global positions.
651+ # Without them, the Trainer generates local [0,...,chunk_len-1] per rank AFTER
652+ # sharding, which after all_gather looks like packed sequences and breaks
653+ # sdpa/flex_attention causal masking.
654+ if "position_ids" not in batch :
655+ raise ValueError ("Ulysses SP requires `position_ids` in every dataloader batch so that "
656+ "each token retains its correct global position after sequence sharding. "
657+ "For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
658+ "For packed sequences: position_ids must reset at document boundaries. "
659+ "Ensure your data collator includes position_ids in its output." )
660+
577661 # we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank
578662 seqlen = torch .tensor (batch ["input_ids" ].shape [1 ], dtype = torch .int64 , device = self .device )
579663 seqlens = [torch .zeros (1 , dtype = torch .int64 , device = self .device ) for _ in range (self .sp_world_size )]
0 commit comments