Skip to content

Commit 40962b2

Browse files
author
Han Wang
committed
feat(pt_expt): enable torch.compile for multi-task training
make_fx inserts aten.detach.default nodes when decomposing autograd.grad(create_graph=True), breaking second-order gradient flow for force training. Add _remove_detach_nodes() to strip these after tracing, restoring correct higher-order derivatives. Loop _compile_model over model_keys so each branch gets its own _CompiledModel with per-task max_nall estimation.
1 parent c20930d commit 40962b2

3 files changed

Lines changed: 356 additions & 110 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 131 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,28 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]:
142142
# ---------------------------------------------------------------------------
143143

144144

145+
def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None:
146+
"""Remove ``aten.detach.default`` nodes from an FX graph in-place.
147+
148+
``make_fx`` inserts these nodes when recording saved tensors from the
149+
autograd backward pass (``autograd.grad`` with ``create_graph=True``).
150+
The detach breaks the gradient connection between saved activations and
151+
model parameters, causing incorrect second-order derivatives — e.g.
152+
bias gradients become zero for force-loss training.
153+
154+
Removing these nodes restores the gradient path so that higher-order
155+
derivatives flow correctly through the decomposed backward ops.
156+
"""
157+
graph = gm.graph
158+
for node in list(graph.nodes):
159+
if node.op == "call_function" and node.target == torch.ops.aten.detach.default:
160+
input_node = node.args[0]
161+
node.replace_all_uses_with(input_node)
162+
graph.erase_node(node)
163+
graph.lint()
164+
gm.recompile()
165+
166+
145167
def _trace_and_compile(
146168
model: torch.nn.Module,
147169
ext_coord: torch.Tensor,
@@ -157,7 +179,7 @@ def _trace_and_compile(
157179
Parameters
158180
----------
159181
model : torch.nn.Module
160-
The (uncompiled) model. Temporarily set to eval mode for tracing.
182+
The (uncompiled) model.
161183
ext_coord, ext_atype, nlist, mapping, fparam, aparam
162184
Sample tensors (already padded to the desired max_nall).
163185
compile_opts : dict
@@ -188,7 +210,7 @@ def fn(
188210
fparam: torch.Tensor | None,
189211
aparam: torch.Tensor | None,
190212
) -> dict[str, torch.Tensor]:
191-
extended_coord = extended_coord.detach().requires_grad_(True)
213+
extended_coord = extended_coord.requires_grad_(True)
192214
return model.forward_lower(
193215
extended_coord,
194216
extended_atype,
@@ -203,13 +225,15 @@ def fn(
203225
# change at runtime, the caller catches the error and retraces.
204226
traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)
205227

228+
# make_fx inserts aten.detach.default for saved tensors used in the
229+
# decomposed autograd.grad backward ops. These detach nodes break
230+
# second-order gradient flow (d(force)/d(params) for force training).
231+
# Removing them restores correct higher-order derivatives.
232+
_remove_detach_nodes(traced_lower)
233+
206234
if not was_training:
207235
model.eval()
208236

209-
# The inductor backend does not propagate gradients through the
210-
# make_fx-decomposed autograd.grad ops (second-order gradients for
211-
# force training). Use "aot_eager" which correctly preserves the
212-
# gradient chain while still benefiting from make_fx decomposition.
213237
if "backend" not in compile_opts:
214238
compile_opts["backend"] = "aot_eager"
215239
compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts)
@@ -839,10 +863,6 @@ def _make_sample(
839863
# torch.compile -------------------------------------------------------
840864
self.enable_compile = training_params.get("enable_compile", False)
841865
if self.enable_compile:
842-
if self.multi_task:
843-
raise ValueError(
844-
"torch.compile is not supported with multi-task training."
845-
)
846866
compile_opts = training_params.get("compile_options", {})
847867
log.info("Compiling model with torch.compile (%s)", compile_opts)
848868
self._compile_model(compile_opts)
@@ -878,108 +898,117 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
878898
normalize_coord,
879899
)
880900

881-
model = self.model
882-
883-
# --- Estimate max_nall by sampling multiple batches ---
884-
n_sample = 20
885-
max_nall = 0
886-
best_sample: (
887-
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None
888-
) = None
889-
890-
for _ii in range(n_sample):
891-
inp, _ = self.get_data(is_train=True)
892-
coord = inp["coord"].detach()
893-
atype = inp["atype"].detach()
894-
box = inp.get("box")
895-
if box is not None:
896-
box = box.detach()
897-
898-
nframes, nloc = atype.shape[:2]
899-
coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3)
900-
atype_np = atype.cpu().numpy()
901-
box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None
902-
903-
if box_np is not None:
904-
coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3))
905-
else:
906-
coord_norm = coord_np
901+
for task_key in self.model_keys:
902+
model = self.wrapper.model[task_key]
903+
904+
# --- Estimate max_nall by sampling multiple batches ---
905+
n_sample = 20
906+
max_nall = 0
907+
best_sample: (
908+
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None
909+
) = None
910+
911+
for _ii in range(n_sample):
912+
inp, _ = self.get_data(is_train=True, task_key=task_key)
913+
coord = inp["coord"].detach()
914+
atype = inp["atype"].detach()
915+
box = inp.get("box")
916+
if box is not None:
917+
box = box.detach()
918+
919+
nframes, nloc = atype.shape[:2]
920+
coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3)
921+
atype_np = atype.cpu().numpy()
922+
box_np = (
923+
box.cpu().numpy().reshape(nframes, 9) if box is not None else None
924+
)
907925

