Skip to content

Commit 37e232f

Browse files
authored
fix(zero): detach flat buffer to prevent autograd inplace error on CP… (#7948)
…U accelerator The on-device flatten path (introduced in #7828) passes nn.Parameter objects with requires_grad=True to torch.cat(), creating a flat buffer with CatBackward0 grad_fn. Later, _unflatten_dense_tensors produces SplitBackward0 views that are assigned to model params. Inplace copy_() on these views during optimizer step raises: RuntimeError: Output 0 of SplitBackward0 is a view and is being modified inplace. This especially affects CPU training where CPU_Accelerator.is_available() returns True and available_memory() returns system RAM, so the on-device path is always taken. Fix: add .detach() to the flattened buffer, matching the implicit detach behavior of the CPU-offload path (param.data.cpu() + .to(device)). Also rename flatten_on_gpu -> flatten_on_accelerator and replace GPU-specific terminology in comments/logs with accelerator-generic equivalents. --------- Signed-off-by: Guokai Ma <[email protected]> Signed-off-by: Ma, Guokai <[email protected]>
1 parent bf0126b commit 37e232f

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _enforce_cpu_offload():
368368
# not sure why apex was cloning the weights before flattening
369369
# removing cloning here
370370

371-
# Compute group size for VRAM check (need 2x model size on GPU to flatten in place: params + flat copy)
371+
# Compute group size for memory check (need 2x model size on accelerator to flatten in place: params + flat copy)
372372
orig_group_numel = sum(param.numel() for param in self.bit16_groups[i])
373373
alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])
374374
aligned_numel = int(math.ceil(orig_group_numel / alignment)) * alignment
@@ -378,13 +378,13 @@ def _enforce_cpu_offload():
378378

379379
empty_cache()
380380
accelerator = get_accelerator()
381-
available_vram = accelerator.available_memory() if accelerator.is_available() else 0
382-
# Flatten on GPU only if we have enough VRAM for the flat buffer (2x = params already there + copy)
383-
flatten_on_gpu = (accelerator.is_available() and (available_vram >= flat_buffer_bytes))
381+
available_memory = accelerator.available_memory() if accelerator.is_available() else 0
382+
# Flatten on accelerator device if we have enough memory for the flat buffer
383+
flatten_on_accelerator = (accelerator.is_available() and (available_memory >= flat_buffer_bytes))
384384

385-
if not flatten_on_gpu:
385+
if not flatten_on_accelerator:
386386
see_memory_usage(f"Before moving param group {i} to CPU")
387-
# move all the parameters to cpu to free up GPU space for creating flat buffer
387+
# move all the parameters to cpu to free up accelerator memory for creating flat buffer
388388
for param in self.bit16_groups[i]:
389389
param.cpu_data = param.data.cpu()
390390
param.data = torch.empty(1).to(param.device)
@@ -409,21 +409,21 @@ def _enforce_cpu_offload():
409409
# Create meta tensors list, ordered according to round_robin_tensors
410410
meta_tensors = []
411411
for param in round_robin_tensors:
412-
if flatten_on_gpu:
412+
if flatten_on_accelerator:
413413
meta_tensors.append(torch.zeros_like(param.data, device="meta"))
414414
else:
415415
meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta"))
416416
self.round_robin_bit16_meta.append(meta_tensors)
417417

418-
if flatten_on_gpu:
419-
logger.info(f"Flattening param group {i} on GPU (sufficient VRAM)")
418+
if flatten_on_accelerator:
419+
logger.info(f"Flattening param group {i} on {accelerator.device_name()} (sufficient memory)")
420420
flattened_buffer = self.flatten_dense_tensors_aligned(self.round_robin_bit16_groups[i],
421421
alignment,
422-
use_cpu_data=False)
422+
use_cpu_data=False).detach()
423423
self.bit16_groups_flat.append(flattened_buffer)
424-
see_memory_usage(f"After flattening param group {i} on GPU", force=False)
424+
see_memory_usage(f"After flattening param group {i} on {accelerator.device_name()}", force=False)
425425
else:
426-
logger.info(f"Flattening param group {i} on CPU (insufficient VRAM)")
426+
logger.info(f"Flattening param group {i} on CPU (insufficient memory)")
427427

