Skip to content

Commit 0dcff08

Browse files
committed
fix(pt): recognize AOTInductor-wrapped CUDA OOM in AutoBatchSize
When running `dp --pt-expt test` (or any path that goes through `deepmd.pt_expt.infer.deep_eval`) against a `.pt2` AOTInductor package, `AutoBatchSize` doubles the batch on every success. For models with a large `sel` the exploration eventually saturates GPU memory, and the CUDA caching allocator raises the usual ``CUDA out of memory`` from inside the AOTInductor runtime. AOTInductor then rewraps that error as a generic RuntimeError: run_func_(...) API call failed at .../aoti_runner/model_container_runner.cpp, line 144 The original "CUDA out of memory" text is printed only to stderr, so the old `is_oom_error` -- which keyed on a short list of substrings in `e.args[0]` -- never matched. `execute()` therefore did not shrink the batch; the exception propagated and the run crashed on a GPU that was otherwise completely idle (as confirmed by monitoring `nvidia-smi --query-compute-apps`, which showed dp itself as the sole consumer holding tens of GiB just before the failure). Widen `is_oom_error` to: * walk the exception chain via `__cause__` / `__context__`, so that a future PyTorch preserving the original OOM text is handled for free; * keep matching the four plain CUDA OOM markers on every message in the chain; * additionally treat the AOTInductor wrapper signature (`run_func_(` plus `model_container_runner`) as an OOM candidate. If the AOTInductor wrapper ever hides a non-OOM failure, the batch shrinker will halve down to 1 and then raise `OutOfMemoryError`, so the fallback is bounded -- non-OOM bugs still surface with a clear terminal error rather than being silently retried forever.
1 parent 54f42d9 commit 0dcff08

1 file changed

Lines changed: 47 additions & 14 deletions

File tree

deepmd/pt/utils/auto_batch_size.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,53 @@ def is_oom_error(self, e: Exception) -> bool:
4646
e : Exception
4747
Exception
4848
"""
49-
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
50-
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
51-
# (the meaningless error message should be considered as a bug in cusolver)
52-
if (
53-
isinstance(e, RuntimeError)
54-
and (
55-
"CUDA out of memory." in e.args[0]
56-
or "CUDA driver error: out of memory" in e.args[0]
57-
or "cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR" in e.args[0]
58-
# https://github.com/deepmodeling/deepmd-kit/issues/4594
59-
or "CUDA error: out of memory" in e.args[0]
60-
)
61-
) or isinstance(e, torch.cuda.OutOfMemoryError):
62-
# Release all unoccupied cached memory
49+
if isinstance(e, torch.cuda.OutOfMemoryError):
6350
torch.cuda.empty_cache()
6451
return True
52+
53+
if not isinstance(e, RuntimeError) or not e.args:
54+
return False
55+
56+
# Gather messages from the exception itself and its chain. AOTInductor
57+
# (.pt2) sometimes strips the underlying OOM message when rewrapping,
58+
# but not always; checking ``__cause__`` / ``__context__`` catches the
59+
# remaining cases when the original error is preserved.
60+
msgs: list[str] = []
61+
cur: BaseException | None = e
62+
seen: set[int] = set()
63+
while cur is not None and id(cur) not in seen:
64+
seen.add(id(cur))
65+
if cur.args:
66+
first = cur.args[0]
67+
if isinstance(first, str):
68+
msgs.append(first)
69+
cur = cur.__cause__ or cur.__context__
70+
71+
# Several sources treat CUSOLVER_STATUS_INTERNAL_ERROR as an OOM, e.g.
72+
# https://github.com/JuliaGPU/CUDA.jl/issues/1924
73+
plain_oom_markers = (
74+
"CUDA out of memory.",
75+
"CUDA driver error: out of memory",
76+
"CUDA error: out of memory",
77+
"cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR",
78+
)
79+
if any(m in msg for msg in msgs for m in plain_oom_markers):
80+
torch.cuda.empty_cache()
81+
return True
82+
83+
# AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic
84+
# ``run_func_(...) API call failed at .../model_container_runner.cpp``.
85+
# The original "CUDA out of memory" text is printed to stderr only and
86+
# is absent from the Python-level RuntimeError, so we match on the
87+
# wrapper signature. If the root cause turns out to be something
88+
# other than OOM, ``execute()`` will keep shrinking the batch and
89+
# eventually raise ``OutOfMemoryError`` at batch size 1, which is a
90+
# clean failure rather than an uncaught exception.
91+
aoti_wrapped = any(
92+
"run_func_(" in msg and "model_container_runner" in msg for msg in msgs
93+
)
94+
if aoti_wrapped:
95+
torch.cuda.empty_cache()
96+
return True
97+
6598
return False

0 commit comments

Comments
 (0)