908-
ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts(
909-
coord_norm, atype_np, box_np, model.get_rcut()
910-
)
911-
nlist_np = build_neighbor_list(
912-
ext_coord_np,
913-
ext_atype_np,
914-
nloc,
915-
model.get_rcut(),
916-
model.get_sel(),
917-
distinguish_types=False,
918-
)
919-
ext_coord_np = ext_coord_np.reshape(nframes, -1, 3)
920-
nall = ext_coord_np.shape[1]
921-
if nall > max_nall:
922-
max_nall = nall
923-
best_sample = (
926+
if box_np is not None:
927+
coord_norm = normalize_coord(
928+
coord_np, box_np.reshape(nframes, 3, 3)
929+
)
930+
else:
931+
coord_norm = coord_np
932+
933+
ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts(
934+
coord_norm, atype_np, box_np, model.get_rcut()
935+
)
936+
nlist_np = build_neighbor_list(
924937
ext_coord_np,
925938
ext_atype_np,
926-
mapping_np,
927-
nlist_np,
928939
nloc,
929-
inp,
940+
model.get_rcut(),
941+
model.get_sel(),
942+
distinguish_types=False,
930943
)
944+
ext_coord_np = ext_coord_np.reshape(nframes, -1, 3)
945+
nall = ext_coord_np.shape[1]
946+
if nall > max_nall:
947+
max_nall = nall
948+
best_sample = (
949+
ext_coord_np,
950+
ext_atype_np,
951+
mapping_np,
952+
nlist_np,
953+
nloc,
954+
inp,
955+
)
931956

932-
# Add 20 % margin and round up to a multiple of 8.
933-
max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8
934-
log.info(
935-
"Estimated max_nall=%d for compiled model (sampled %d batches).",
936-
max_nall,
937-
n_sample,
938-
)
939-
940-
# --- Pad the largest sample to max_nall and trace ---
941-
assert best_sample is not None
942-
ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = (
943-
best_sample
944-
)
945-
nframes = ext_coord_np.shape[0]
946-
actual_nall = ext_coord_np.shape[1]
947-
pad_n = max_nall - actual_nall
948-
949-
if pad_n > 0:
950-
ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0)))
951-
ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n)))
952-
mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n)))
957+
# Add 20 % margin and round up to a multiple of 8.
958+
max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8
959+
log.info(
960+
"Estimated max_nall=%d for compiled model "
961+
"(task=%s, sampled %d batches).",
962+
max_nall,
963+
task_key,
964+
n_sample,
965+
)
953966

954-
ext_coord = torch.tensor(
955-
ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
956-
)
957-
ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE)
958-
nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE)
959-
mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE)
960-
fparam = sample_input.get("fparam")
961-
aparam = sample_input.get("aparam")
967+
# --- Pad the largest sample to max_nall and trace ---
968+
assert best_sample is not None
969+
ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = (
970+
best_sample
971+
)
972+
nframes = ext_coord_np.shape[0]
973+
actual_nall = ext_coord_np.shape[1]
974+
pad_n = max_nall - actual_nall
962975

963-
compile_opts.pop("dynamic", None) # always False for padded approach
976+
if pad_n > 0:
977+
ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0)))
978+
ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n)))
979+
mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n)))
964980

965-
compiled_lower = _trace_and_compile(
966-
model,
967-
ext_coord,
968-
ext_atype,
969-
nlist_t,
970-
mapping_t,
971-
fparam,
972-
aparam,
973-
compile_opts,
974-
)
981+
ext_coord = torch.tensor(
982+
ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
983+
)
984+
ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE)
985+
nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE)
986+
mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE)
987+
fparam = sample_input.get("fparam")
988+
aparam = sample_input.get("aparam")
989+
990+
task_compile_opts = dict(compile_opts)
991+
task_compile_opts.pop("dynamic", None) # always False for padded approach
992+
993+
compiled_lower = _trace_and_compile(
994+
model,
995+
ext_coord,
996+
ext_atype,
997+
nlist_t,
998+
mapping_t,
999+
fparam,
1000+
aparam,
1001+
task_compile_opts,
1002+
)
9751003

976-
self.wrapper.model["Default"] = _CompiledModel(
977-
model, compiled_lower, max_nall, compile_opts
978-
)
979-
log.info(
980-
"Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).",
981-
max_nall,
982-
)
1004+
self.wrapper.model[task_key] = _CompiledModel(
1005+
model, compiled_lower, max_nall, task_compile_opts
1006+
)
1007+
log.info(
1008+
"Model compiled with padded nall=%d (task=%s, dynamic=False).",
1009+
max_nall,
1010+
task_key,
1011+
)
9831012

9841013
# ------------------------------------------------------------------
9851014
# Data helpers

0 commit comments

Comments
 (0)