428428
flattened_buffer = self.flatten_dense_tensors_aligned(self.round_robin_bit16_groups[i],
429429
alignment,
@@ -437,7 +437,8 @@ def _enforce_cpu_offload():
437437
self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name()))
438438
del flattened_buffer
439439

440-
see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)
440+
see_memory_usage(f"After flattening and moving param group {i} to {get_accelerator().device_name()}",
441+
force=False)
441442

442443
if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
443444
see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)

tests/unit/v1/zero/test_stage2_flatten_on_gpu.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
"""
99

1010
import pytest
11+
import torch
1112
import deepspeed
1213
from deepspeed.accelerator import get_accelerator
1314
from deepspeed.utils import set_log_level_from_string
1415
from unit.common import DistributedTest
15-
from unit.simple_model import SimpleModel
16+
from unit.simple_model import SimpleModel, random_dataloader
17+
18+
_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
1619

1720

1821
def _apply_dtype_to_config(config_dict, dtype):
@@ -70,10 +73,10 @@ def mock_logger_info(msg, *args, **kwargs):
7073
model_parameters=model.parameters(),
7174
)
7275

73-
# Small model + no CPU offload => GPU path; that path logs "on GPU"
74-
gpu_path_logs = [m for m in log_messages if "Flattening param group" in m and "on GPU" in m]
75-
assert gpu_path_logs, (
76-
f"Expected GPU flatten path (logger.info should be called with 'Flattening param group' and 'on GPU'). "
76+
# Small model + no CPU offload => accelerator path logs "Flattening param group ... (sufficient memory)"
77+
accel_path_logs = [m for m in log_messages if "Flattening param group" in m and "(sufficient memory)" in m]
78+
assert accel_path_logs, (
79+
f"Expected accelerator flatten path (log should contain 'Flattening param group' and '(sufficient memory)'). "
7780
f"Captured messages: {log_messages}")
7881

7982
def test_flat_buffers_on_accelerator(self, zero_stage, dtype):
@@ -107,3 +110,44 @@ def test_flat_buffers_on_accelerator(self, zero_stage, dtype):
107110
device_type = get_accelerator().device_name()
108111
for i, flat in enumerate(opt.bit16_groups_flat):
109112
assert flat.device.type == device_type, (f"Flat buffer {i} must be on {device_type}, got {flat.device}")
113+
114+
@pytest.mark.world_size(1)
115+
def test_flatten_on_accelerator_training_step(self, zero_stage, dtype):
116+
"""Regression: flat buffer must be detached so inplace ops during step don't crash."""
117+
if not get_accelerator().is_available():
118+
pytest.skip("Accelerator not available")
119+
config_dict = {
120+
"train_micro_batch_size_per_gpu": 2,
121+
"gradient_accumulation_steps": 1,
122+
"zero_optimization": {
123+
"stage": zero_stage
124+
},
125+
"optimizer": {
126+
"type": "Adam",
127+
"params": {
128+
"lr": 1e-3
129+
}
130+
},
131+
}
132+
_apply_dtype_to_config(config_dict, dtype)
133+
134+
hidden_dim = 64
135+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=2)
136+
engine, _, _, _ = deepspeed.initialize(
137+
config=config_dict,
138+
model=model,
139+
model_parameters=model.parameters(),
140+
)
141+
for flat in engine.optimizer.bit16_groups_flat:
142+
assert flat.grad_fn is None, ("Flat buffer must be detached from autograd graph"
143+
" to prevent inplace-modification errors during optimizer step")
144+
145+
data_loader = random_dataloader(model=engine,
146+
total_samples=8,
147+
hidden_dim=hidden_dim,
148+
device=engine.device,
149+
dtype=_DTYPE_MAP[dtype])
150+
for batch in data_loader:
151+
loss = engine(batch[0], batch[1])
152+
engine.backward(loss)
153+
engine.step()

0 commit comments

Comments
 (0)