You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits
In recent PyTorch AOT Autograd, having tensors requiring grad in inputs
doesn't guarantee backward graph compilation. If no output requires grad
and no input requiring grad is mutated, aot_autograd skips backward
compilation (see [1]).
DeepCompile previously required backward compilation for every forward
graph which required grad, but relied solely on the existence of
require_grad tensors. This mismatch caused unbalanced forward/backward
visits, leaving graphs unvisited in `frames_needing_bwd`. The patched
FunctionMeta then remained effective during backward execution, causing
graphs to recompile on each execution and triggering exceptions during
`frames_needing_bwd.remove`.
Fix by:
- Remove `frames_needing_bwd` set and `needs_backward` tracking
- Use context manager `collect_backward_inputs()` to scope the patching
of compiled functions only during the forward pass
This ensures FunctionMeta patching is only effective during forward and
prevents unnecessary recompilation during backward passes.
References
[1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
0 commit comments