|
17 | 17 |
|
18 | 18 |
|
19 | 19 | def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: |
20 | | - """Remove shape-guard assertion nodes from a spin model's exported graph. |
| 20 | + """Remove shape-guard assertion nodes from an exported graph. |
21 | 21 |
|
22 | 22 | ``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape |
23 | | - relationships discovered during tracing. For the spin model, the atom- |
24 | | - doubling logic creates slice patterns that depend on ``(nall - nloc)``, |
25 | | - producing guards like ``Ne(nall, nloc)``. These guards are spurious: the |
26 | | - model computes correct results even when ``nall == nloc`` (NoPBC, no ghost |
27 | | - atoms). |
28 | | -
|
29 | | - This function is **only called for spin models** (guarded by ``if is_spin`` |
30 | | - in ``_trace_and_export``). The assertion messages use opaque symbolic |
31 | | - variable names (e.g. ``Ne(s22, s96)``) rather than human-readable names, |
32 | | - so filtering by message content is not reliable. Since |
| 23 | + relationships discovered during tracing. These guards can be spurious: |
| 24 | +
|
| 25 | + * **Spin models**: atom-doubling logic creates slice patterns that depend |
| 26 | + on ``(nall - nloc)``, producing guards like ``Ne(nall, nloc)``. |
| 27 | + * **All models**: the nlist padding inside ``forward_common_lower_exportable`` |
| 28 | + and the subsequent sort/truncate in ``_format_nlist`` can produce guards |
| 29 | + like ``Ne(nnei, sum(sel))``. These are spurious because the compiled |
| 30 | + graph handles any ``nnei >= sum(sel)`` correctly. |
| 31 | +
|
| 32 | + The assertion messages use opaque symbolic variable names (e.g. |
| 33 | + ``Ne(s22, s96)``) rather than human-readable names, so filtering by |
| 34 | + message content is not reliable. Since |
33 | 35 | ``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape |
34 | | - guards into these deferred assertions, and the only shape relationships in |
35 | | - the spin model involve nall/nloc, removing all of them is safe in this |
36 | | - context. |
| 36 | + guards into these deferred assertions, removing all of them is safe. |
37 | 37 | """ |
38 | 38 | graph = graph_module.graph |
39 | 39 | for node in list(graph.nodes): |
@@ -141,10 +141,8 @@ def _make_sample_inputs( |
141 | 141 | sel, |
142 | 142 | distinguish_types=not mixed_types, |
143 | 143 | ) |
144 | | - # Pad nlist with extra -1 columns so n_nnei > nnei at trace time. |
145 | | - # This ensures format_nlist's distance-sort branch is traced into the |
146 | | - # compiled graph, allowing the .pt2 model to handle variable-size |
147 | | - # neighbor lists at runtime (e.g. LAMMPS rcut + skin). |
| 144 | + # Pad nlist so nnei > sum(sel) in the sample tensors. |
| 145 | + # This prevents torch.export from specializing nnei to sum(sel). |
148 | 146 | nnei = sum(sel) |
149 | 147 | n_pad = max(1, nnei // 4) # pad by ~25%, at least 1 |
150 | 148 | nlist = np.concatenate( |
@@ -519,15 +517,10 @@ def _trace_and_export( |
519 | 517 | prefer_deferred_runtime_asserts_over_guards=True, |
520 | 518 | ) |
521 | 519 |
|
522 | | - if is_spin: |
523 | | - # torch.export re-introduces shape-guard assertions even when |
524 | | - # the make_fx graph has none. The spin model's atom-doubling |
525 | | - # logic creates slice patterns that depend on (nall - nloc), |
526 | | - # producing guards like Ne(nall, nloc). These guards are |
527 | | - # spurious: the model is correct when nall == nloc (NoPBC). |
528 | | - # Strip them from the exported graph so the model can be |
529 | | - # used with any valid nall >= nloc. |
530 | | - _strip_shape_assertions(exported.graph_module) |
| 520 | + # torch.export inserts _assert_scalar guards for symbolic shape |
| 521 | + # relationships (e.g. Ne(nnei, sum(sel)), Ne(nall, nloc)). These |
| 522 | + # are spurious — the model handles any valid input shapes correctly. |
| 523 | + _strip_shape_assertions(exported.graph_module) |
531 | 524 |
|
532 | 525 | # 7. Move the exported program to the target device if needed. |
533 | 526 | if target_device.type != "cpu": |
|
0 commit comments