Skip to content

Commit ebbe957

Browse files
committed
deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits
In recent PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't guarantee backward graph compilation. If no output requires grad and no input requiring grad is mutated, aot_autograd skips backward compilation (see [1]). DeepCompile previously required backward compilation for every forward graph which required grad, but relied solely on the existence of require_grad tensors. This mismatch caused unbalanced forward/backward visits, leaving graphs unvisited in `frames_needing_bwd`. The patched FunctionMeta then remained effective during backward execution, causing graphs to recompile on each execution and triggering exceptions during `frames_needing_bwd.remove`. Fix by: - Remove `frames_needing_bwd` set and `needs_backward` tracking - Use context manager `collect_backward_inputs()` to scope the patching of compiled functions only during the forward pass This ensures FunctionMeta patching is only effective during forward and prevents unnecessary recompilation during backward passes. References [1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618 Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
1 parent 83d43c3 commit ebbe957

5 files changed

Lines changed: 23 additions & 23 deletions

File tree

deepspeed/compile/backend.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.fx import Graph, GraphModule
1313

1414
try:
15-
import torch.utils._pytree as pytree
1615
import torch._dynamo
1716
from functorch.compile import make_boxed_func
1817
from torch._functorch.aot_autograd import aot_module_simplified
@@ -28,7 +27,7 @@
2827
from .graph_param import DSGraphParamManager
2928
from .profilers import ProfilingResult
3029
from .profilers.graph_profile import MemoryProfilingInterpreter
31-
from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs
30+
from .patch_compiled_func import get_backward_inputs
3231
from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0, is_backend_inductor
3332
from .partitioner import get_wrapped_partitioner
3433
from .inductor import register_custom_ops, patch_create_aot_dispatcher_function
@@ -47,9 +46,9 @@ class GraphOrder:
4746
def __init__(self):
4847
self.frames = OrderedDict()
4948

50-
def add_graph(self, graph_id: int, frame_id: int, needs_backward: bool):
49+
def add_graph(self, graph_id: int, frame_id: int):
5150
if frame_id not in self.frames:
52-
self.frames[frame_id] = (graph_id, needs_backward)
51+
self.frames[frame_id] = (graph_id, )
5352

5453
def get_graph_order(self) -> List[Tuple[int, bool]]:
5554
return list(self.frames.values())
@@ -60,7 +59,6 @@ def clear(self):
6059

6160
graph_order_with_frame_id = GraphOrder()
6261

63-
frames_needing_bwd = set()
6462
profiling_results: Dict[int, ProfilingResult] = {}
6563
opt_pass_times = []
6664
opt_passes = {}
@@ -225,9 +223,8 @@ def make_backend(backend, compile_config, compile_kwargs={}):
225223
def backend_fn(gm: GraphModule, real_inputs):
226224
graph_id = id(gm.graph)
227225

228-
needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs)
229226
frame_id = gm.meta["dynamo_compile_id"].frame_id
230-
graph_order_with_frame_id.add_graph(graph_id, frame_id, needs_backward)
227+
graph_order_with_frame_id.add_graph(graph_id, frame_id)
231228

232229
graph_order = graph_order_with_frame_id.get_graph_order()
233230

@@ -258,17 +255,11 @@ def backend_fn(gm: GraphModule, real_inputs):
258255
if graph_id not in profiling_results:
259256
profiling_results[graph_id] = ProfilingResult()
260257
profiling_results[graph_id].param_indices = param_indices
261-
profiling_results[graph_id].needs_backward = needs_backward
262258

263259
def make_fw_graph(gm, sample_inputs):
264260
time_start = time.time()
265261
graph_index = len(graph_order) - 1
266262

267-
if needs_backward:
268-
if len(frames_needing_bwd) == 0:
269-
patch_compiled_func()
270-
frames_needing_bwd.add(frame_id)
271-
272263
# Try to get real_inputs from the list first, then from storage
273264
if fwd_real_inputs:
274265
real_inputs = fwd_real_inputs.pop(0)
@@ -347,10 +338,6 @@ def make_bw_graph(gm, sample_inputs):
347338
add_free_activations(graph_id, gm.graph,
348339
get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names))
349340

350-
frames_needing_bwd.remove(frame_id)
351-
if len(frames_needing_bwd) == 0:
352-
unpatch_compiled_func()
353-
354341
log_rank0(
355342
f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
356343
enable=debug_log)

deepspeed/compile/patch_compiled_func.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# DeepSpeed Team
55

6+
from contextlib import contextmanager
67
import torch
78
from deepspeed.utils.torch import required_torch_version
89

@@ -89,5 +90,12 @@ def unpatch_compiled_func():
8990
torch.autograd.Function = original_grad_fn
9091

9192

93+
@contextmanager
94+
def collect_backward_inputs():
95+
patch_compiled_func()
96+
yield
97+
unpatch_compiled_func()
98+
99+
92100
def get_backward_inputs():
93101
return backward_inputs

deepspeed/compile/profilers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
class ProfilingResult:
1414
fwd_graph: Graph = None
1515
bwd_graph: Graph = None
16-
needs_backward: bool = False
1716
fwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) # name, current_alloc, delta, peak
1817
bwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list)
1918
fwd_time: List[Tuple[str, int, int]] = field(default_factory=list) # name, device_time, wall_time

deepspeed/compile/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def is_release_node(n: Node) -> bool:
469469

470470

471471
def get_index_by_graph_id(graph_order, target_graph_id):
472-
for index, (graph_id, _) in enumerate(graph_order):
472+
for index, (graph_id, ) in enumerate(graph_order):
473473
if graph_id == target_graph_id:
474474
return index
475475
return -1

deepspeed/runtime/engine.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections import defaultdict, OrderedDict, deque
1212
from shutil import copyfile
1313
import gc
14+
from contextlib import nullcontext
1415

1516
from torch.nn.modules import Module
1617
from torch.nn.parameter import Parameter
@@ -125,6 +126,7 @@
125126
from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue
126127
from deepspeed.compile.backend import register_compile_pass, opt_passes
127128
from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states
129+
from deepspeed.compile.patch_compiled_func import collect_backward_inputs
128130
from deepspeed.compile.init_z1 import init_z1
129131
from deepspeed.compile.init_z3 import init_z3
130132
from deepspeed.compile.init_sp import init_autosp
@@ -2350,11 +2352,15 @@ def forward(self, *inputs, **kwargs):
23502352
"DeepCompile is enabled but engine.compile() has not been called; executing without DeepCompile until compile() runs.",
23512353
ranks=[0])
23522354

2353-
if self.is_deepcompile_active() and hasattr(self, "launch_compile_passes"):
2354-
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
2355-
self.launch_compile_passes(self.global_steps)
2355+
if self.is_deepcompile_active():
2356+
collect_backward_input_ctx = collect_backward_inputs
2357+
if hasattr(self, "launch_compile_passes"):
2358+
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
2359+
self.launch_compile_passes(self.global_steps)
2360+
else:
2361+
collect_backward_input_ctx = nullcontext
23562362

2357-
with autocast_if_enabled(self):
2363+
with autocast_if_enabled(self), collect_backward_input_ctx():
23582364
loss = self.module(*inputs, **kwargs)
23592365

23602366
# Register output backward hooks

0 commit comments

Comments
 (0)