Skip to content

Commit e6ac6cd

Browse files
committed
merge: bring latest master into aws test torch branch
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
2 parents 97f3cbd + bf0126b commit e6ac6cd

6 files changed

Lines changed: 435 additions & 83 deletions

File tree

deepspeed/module_inject/layers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ def _set_param_uc_meta(self,
369369
def _mark_uc_metadata(self):
370370
return
371371

372+
def _should_materialize_tp_partition(self):
373+
# AutoTP partitioning should only materialize parameters when an actual
374+
# TP process group is present. Metadata-only construction with
375+
# mp_group=None should not touch device placement.
376+
return self.mp_group is not None
377+
372378
def is_training_mode(self):
373379
global DEEPSPEED_AUTOTP_MODE
374380
return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING
@@ -579,7 +585,8 @@ def __init__(self, module, mp_group, **kwargs):
579585
self.weight = module.weight
580586
self.bias = module.bias
581587

582-
self._tp_partition([self.weight, self.bias])
588+
if self._should_materialize_tp_partition():
589+
self._tp_partition([self.weight, self.bias])
583590
self.support_training = True
584591
self.config_tp_params(self.weight)
585592
if self.bias is not None:
@@ -674,7 +681,7 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
674681
super(LinearLayer, self).__init__(mp_group, **kwargs)
675682
self.weight = module.weight
676683
self.bias = module.bias
677-
if not skip_partition:
684+
if not skip_partition and self._should_materialize_tp_partition():
678685
self._tp_partition([self.weight, self.bias])
679686
self.support_training = True
680687
self.config_tp_params(self.weight)
@@ -1234,7 +1241,8 @@ def __init__(self, module, mp_group, shape, partition_dim=0, **kwargs):
12341241
raise ValueError(f"AutoTP layer '{self.name}' bias size {self.bias.numel()} does not match output shape "
12351242
f"{self._output_shape}.")
12361243

1237-
self._tp_partition([self.weight, self.bias])
1244+
if self._should_materialize_tp_partition():
1245+
self._tp_partition([self.weight, self.bias])
12381246
self.support_training = True
12391247
self.config_tp_params(self.weight)
12401248
if self.bias is not None:
@@ -1352,7 +1360,8 @@ def __init__(self, module, mp_group, shape, partition_dim=1, **kwargs):
13521360
self._bias_partition_dim) = _infer_subparam_logical_shapes(self._orig_weight_shape, self.shape,
13531361
self.partition_dim, self.name)
13541362

1355-
self._tp_partition([self.weight, self.bias])
1363+
if self._should_materialize_tp_partition():
1364+
self._tp_partition([self.weight, self.bias])
13561365
self.support_training = True
13571366
self.config_tp_params(self.weight)
13581367
if self.bias is not None:

deepspeed/runtime/sequence_parallel/ulysses_sp.py

Lines changed: 116 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import deepspeed.comm as dist
4343
import importlib.metadata
4444
import math
45+
import re
4546
import torch
4647
import torch.distributed.nn
4748

@@ -121,6 +122,10 @@ def __init__(
121122
self.skip_all_but_last_attention_debug_mode = False
122123
self.rotating_layer_counter = 0 # used for dev work
123124

125+
self.core_attn_implementation = None # set by register_with_transformers
126+
self._flex_block_mask_cached = None # cached BlockMask for flex_attention
127+
self._flex_block_mask_cache_key = None # (batch_size, seq_len) for cache invalidation
128+
124129
self.local_q_head_count = attn_head_count // self.world_size
125130

126131
# if we have 4 kv heads and sp 8, we need to replicate kv heads 2x
@@ -272,23 +277,21 @@ def forward(
272277
key = rearrange(key, "bs hc sl hs -> sl bs hc hs") # .contiguous()
273278
value = rearrange(value, "bs hc sl hs -> sl bs hc hs") # .contiguous()
274279

275-
# core attn like FA2 expects an unsharded `position_ids` - without which packed samples
276-
# will return loss=nan.
277-
#
278-
# XXX: need to figure out if we can do the same for SDPA - as it doesn't require this and
279-
# wants an attention mask, so possibly doing this for FA2 only?
280-
#
281-
# Ideally we would passing the original unsharded position_ids - but we have no way to pass
282-
# it here as HF Transformers drops unexpected keys in `batch` - so either we need to stash
283-
# it somewhere in UlyssesSPDataLoaderAdapter and retrieve it here or we could gather it once
284-
# per batch and stash it inside `module` arg - I already have a machinery to figure out
285-
# which layer number is being called below in the skip_all_but_last_attention_debug_mode
286-
# code where rotating_layer_counter is used - so we could calculate it on the first layer
287-
# and re-use on the remaining layers
288-
if "position_ids" in kwargs:
289-
position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)]
290-
dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group)
291-
kwargs["position_ids"] = torch.cat(position_ids_list, dim=1)
280+
# All attention backends need unsharded position_ids after the all-to-all.
281+
# FA2 uses them for packed-sequence detection (flash_varlen_fn), sdpa/flex_attention
282+
# need them to be monotonically increasing so causal masking works correctly.
283+
# UlyssesSPDataLoaderAdapter ensures position_ids are in the batch before sharding,
284+
# so after gathering here they reconstruct to the correct global positions.
285+
assert "position_ids" in kwargs, (
286+
"Ulysses SP requires position_ids in every forward() call so that after all_gather "
287+
"causal masking works correctly. Without them each rank generates local [0..chunk_len-1] "
288+
"positions which, after gathering, look like packed sequences and break attention. "
289+
"For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
290+
"For packed sequences: position_ids must reset at document boundaries. "
291+
"Ensure your data collator or UlyssesSPDataLoaderAdapter includes position_ids.")
292+
position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)]
293+
dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group)
294+
kwargs["position_ids"] = torch.cat(position_ids_list, dim=1)
292295

