Skip to content

Commit 6c59d54

Browse files
tohtanarraminen
andauthored
Fix hook count performance regression from v0.18.5 (#7886)
Fixes performance regressions reported in #7882 and #7885. PR #7780 added dynamic hook count computation for reentrant checkpointing correctness, but placed the call inside every gradient hook closure. For a model with n parameter tensors, this creates significant overhead per backward pass. Summary: 1. Added `should_refresh_expected_hook_count()` predicate that returns true only at backward phase boundaries (first hook, or new reentrant phase), so `count_used_parameters_in_backward()` is called once per phase instead of once per hook. 2. Applied this predicate in ZeRO-1/2 (stage_1_and_2.py) and both ZeRO-3 hook sites (stage3.py), reusing the `cached_max_expected_hooks_seen` value when refresh isn't needed. 3. Changed enter_backward() to reset hook counters on first real backward entry, preventing pollution from pre-user-backward autograd calls (e.g., TiledFusedLogitsLoss). With 24-layer transformer, ~267M params (147 parameter tensors), ZeRO-2, 8×H100 80GB, bf16, batch size 8, 20 warmup + 20 measured iterations: - Before fix: 0.1265s/iter - After fix: 0.0505s/iter --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Ramya Ramineni <rraminen@users.noreply.github.com>
1 parent 4dba1e2 commit 6c59d54

File tree

4 files changed

+155
-5
lines changed

4 files changed

+155
-5
lines changed

deepspeed/runtime/base_optimizer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ def run_grad_acc_post_hooks(self):
109109

110110
def enter_backward(self):
111111
"""Enter backward context. Call at the start of backward pass."""
112+
# On first real backward entry of a step, reset counters that may have been
113+
# polluted by pre-user-backward hooks (e.g. TiledFusedLogitsLoss calling
114+
# torch.autograd.backward() from forward). Do NOT reset on reentrant
115+
# phase re-entry (backward_seen_this_step == True) so phase-to-phase
116+
# state remains intact.
117+
if self.backward_active_depth == 0 and not self.backward_seen_this_step:
118+
self.hooks_fired_this_backward = 0
119+
self.max_expected_hooks_seen = 0
120+
self.remaining_grad_acc_hooks = 0
121+
self.post_backward_callback_queued = False
122+
self.post_backward_callback_graph_task_id = None
112123
self.backward_active_depth += 1
113124
# Track that backward has been active at some point in this step.
114125
# This is used to detect subsequent gradient hook phases with reentrant checkpointing.
@@ -128,6 +139,22 @@ def reset_for_new_step(self):
128139
self.post_backward_callback_queued = False
129140
self.post_backward_callback_graph_task_id = None
130141

142+
def should_refresh_expected_hook_count(self):
143+
"""Return True when count_used_parameters_in_backward() should be re-evaluated.
144+
145+
Refresh is needed in two cases:
146+
1. First hook of a backward (or backward phase): hooks_fired == 0.
147+
2. A new reentrant phase started: remaining hooks exhausted, we exited
148+
backward, but backward was active earlier this step.
149+
150+
The predicate must be evaluated BEFORE reenter_backward_if_needed()
151+
because re-entering changes backward_active_depth and hides the
152+
phase-boundary signal.
153+
"""
154+
return (self.hooks_fired_this_backward == 0
155+
or (self.remaining_grad_acc_hooks == 0 and self.backward_active_depth == 0
156+
and self.backward_seen_this_step))
157+
131158
def reenter_backward_if_needed(self):
132159
"""Re-enter backward context for subsequent phases in reentrant checkpointing.
133160
@@ -401,6 +428,10 @@ def clear_backward_seen_flag(self):
401428
"""Clear the backward seen flag and reset hook counters at the start of each step."""
402429
self._backward_hook_state.reset_for_new_step()
403430

431+
def should_refresh_expected_hook_count(self):
432+
"""Return True when count_used_parameters_in_backward() should be re-evaluated."""
433+
return self._backward_hook_state.should_refresh_expected_hook_count()
434+
404435
def reenter_backward_if_needed(self):
405436
"""Re-enter backward context for subsequent phases in reentrant checkpointing."""
406437
self._backward_hook_state.reenter_backward_if_needed()

deepspeed/runtime/zero/stage3.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,14 +1279,19 @@ def wrapper(param):
12791279

12801280
@instrument_w_nvtx
12811281
def reduce_partition_and_remove_grads(*notneeded):
1282+
# Evaluate refresh condition before reenter_backward_if_needed()
1283+
refresh_expected = self.should_refresh_expected_hook_count()
12821284
# Re-enter backward for subsequent phases in reentrant checkpointing
12831285
self.reenter_backward_if_needed()
12841286

12851287
self.reduce_ready_partitions_and_remove_grads(param)
12861288

12871289
# Update hook state and run epilogue if all expected hooks have fired
1288-
current_expected = count_used_parameters_in_backward(
1289-
non_leaf_params_requiring_grad) + leaf_module_count
1290+
if refresh_expected:
1291+
current_expected = count_used_parameters_in_backward(
1292+
non_leaf_params_requiring_grad) + leaf_module_count
1293+
else:
1294+
current_expected = self._max_expected_hooks_seen
12901295
self.update_hook_state_and_maybe_run_epilogue(current_expected)
12911296

12921297
self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads))
@@ -1303,6 +1308,8 @@ def reduce_partition_and_remove_grads(*notneeded):
13031308
def make_hook(params):
13041309

13051310
def reduce_leaf_module_grads(module, grad_input, grad_output):
1311+
# Evaluate refresh condition before reenter_backward_if_needed()
1312+
refresh_expected = self.should_refresh_expected_hook_count()
13061313
self.reenter_backward_if_needed()
13071314

13081315
for param in params:
@@ -1311,8 +1318,11 @@ def reduce_leaf_module_grads(module, grad_input, grad_output):
13111318
param.grad = torch.zeros_like(param)
13121319
self.reduce_ready_partitions_and_remove_grads(param)
13131320

1314-
current_expected = count_used_parameters_in_backward(
1315-
non_leaf_params_requiring_grad) + leaf_module_count
1321+
if refresh_expected:
1322+
current_expected = count_used_parameters_in_backward(
1323+
non_leaf_params_requiring_grad) + leaf_module_count
1324+
else:
1325+
current_expected = self._max_expected_hooks_seen
13161326
self.update_hook_state_and_maybe_run_epilogue(current_expected)
13171327

13181328
return reduce_leaf_module_grads

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,9 +1046,14 @@ def create_gradient_handling_hooks(self):
10461046
def wrapper(param, i):
10471047

10481048
def grad_handling_hook(*notneeded):
1049+
# Evaluate refresh condition before reenter_backward_if_needed()
1050+
refresh_expected = self.should_refresh_expected_hook_count()
10491051
self.reenter_backward_if_needed()
10501052
self.process_gradients(param, i)
1051-
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
1053+
if refresh_expected:
1054+
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
1055+
else:
1056+
current_expected = self._max_expected_hooks_seen
10521057
self.update_hook_state_and_maybe_run_epilogue(current_expected)
10531058

10541059
self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook))
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
"""Regression tests for count_used_parameters_in_backward() call count.
6+
7+
Verifies fix for https://github.com/deepspeedai/DeepSpeed/issues/7885:
8+
count_used_parameters_in_backward() was called once per gradient hook
9+
(O(n) calls per backward) instead of once per backward phase (O(1)
10+
for non-reentrant, O(p) for reentrant with p phases).
11+
"""
12+
13+
import pytest
14+
import torch
15+
from unittest.mock import patch
16+
17+
import deepspeed
18+
from deepspeed.accelerator import get_accelerator
19+
from unit.common import DistributedTest
20+
from unit.simple_model import SimpleModel, random_dataloader
21+
22+
23+
def get_config_dict(zero_stage):
24+
config_dict = {
25+
"train_micro_batch_size_per_gpu": 2,
26+
"gradient_accumulation_steps": 1,
27+
"steps_per_print": 1,
28+
"zero_optimization": {
29+
"stage": zero_stage,
30+
},
31+
"optimizer": {
32+
"type": "Adam",
33+
"params": {
34+
"lr": 1e-3
35+
}
36+
},
37+
}
38+
39+
if zero_stage == 3:
40+
config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0
41+
42+
if get_accelerator().is_bf16_supported():
43+
config_dict["bf16"] = {"enabled": True}
44+
elif get_accelerator().is_fp16_supported():
45+
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
46+
47+
return config_dict
48+
49+
50+
class TestHookCountRegression(DistributedTest):
51+
"""Test that count_used_parameters_in_backward is not called per-hook."""
52+
world_size = 2
53+
54+
@pytest.mark.parametrize("zero_stage", [2, 3])
55+
def test_non_reentrant_single_count_call(self, zero_stage):
56+
"""Non-reentrant backward should call count_used_parameters_in_backward exactly once."""
57+
hidden_dim = 16
58+
model = SimpleModel(hidden_dim)
59+
config = get_config_dict(zero_stage)
60+
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config)
61+
62+
data_loader = random_dataloader(model=engine, total_samples=4, hidden_dim=hidden_dim, device=engine.device)
63+
64+
# Determine the correct module path to patch based on stage
65+
if zero_stage == 2:
66+
patch_target = "deepspeed.runtime.zero.stage_1_and_2.count_used_parameters_in_backward"
67+
else:
68+
patch_target = "deepspeed.runtime.zero.stage3.count_used_parameters_in_backward"
69+
70+
call_counts = []
71+
72+
for batch in data_loader:
73+
with patch(patch_target, wraps=deepspeed.runtime.utils.count_used_parameters_in_backward) as mock_count:
74+
loss = engine(batch[0], batch[1])
75+
engine.backward(loss)
76+
call_counts.append(mock_count.call_count)
77+
engine.step()
78+
break
79+
80+
# Non-reentrant: exactly 1 call per backward
81+
assert call_counts[0] == 1, (f"Expected exactly 1 call to count_used_parameters_in_backward "
82+
f"per backward, got {call_counts[0]}")
83+
84+
@pytest.mark.parametrize("zero_stage", [2, 3])
85+
def test_training_step_succeeds_after_fix(self, zero_stage):
86+
"""Verify a full training step produces a finite loss after the caching fix."""
87+
hidden_dim = 16
88+
model = SimpleModel(hidden_dim)
89+
config = get_config_dict(zero_stage)
90+
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config)
91+
92+
data_loader = random_dataloader(model=engine, total_samples=8, hidden_dim=hidden_dim, device=engine.device)
93+
94+
losses = []
95+
for i, batch in enumerate(data_loader):
96+
loss = engine(batch[0], batch[1])
97+
assert torch.isfinite(loss), f"Loss is not finite at step {i}: {loss.item()}"
98+
losses.append(loss.item())
99+
engine.backward(loss)
100+
engine.step()
101+
if i >= 1:
102+
break
103+
104+
assert len(losses) >= 2, "Expected at least 2 training steps"

0 commit comments

Comments
 (0)