Skip to content

Commit 8a9fe63

Browse files
author
Han Wang
committed
fix(pt2): move nlist padding inside traced fn and strip shape assertions
Move nlist padding (+1 column of -1s) inside the `fn` closure in both `make_model.forward_common_lower_exportable` and `SpinModel.forward_common_lower_exportable`, making it part of the traced graph. This fixes proxy tensor shape mismatches from make_fx and removes the need for external padding in deep_eval.py. Also apply `_strip_shape_assertions` unconditionally (not just spin models) to remove spurious torch.export guards like Ne(nnei, sum(sel)). Export tests that verify atomic virial now pass `do_atomic_virial=True` to `deserialize_to_file` so the exported model includes the correction.
1 parent 217a587 commit 8a9fe63

6 files changed

Lines changed: 110 additions & 78 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -759,19 +759,6 @@ def _eval_model(
759759
# returning a dict just like the .pte module.
760760
# It also filters non-tensor args automatically, matching the
761761
# export-time signature where None args were excluded.
762-
# Pad nlist with extra -1 column so n_nnei > nnei, ensuring
763-
# format_nlist's compiled sort branch executes.
764-
nlist_t = torch.cat(
765-
[
766-
nlist_t,
767-
-torch.ones(
768-
(*nlist_t.shape[:2], 1),
769-
dtype=nlist_t.dtype,
770-
device=nlist_t.device,
771-
),
772-
],
773-
dim=-1,
774-
)
775762
model_ret = self._pt2_runner(
776763
ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t
777764
)
@@ -911,19 +898,6 @@ def _eval_model_spin(
911898

912899
# Call the model with spin (7 args)
913900
if self._is_pt2:
914-
# Pad nlist with extra -1 column so n_nnei > nnei, ensuring
915-
# format_nlist's compiled sort branch executes.
916-
nlist_t = torch.cat(
917-
[
918-
nlist_t,
919-
-torch.ones(
920-
(*nlist_t.shape[:2], 1),
921-
dtype=nlist_t.dtype,
922-
device=nlist_t.device,
923-
),
924-
],
925-
dim=-1,
926-
)
927901
model_ret = self._pt2_runner(
928902
ext_coord_t,
929903
ext_atype_t,

deepmd/pt_expt/model/make_model.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,21 @@ def fn(
346346
aparam: torch.Tensor | None,
347347
) -> dict[str, torch.Tensor]:
348348
extended_coord = extended_coord.detach().requires_grad_(True)
349+
# Pad nlist with one extra -1 column inside the traced function.
350+
# This ensures n_nnei > sum(sel), forcing the sort branch in
351+
# _format_nlist. The padding becomes part of the compiled graph,
352+
# so callers never need to pad externally.
353+
nlist = torch.cat(
354+
[
355+
nlist,
356+
-torch.ones(
357+
(*nlist.shape[:2], 1),
358+
dtype=nlist.dtype,
359+
device=nlist.device,
360+
),
361+
],
362+
dim=-1,
363+
)
349364
return model.forward_common_lower(
350365
extended_coord,
351366
extended_atype,
@@ -356,13 +371,19 @@ def fn(
356371
do_atomic_virial=do_atomic_virial,
357372
)
358373

359-
return make_fx(fn, **make_fx_kwargs)(
360-
extended_coord,
361-
extended_atype,
362-
nlist,
363-
mapping,
364-
fparam,
365-
aparam,
366-
)
374+
# Force format_nlist to always use the sort branch during tracing.
375+
model.need_sorted_nlist_for_lower = lambda: True
376+
try:
377+
traced = make_fx(fn, **make_fx_kwargs)(
378+
extended_coord,
379+
extended_atype,
380+
nlist,
381+
mapping,
382+
fparam,
383+
aparam,
384+
)
385+
finally:
386+
del model.need_sorted_nlist_for_lower
387+
return traced
367388

368389
return CM

deepmd/pt_expt/model/spin_model.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ def fn(
9696
aparam: torch.Tensor | None,
9797
) -> dict[str, torch.Tensor]:
9898
extended_coord = extended_coord.detach().requires_grad_(True)
99+
# Pad nlist inside traced function (see make_model.py for rationale).
100+
nlist = torch.cat(
101+
[
102+
nlist,
103+
-torch.ones(
104+
(*nlist.shape[:2], 1),
105+
dtype=nlist.dtype,
106+
device=nlist.device,
107+
),
108+
],
109+
dim=-1,
110+
)
99111
return model.forward_common_lower(
100112
extended_coord,
101113
extended_atype,
@@ -107,15 +119,22 @@ def fn(
107119
do_atomic_virial=do_atomic_virial,
108120
)
109121

110-
return make_fx(fn, **make_fx_kwargs)(
111-
extended_coord,
112-
extended_atype,
113-
extended_spin,
114-
nlist,
115-
mapping,
116-
fparam,
117-
aparam,
118-
)
122+
# Force format_nlist to always use the sort branch during tracing.
123+
backbone = model.backbone_model
124+
backbone.need_sorted_nlist_for_lower = lambda: True
125+
try:
126+
traced = make_fx(fn, **make_fx_kwargs)(
127+
extended_coord,
128+
extended_atype,
129+
extended_spin,
130+
nlist,
131+
mapping,
132+
fparam,
133+
aparam,
134+
)
135+
finally:
136+
del backbone.need_sorted_nlist_for_lower
137+
return traced
119138

120139
def forward_common_lower(
121140
self, *args: Any, **kwargs: Any

deepmd/pt_expt/utils/serialization.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,23 @@
1717

1818

1919
def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
20-
"""Remove shape-guard assertion nodes from a spin model's exported graph.
20+
"""Remove shape-guard assertion nodes from an exported graph.
2121
2222
``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape
23-
relationships discovered during tracing. For the spin model, the atom-
24-
doubling logic creates slice patterns that depend on ``(nall - nloc)``,
25-
producing guards like ``Ne(nall, nloc)``. These guards are spurious: the
26-
model computes correct results even when ``nall == nloc`` (NoPBC, no ghost
27-
atoms).
28-
29-
This function is **only called for spin models** (guarded by ``if is_spin``
30-
in ``_trace_and_export``). The assertion messages use opaque symbolic
31-
variable names (e.g. ``Ne(s22, s96)``) rather than human-readable names,
32-
so filtering by message content is not reliable. Since
23+
relationships discovered during tracing. These guards can be spurious:
24+
25+
* **Spin models**: atom-doubling logic creates slice patterns that depend
26+
on ``(nall - nloc)``, producing guards like ``Ne(nall, nloc)``.
27+
* **All models**: the nlist padding inside ``forward_common_lower_exportable``
28+
and the subsequent sort/truncate in ``_format_nlist`` can produce guards
29+
like ``Ne(nnei, sum(sel))``. These are spurious because the compiled
30+
graph handles any ``nnei >= sum(sel)`` correctly.
31+
32+
The assertion messages use opaque symbolic variable names (e.g.
33+
``Ne(s22, s96)``) rather than human-readable names, so filtering by
34+
message content is not reliable. Since
3335
``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape
34-
guards into these deferred assertions, and the only shape relationships in
35-
the spin model involve nall/nloc, removing all of them is safe in this
36-
context.
36+
guards into these deferred assertions, removing all of them is safe.
3737
"""
3838
graph = graph_module.graph
3939
for node in list(graph.nodes):
@@ -141,10 +141,8 @@ def _make_sample_inputs(
141141
sel,
142142
distinguish_types=not mixed_types,
143143
)
144-
# Pad nlist with extra -1 columns so n_nnei > nnei at trace time.
145-
# This ensures format_nlist's distance-sort branch is traced into the
146-
# compiled graph, allowing the .pt2 model to handle variable-size
147-
# neighbor lists at runtime (e.g. LAMMPS rcut + skin).
144+
# Pad nlist so nnei > sum(sel) in the sample tensors.
145+
# This prevents torch.export from specializing nnei to sum(sel).
148146
nnei = sum(sel)
149147
n_pad = max(1, nnei // 4) # pad by ~25%, at least 1
150148
nlist = np.concatenate(
@@ -519,15 +517,10 @@ def _trace_and_export(
519517
prefer_deferred_runtime_asserts_over_guards=True,
520518
)
521519

522-
if is_spin:
523-
# torch.export re-introduces shape-guard assertions even when
524-
# the make_fx graph has none. The spin model's atom-doubling
525-
# logic creates slice patterns that depend on (nall - nloc),
526-
# producing guards like Ne(nall, nloc). These guards are
527-
# spurious: the model is correct when nall == nloc (NoPBC).
528-
# Strip them from the exported graph so the model can be
529-
# used with any valid nall >= nloc.
530-
_strip_shape_assertions(exported.graph_module)
520+
# torch.export inserts _assert_scalar guards for symbolic shape
521+
# relationships (e.g. Ne(nnei, sum(sel)), Ne(nall, nloc)). These
522+
# are spurious — the model handles any valid input shapes correctly.
523+
_strip_shape_assertions(exported.graph_module)
531524

532525
# 7. Move the exported program to the target device if needed.
533526
if target_device.type != "cpu":

source/tests/pt_expt/export_helpers.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def export_save_load_and_compare(
7070
strict=False,
7171
prefer_deferred_runtime_asserts_over_guards=True,
7272
)
73+
# Strip spurious shape-guard assertions (e.g. Ne(nnei, sum(sel)))
74+
from deepmd.pt_expt.utils.serialization import (
75+
_strip_shape_assertions,
76+
)
77+
78+
_strip_shape_assertions(exported.graph_module)
7379

7480
# 4. .pte save -> load round-trip
7581
with tempfile.NamedTemporaryFile(suffix=".pte") as f:
@@ -199,9 +205,22 @@ def model_forward_lower_export_round_trip(
199205
)
200206

201207
# 5. Symbolic trace + dynamic shapes + .pte round-trip
208+
# Pad nlist with extra -1 columns so nnei > sum(sel) in the sample.
209+
# This prevents torch.export from specializing nnei to sum(sel).
210+
nlist_padded = torch.cat(
211+
[
212+
nlist_t,
213+
-torch.ones(
214+
(*nlist_t.shape[:2], max(1, nlist_t.shape[2] // 4)),
215+
dtype=nlist_t.dtype,
216+
device=nlist_t.device,
217+
),
218+
],
219+
dim=-1,
220+
)
202221
inputs_2f = tuple(
203222
torch.cat([t, t], dim=0) if t is not None else None
204-
for t in (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
223+
for t in (ext_coord, ext_atype, nlist_padded, mapping_t, fparam, aparam)
205224
)
206225
traced_sym = md_pt.forward_lower_exportable(
207226
inputs_2f[0],
@@ -221,6 +240,12 @@ def model_forward_lower_export_round_trip(
221240
strict=False,
222241
prefer_deferred_runtime_asserts_over_guards=True,
223242
)
243+
# Strip spurious shape-guard assertions (e.g. Ne(nnei, sum(sel)))
244+
from deepmd.pt_expt.utils.serialization import (
245+
_strip_shape_assertions,
246+
)
247+
248+
_strip_shape_assertions(exported_dyn.graph_module)
224249
with tempfile.NamedTemporaryFile(suffix=".pte") as f:
225250
torch.export.save(exported_dyn, f.name)
226251
loaded = torch.export.load(f.name).module()

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def setUpClass(cls) -> None:
6363
cls.model = cls.model.to(torch.float64)
6464
cls.model.eval()
6565

66-
# Serialize and save to .pte
66+
# Serialize and save to .pte (with atomic virial for test_dynamic_shapes)
6767
cls.model_data = {"model": cls.model.serialize()}
6868
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
6969
cls.tmpfile.close()
70-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
70+
deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True)
7171

7272
# Create DeepPot for testing
7373
cls.dp = DeepPot(cls.tmpfile.name)
@@ -547,14 +547,14 @@ def setUpClass(cls) -> None:
547547
# compilation (tests/pt/__init__.py sets it to "cuda:9999999").
548548
torch.set_default_device(None)
549549
try:
550-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
550+
deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True)
551551
finally:
552552
torch.set_default_device("cuda:9999999")
553553

554554
# Also save to .pte for cross-format comparison
555555
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
556556
cls.pte_tmpfile.close()
557-
deserialize_to_file(cls.pte_tmpfile.name, cls.model_data)
557+
deserialize_to_file(cls.pte_tmpfile.name, cls.model_data, do_atomic_virial=True)
558558

559559
# Create DeepPot for .pt2
560560
cls.dp = DeepPot(cls.tmpfile.name)
@@ -1070,7 +1070,7 @@ def setUpClass(cls) -> None:
10701070
cls.model_data = {"model": cls.model.serialize()}
10711071
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
10721072
cls.tmpfile.close()
1073-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
1073+
deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True)
10741074

10751075
cls.dp = DeepPot(cls.tmpfile.name)
10761076

@@ -1187,14 +1187,14 @@ def setUpClass(cls) -> None:
11871187
cls.tmpfile.close()
11881188
torch.set_default_device(None)
11891189
try:
1190-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
1190+
deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True)
11911191
finally:
11921192
torch.set_default_device("cuda:9999999")
11931193

11941194
# Also save .pte for cross-format comparison
11951195
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
11961196
cls.pte_tmpfile.close()
1197-
deserialize_to_file(cls.pte_tmpfile.name, cls.model_data)
1197+
deserialize_to_file(cls.pte_tmpfile.name, cls.model_data, do_atomic_virial=True)
11981198

11991199
cls.dp = DeepPot(cls.tmpfile.name)
12001200
cls.dp_pte = DeepPot(cls.pte_tmpfile.name)

0 commit comments

Comments
 (0)