293296
# please don't remove the white-space vertical alignment in the error message
294297
assert query.shape == self.required_query_shape, (
@@ -311,6 +314,41 @@ def forward(
311314
if self.kv_replication_factor > 1:
312315
module.num_key_value_groups = query_layer.size(-3) // key_layer.size(-3)
313316

317+
# For flex_attention: the wrapper preserved the BlockMask from the model, but it
318+
# was built for the local shard's sequence length. Rebuild it for the full gathered
319+
# sequence length after the all-to-all.
320+
# XXX: currently hardcodes a causal mask_mod — models with sliding window or other
321+
# non-standard patterns would need the mask_mod extracted from the original BlockMask.
322+
if self.core_attn_implementation == "flex_attention":
323+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
324+
if isinstance(attention_mask, BlockMask):
325+
seq_len = query_layer.shape[2]
326+
batch_size = query_layer.shape[0]
327+
cache_key = (batch_size, seq_len)
328+
329+
# Cache the BlockMask — create_block_mask is expensive and the mask is the
330+
# same for all layers within a forward pass. Only rebuild when dimensions change.
331+
if self._flex_block_mask_cache_key != cache_key:
332+
333+
def causal_mask(batch_idx, head_idx, q_idx, kv_idx):
334+
return q_idx >= kv_idx
335+
336+
# Don't compile create_block_mask here — it runs inside the model's
337+
# forward pass where flex_attention already uses torch.compile, and
338+
# nesting compiled contexts causes gradient explosion in the backward
339+
# pass. The BlockMask is cached so creation cost is negligible.
340+
self._flex_block_mask_cached = create_block_mask(
341+
mask_mod=causal_mask,
342+
B=batch_size,
343+
H=None,
344+
Q_LEN=seq_len,
345+
KV_LEN=seq_len,
346+
device=query_layer.device,
347+
)
348+
self._flex_block_mask_cache_key = cache_key
349+
350+
attention_mask = self._flex_block_mask_cached
351+
314352
if not self.skip_all_but_last_attention_debug_mode:
315353
# expects: [bs hc_l sl hs]
316354
context_layer, attn_weights = self.attn(module, query_layer, key_layer, value_layer, attention_mask, *args,
@@ -411,15 +449,34 @@ def register_with_transformers(
411449
# if we don't have the model yet at this stage
412450
hf_model_config = AutoConfig.from_pretrained(model_name_or_path)
413451

414-
supported_attn_implementation = ["flash_attention_2", "flash_attention_3", "sdpa"]
415-
if core_attn_implementation not in supported_attn_implementation:
416-
# notes on the excluded ones:
417-
# - eager: The problem is that `eager` wants an attention_mask and it creates the wrong attention mask it seems if we don't provide one - it's possible that we could somehow solve this, but it's also unlikely someone will want to use the slow eager attention with sequence parallelism
418-
# - flex_attention: haven't tried
419-
452+
model_attn_implementation = getattr(hf_model_config, "_attn_implementation", None)
453+
if model_attn_implementation is not None and model_attn_implementation != core_attn_implementation:
454+
raise ValueError(
455+
f"core_attn_implementation='{core_attn_implementation}' does not match "
456+
f"model config attn_implementation='{model_attn_implementation}'. "
457+
"Set both to the same value so sequence-parallel wrapper can intercept the active attention path.")
458+
459+
# eager always materializes a 4D attention_mask (O(n²) memory) and cannot fall back
460+
# to is_causal=True like sdpa — so it's incompatible with SP which discards masks.
461+
unsupported_attn_implementation = ["eager", "paged|eager"]
462+
if core_attn_implementation in unsupported_attn_implementation:
420463
raise ValueError(
421464
f"{core_attn_implementation} attn_implementation isn't currently supported by Ulysses sequence"
422-
f" parallelism. Set core_attn_implementation arg to one of {supported_attn_implementation}.")
465+
f" parallelism because it requires a 4D attention_mask (O(n²) memory)."
466+
f" Use any flash attention variant, 'flex_attention', 'sdpa',"
467+
f" or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2').")
468+
469+
# Hub kernels (e.g. kernels-community/flash-attn2) are registered lazily in transformers.
470+
# Ensure registration happens before validating against ALL_ATTENTION_FUNCTIONS.
471+
is_hub_kernel_attn = (isinstance(core_attn_implementation, str) and re.search(
472+
r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", core_attn_implementation) is not None)
473+
if is_hub_kernel_attn:
474+
try:
475+
from transformers.modeling_flash_attention_utils import lazy_import_flash_attention
476+
except ImportError as e:
477+
raise ImportError("Hub kernel attention requires a transformers version exposing "
478+
"`transformers.modeling_flash_attention_utils.lazy_import_flash_attention`.") from e
479+
lazy_import_flash_attention(core_attn_implementation)
423480

424481
if core_attn_implementation not in ALL_ATTENTION_FUNCTIONS:
425482
raise ValueError(
@@ -448,6 +505,7 @@ def register_with_transformers(
448505
global_seq_length=global_seq_length,
449506
disable_in_eval=disable_in_eval,
450507
)
508+
uattn.core_attn_implementation = core_attn_implementation
451509

452510
def uattn_wrapper(
453511
module: torch.nn.Module,
@@ -459,27 +517,41 @@ def uattn_wrapper(
459517
**kwargs,
460518
) -> Tuple[torch.Tensor, torch.Tensor]:
461519

462-
# We are relaying on position_ids for SP to work so attention_mask has to be None
463-
# the problem is that HF currently doesn't know anything about ALL_ATTENTION_FUNCTIONS["ulysses"] so it doesn't make a special case like for "flash_attention_2" and "sdpa" and it creates an attention mask on the fly and it breaks things.
464-
attention_mask = None
520+
# SP relies on position_ids (not attention_mask) for causal masking.
521+
# HF doesn't know about the SP wrapper, so it creates an attention_mask for
522+
# the local shard's sequence length — which is invalid after the SP all-to-all
523+
# gathers the full sequence. A 4D mask at full sequence length would also be
524+
# O(n²) memory. So we discard 4D tensor masks.
525+
#
526+
# Keep BlockMask (flex_attention) — it's a compressed sparse representation.
527+
# It will be rebuilt for the full gathered sequence in forward().
528+
_is_block_mask = False
529+
if core_attn_implementation == "flex_attention":
530+
from torch.nn.attention.flex_attention import BlockMask
531+
_is_block_mask = isinstance(attention_mask, BlockMask)
532+
533+
if not _is_block_mask:
534+
attention_mask = None
465535

466536
attn_output, attn_weights = uattn(
467537
module,
468538
query,
469539
key,
470540
value,
471541
attention_mask,
472-
# XXX: fixme
473542
*args,
474543
**kwargs,
475544
)
476545
return attn_output, attn_weights
477546

478547
# We don't do: ALL_ATTENTION_FUNCTIONS.register("ulysses", uattn_wrapper)
479-
# The problem with this approach is that we are missing on all the special use cases in HF Transformers that do things like: if self.config._attn_implementation == "flash_attention_2": ...
480-
# So instead we hack `ALL_ATTENTION_FUNCTIONS` to override all existing keys with our implementation, since it only gets used at the point of calling the attention and that's what we want, all other code branches relying on the original core `attn_implementation` will still be executed. This is what we called "Being John Malkovich"
481-
for key in ALL_ATTENTION_FUNCTIONS.keys():
482-
ALL_ATTENTION_FUNCTIONS[key] = uattn_wrapper
548+
# The problem with that approach is that we'd miss all the special-case branches in
549+
# HF Transformers that check `if self.config._attn_implementation == "flash_attention_2": ...`
550+
# So instead we override the requested core implementation key in ALL_ATTENTION_FUNCTIONS
551+
# with our wrapper. All other code paths relying on the original core attn_implementation
552+
# will still be executed — we only intercept at the point of calling attention.
553+
# This is what we called "Being John Malkovich".
554+
ALL_ATTENTION_FUNCTIONS[core_attn_implementation] = uattn_wrapper
483555

484556
return mpu
485557

@@ -574,6 +646,18 @@ def refill(self):
574646
micro_batches = defaultdict(dict)
575647
# XXX: replace with more efficient all-to-all?
576648

649+
# position_ids must exist before sharding so that after all_gather in
650+
# UlyssesSPAttentionHF.forward() they reconstruct to correct global positions.
651+
# Without them, the Trainer generates local [0,...,chunk_len-1] per rank AFTER
652+
# sharding, which after all_gather looks like packed sequences and breaks
653+
# sdpa/flex_attention causal masking.
654+
if "position_ids" not in batch:
655+
raise ValueError("Ulysses SP requires `position_ids` in every dataloader batch so that "
656+
"each token retains its correct global position after sequence sharding. "
657+
"For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
658+
"For packed sequences: position_ids must reset at document boundaries. "
659+
"Ensure your data collator includes position_ids in its output.")
660+
577661
# we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank
578662
seqlen = torch.tensor(batch["input_ids"].shape[1], dtype=torch.int64, device=self.device)
579663
seqlens = [torch.zeros(1, dtype=torch.int64, device=self.device) for _ in range(self.sp_world_size)]

deepspeed/runtime/zero/stage3.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import gc
88
import collections
99
import itertools
10-
from typing import Deque, Dict, Set, List, Tuple, Container, Optional
10+
from typing import Deque, Dict, Set, List, Container, Optional
1111
from contextlib import contextmanager
1212
from dataclasses import dataclass, field
1313

@@ -21,13 +21,13 @@
2121
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
2222
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
2323
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
24-
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward
24+
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward
2525
from deepspeed.runtime.zero.partition_parameters import *
2626
from deepspeed.runtime.zero.config import ZeroStageEnum
2727
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
2828
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
2929
import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3
30-
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
30+
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer, defragment
3131
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
3232
from deepspeed.ops.adam import DeepSpeedCPUAdam
3333
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
@@ -655,38 +655,6 @@ def get_lr(self):
655655
"""Return the current learning rate."""
656656
return self.optimizer.param_groups[0]["lr"]
657657

658-
# TODO. factor out to a utility outside of stage3
659-
@staticmethod
660-
def defragment(tensors: List[Tensor]) -> Tensor:
661-
"""move provided tensors into a contiguous flat buffer, with some additional
662-
measures taken to reduce memory fragmentation"""
663-
assert len(set(t.dtype for t in tensors)) == 1
664-
assert len(set(t.device for t in tensors)) == 1
665-
666-
cpu_buffer = torch.empty(sum(p.numel() for p in tensors),
667-
dtype=get_only_unique_item(t.dtype for t in tensors),
668-
device="cpu")
669-
tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors)
670-
orig_device = get_only_unique_item(t.device for t in tensors)
671-
672-
offset = 0
673-
for tensor, offset, tensor_numel in tensor_infos:
674-
# move the tensor from device memory to host memory
675-
cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor)
676-
tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device)
677-
678-
gc.collect()
679-
get_accelerator().empty_cache()
680-
681-
# copy tensors (now flattened and contiguous) back to GPU
682-
device_buffer = cpu_buffer.to(orig_device)
683-
684-
# restore device tensors
685-
for tensor, offset, tensor_numel in tensor_infos:
686-
tensor.data = device_buffer.narrow(0, offset, tensor_numel)
687-
688-
return device_buffer
689-
690658
def _get_param_coordinator(self):
691659
return self.parameter_offload.get_param_coordinator()
692660

@@ -834,7 +802,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
834802
parameter_partitions = self._get_parameter_partitions()
835803

836804
# We need to keep the reference to this buffer to make sure you can free it in `offload_states`
837-
self.lp_param_buffer = __class__.defragment(parameter_partitions)
805+
self.lp_param_buffer = defragment(parameter_partitions)
838806
self._set_fp16_partitioned_groups_flat()
839807

840808
else: # partitioned params offloaded to CPU when not in use

0 commit comments

Comments
 (0)