@@ -255,6 +255,7 @@ def compute_input_stats(
255255 stat_file_path : Optional[DPPath]
256256 The path to the stat file.
257257 """
258+ self ._param_stats : dict [str , list [StatItem ]] = {}
258259 if self .numb_fparam == 0 and self .numb_aparam == 0 :
259260 # skip data statistics
260261 return
@@ -296,6 +297,7 @@ def compute_input_stats(
296297 self ._save_param_stats_to_file (
297298 stat_file_path , "fparam" , fparam_stats
298299 )
300+ self ._param_stats ["fparam" ] = fparam_stats
299301 fparam_avg = np .array (
300302 [s .compute_avg () for s in fparam_stats ], dtype = np .float64
301303 )
@@ -362,6 +364,7 @@ def compute_input_stats(
362364 self ._save_param_stats_to_file (
363365 stat_file_path , "aparam" , aparam_stats
364366 )
367+ self ._param_stats ["aparam" ] = aparam_stats
365368 aparam_avg = np .array (
366369 [s .compute_avg () for s in aparam_stats ], dtype = np .float64
367370 )
@@ -407,6 +410,10 @@ def _load_param_stats_from_file(
407410 for ii in range (numb )
408411 ]
409412
413+ def get_param_stats (self ) -> dict [str , list [StatItem ]]:
414+ """Get the stored fparam/aparam statistics (populated by compute_input_stats)."""
415+ return getattr (self , "_param_stats" , {})
416+
410417 @abstractmethod
411418 def _net_out_dim (self ) -> int :
412419 """Set the FittingNet output dim."""
@@ -666,11 +673,13 @@ def _call_common(
666673 # check fparam dim, concate to input descriptor
667674 if self .numb_fparam > 0 :
668675 assert fparam is not None , "fparam should not be None"
669- if fparam .shape [- 1 ] != self .numb_fparam :
676+ try :
677+ fparam = xp .reshape (fparam , (nf , self .numb_fparam ))
678+ except (ValueError , RuntimeError ) as e :
670679 raise ValueError (
671- f"get an input fparam of dim { fparam .shape [ - 1 ] } , "
672- f"which is not consistent with { self .numb_fparam } ."
673- )
680+ f"input fparam: cannot reshape { fparam .shape } "
681+ f"into ( { nf } , { self .numb_fparam } ) ."
682+ ) from e
674683 fparam = (fparam - self .fparam_avg [...]) * self .fparam_inv_std [...]
675684 fparam = xp .tile (
676685 xp .reshape (fparam , (nf , 1 , self .numb_fparam )), (1 , nloc , 1 )
@@ -687,12 +696,13 @@ def _call_common(
687696 # check aparam dim, concate to input descriptor
688697 if self .numb_aparam > 0 and not self .use_aparam_as_mask :
689698 assert aparam is not None , "aparam should not be None"
690- if aparam .shape [- 1 ] != self .numb_aparam :
699+ try :
700+ aparam = xp .reshape (aparam , (nf , nloc , self .numb_aparam ))
701+ except (ValueError , RuntimeError ) as e :
691702 raise ValueError (
692- f"get an input aparam of dim { aparam .shape [- 1 ]} , "
693- f"which is not consistent with { self .numb_aparam } ."
694- )
695- aparam = xp .reshape (aparam , (nf , nloc , self .numb_aparam ))
703+ f"input aparam: cannot reshape { aparam .shape } "
704+ f"into ({ nf } , { nloc } , { self .numb_aparam } )."
705+ ) from e
696706 aparam = (aparam - self .aparam_avg [...]) * self .aparam_inv_std [...]
697707 xx = xp .concat (
698708 [xx , aparam ],
@@ -735,7 +745,8 @@ def _call_common(
735745 )
736746 for type_i in range (self .ntypes ):
737747 mask = xp .tile (
738- xp .reshape ((atype == type_i ), (nf , nloc , 1 )), (1 , 1 , net_dim_out )
748+ xp .reshape ((atype == type_i ), (nf , nloc , 1 )),
749+ (1 , 1 , net_dim_out ),
739750 )
740751 atom_property = self .nets [(type_i ,)](xx )
741752 if self .remove_vaccum_contribution is not None and not (
0 commit comments