@@ -142,6 +142,28 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]:
142142# ---------------------------------------------------------------------------
143143
144144
145+ def _remove_detach_nodes (gm : torch .fx .GraphModule ) -> None :
146+ """Remove ``aten.detach.default`` nodes from an FX graph in-place.
147+
148+ ``make_fx`` inserts these nodes when recording saved tensors from the
149+ autograd backward pass (``autograd.grad`` with ``create_graph=True``).
150+ The detach breaks the gradient connection between saved activations and
151+ model parameters, causing incorrect second-order derivatives — e.g.
152+ bias gradients become zero for force-loss training.
153+
154+ Removing these nodes restores the gradient path so that higher-order
155+ derivatives flow correctly through the decomposed backward ops.
156+ """
157+ graph = gm .graph
158+ for node in list (graph .nodes ):
159+ if node .op == "call_function" and node .target == torch .ops .aten .detach .default :
160+ input_node = node .args [0 ]
161+ node .replace_all_uses_with (input_node )
162+ graph .erase_node (node )
163+ graph .lint ()
164+ gm .recompile ()
165+
166+
145167def _trace_and_compile (
146168 model : torch .nn .Module ,
147169 ext_coord : torch .Tensor ,
@@ -157,7 +179,7 @@ def _trace_and_compile(
157179 Parameters
158180 ----------
159181 model : torch.nn.Module
160- The (uncompiled) model. Temporarily set to eval mode for tracing.
182+ The (uncompiled) model.
161183 ext_coord, ext_atype, nlist, mapping, fparam, aparam
162184 Sample tensors (already padded to the desired max_nall).
163185 compile_opts : dict
@@ -188,7 +210,7 @@ def fn(
188210 fparam : torch .Tensor | None ,
189211 aparam : torch .Tensor | None ,
190212 ) -> dict [str , torch .Tensor ]:
191- extended_coord = extended_coord .detach (). requires_grad_ (True )
213+ extended_coord = extended_coord .requires_grad_ (True )
192214 return model .forward_lower (
193215 extended_coord ,
194216 extended_atype ,
@@ -203,13 +225,15 @@ def fn(
203225 # change at runtime, the caller catches the error and retraces.
204226 traced_lower = make_fx (fn )(ext_coord , ext_atype , nlist , mapping , fparam , aparam )
205227
228+ # make_fx inserts aten.detach.default for saved tensors used in the
229+ # decomposed autograd.grad backward ops. These detach nodes break
230+ # second-order gradient flow (d(force)/d(params) for force training).
231+ # Removing them restores correct higher-order derivatives.
232+ _remove_detach_nodes (traced_lower )
233+
206234 if not was_training :
207235 model .eval ()
208236
209- # The inductor backend does not propagate gradients through the
210- # make_fx-decomposed autograd.grad ops (second-order gradients for
211- # force training). Use "aot_eager" which correctly preserves the
212- # gradient chain while still benefiting from make_fx decomposition.
213237 if "backend" not in compile_opts :
214238 compile_opts ["backend" ] = "aot_eager"
215239 compiled_lower = torch .compile (traced_lower , dynamic = False , ** compile_opts )
@@ -839,10 +863,6 @@ def _make_sample(
839863 # torch.compile -------------------------------------------------------
840864 self .enable_compile = training_params .get ("enable_compile" , False )
841865 if self .enable_compile :
842- if self .multi_task :
843- raise ValueError (
844- "torch.compile is not supported with multi-task training."
845- )
846866 compile_opts = training_params .get ("compile_options" , {})
847867 log .info ("Compiling model with torch.compile (%s)" , compile_opts )
848868 self ._compile_model (compile_opts )
@@ -878,108 +898,117 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
878898 normalize_coord ,
879899 )
880900
881- model = self .model
882-
883- # --- Estimate max_nall by sampling multiple batches ---
884- n_sample = 20
885- max_nall = 0
886- best_sample : (
887- tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , dict ] | None
888- ) = None
889-
890- for _ii in range (n_sample ):
891- inp , _ = self .get_data (is_train = True )
892- coord = inp ["coord" ].detach ()
893- atype = inp ["atype" ].detach ()
894- box = inp .get ("box" )
895- if box is not None :
896- box = box .detach ()
897-
898- nframes , nloc = atype .shape [:2 ]
899- coord_np = coord .cpu ().numpy ().reshape (nframes , nloc , 3 )
900- atype_np = atype .cpu ().numpy ()
901- box_np = box .cpu ().numpy ().reshape (nframes , 9 ) if box is not None else None
902-
903- if box_np is not None :
904- coord_norm = normalize_coord (coord_np , box_np .reshape (nframes , 3 , 3 ))
905- else :
906- coord_norm = coord_np
901+ for task_key in self .model_keys :
902+ model = self .wrapper .model [task_key ]
903+
904+ # --- Estimate max_nall by sampling multiple batches ---
905+ n_sample = 20
906+ max_nall = 0
907+ best_sample : (
908+ tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , dict ] | None
909+ ) = None
910+
911+ for _ii in range (n_sample ):
912+ inp , _ = self .get_data (is_train = True , task_key = task_key )
913+ coord = inp ["coord" ].detach ()
914+ atype = inp ["atype" ].detach ()
915+ box = inp .get ("box" )
916+ if box is not None :
917+ box = box .detach ()
918+
919+ nframes , nloc = atype .shape [:2 ]
920+ coord_np = coord .cpu ().numpy ().reshape (nframes , nloc , 3 )
921+ atype_np = atype .cpu ().numpy ()
922+ box_np = (
923+ box .cpu ().numpy ().reshape (nframes , 9 ) if box is not None else None
924+ )
907925
908- ext_coord_np , ext_atype_np , mapping_np = extend_coord_with_ghosts (
909- coord_norm , atype_np , box_np , model .get_rcut ()
910- )
911- nlist_np = build_neighbor_list (
912- ext_coord_np ,
913- ext_atype_np ,
914- nloc ,
915- model .get_rcut (),
916- model .get_sel (),
917- distinguish_types = False ,
918- )
919- ext_coord_np = ext_coord_np .reshape (nframes , - 1 , 3 )
920- nall = ext_coord_np .shape [1 ]
921- if nall > max_nall :
922- max_nall = nall
923- best_sample = (
926+ if box_np is not None :
927+ coord_norm = normalize_coord (
928+ coord_np , box_np .reshape (nframes , 3 , 3 )
929+ )
930+ else :
931+ coord_norm = coord_np
932+
933+ ext_coord_np , ext_atype_np , mapping_np = extend_coord_with_ghosts (
934+ coord_norm , atype_np , box_np , model .get_rcut ()
935+ )
936+ nlist_np = build_neighbor_list (
924937 ext_coord_np ,
925938 ext_atype_np ,
926- mapping_np ,
927- nlist_np ,
928939 nloc ,
929- inp ,
940+ model .get_rcut (),
941+ model .get_sel (),
942+ distinguish_types = False ,
930943 )
944+ ext_coord_np = ext_coord_np .reshape (nframes , - 1 , 3 )
945+ nall = ext_coord_np .shape [1 ]
946+ if nall > max_nall :
947+ max_nall = nall
948+ best_sample = (
949+ ext_coord_np ,
950+ ext_atype_np ,
951+ mapping_np ,
952+ nlist_np ,
953+ nloc ,
954+ inp ,
955+ )
931956
932- # Add 20 % margin and round up to a multiple of 8.
933- max_nall = ((int (max_nall * 1.2 ) + 7 ) // 8 ) * 8
934- log .info (
935- "Estimated max_nall=%d for compiled model (sampled %d batches)." ,
936- max_nall ,
937- n_sample ,
938- )
939-
940- # --- Pad the largest sample to max_nall and trace ---
941- assert best_sample is not None
942- ext_coord_np , ext_atype_np , mapping_np , nlist_np , nloc , sample_input = (
943- best_sample
944- )
945- nframes = ext_coord_np .shape [0 ]
946- actual_nall = ext_coord_np .shape [1 ]
947- pad_n = max_nall - actual_nall
948-
949- if pad_n > 0 :
950- ext_coord_np = np .pad (ext_coord_np , ((0 , 0 ), (0 , pad_n ), (0 , 0 )))
951- ext_atype_np = np .pad (ext_atype_np , ((0 , 0 ), (0 , pad_n )))
952- mapping_np = np .pad (mapping_np , ((0 , 0 ), (0 , pad_n )))
957+ # Add 20 % margin and round up to a multiple of 8.
958+ max_nall = ((int (max_nall * 1.2 ) + 7 ) // 8 ) * 8
959+ log .info (
960+ "Estimated max_nall=%d for compiled model "
961+ "(task=%s, sampled %d batches)." ,
962+ max_nall ,
963+ task_key ,
964+ n_sample ,
965+ )
953966
954- ext_coord = torch . tensor (
955- ext_coord_np , dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
956- )
957- ext_atype = torch . tensor ( ext_atype_np , dtype = torch . int64 , device = DEVICE )
958- nlist_t = torch . tensor ( nlist_np , dtype = torch . int64 , device = DEVICE )
959- mapping_t = torch . tensor ( mapping_np , dtype = torch . int64 , device = DEVICE )
960- fparam = sample_input . get ( "fparam" )
961- aparam = sample_input . get ( "aparam" )
967+ # --- Pad the largest sample to max_nall and trace ---
968+ assert best_sample is not None
969+ ext_coord_np , ext_atype_np , mapping_np , nlist_np , nloc , sample_input = (
970+ best_sample
971+ )
972+ nframes = ext_coord_np . shape [ 0 ]
973+ actual_nall = ext_coord_np . shape [ 1 ]
974+ pad_n = max_nall - actual_nall
962975
963- compile_opts .pop ("dynamic" , None ) # always False for padded approach
976+ if pad_n > 0 :
977+ ext_coord_np = np .pad (ext_coord_np , ((0 , 0 ), (0 , pad_n ), (0 , 0 )))
978+ ext_atype_np = np .pad (ext_atype_np , ((0 , 0 ), (0 , pad_n )))
979+ mapping_np = np .pad (mapping_np , ((0 , 0 ), (0 , pad_n )))
964980
965- compiled_lower = _trace_and_compile (
966- model ,
967- ext_coord ,
968- ext_atype ,
969- nlist_t ,
970- mapping_t ,
971- fparam ,
972- aparam ,
973- compile_opts ,
974- )
981+ ext_coord = torch .tensor (
982+ ext_coord_np , dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
983+ )
984+ ext_atype = torch .tensor (ext_atype_np , dtype = torch .int64 , device = DEVICE )
985+ nlist_t = torch .tensor (nlist_np , dtype = torch .int64 , device = DEVICE )
986+ mapping_t = torch .tensor (mapping_np , dtype = torch .int64 , device = DEVICE )
987+ fparam = sample_input .get ("fparam" )
988+ aparam = sample_input .get ("aparam" )
989+
990+ task_compile_opts = dict (compile_opts )
991+ task_compile_opts .pop ("dynamic" , None ) # always False for padded approach
992+
993+ compiled_lower = _trace_and_compile (
994+ model ,
995+ ext_coord ,
996+ ext_atype ,
997+ nlist_t ,
998+ mapping_t ,
999+ fparam ,
1000+ aparam ,
1001+ task_compile_opts ,
1002+ )
9751003
976- self .wrapper .model ["Default" ] = _CompiledModel (
977- model , compiled_lower , max_nall , compile_opts
978- )
979- log .info (
980- "Model compiled with padded nall=%d (tracing_mode=real, dynamic=False)." ,
981- max_nall ,
982- )
1004+ self .wrapper .model [task_key ] = _CompiledModel (
1005+ model , compiled_lower , max_nall , task_compile_opts
1006+ )
1007+ log .info (
1008+ "Model compiled with padded nall=%d (task=%s, dynamic=False)." ,
1009+ max_nall ,
1010+ task_key ,
1011+ )
9831012
9841013 # ------------------------------------------------------------------
9851014 # Data helpers
0 commit comments