Skip to content

Commit e329579

Browse files
authored
Merge branch 'deepmodeling:master' into fix/rmse-loss-normalization
2 parents c856d72 + 3f91293 commit e329579

45 files changed

Lines changed: 9090 additions & 582 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ def get_env_protection(self) -> float:
405405
"""Returns the protection of building environment matrix."""
406406
return self.se_atten.get_env_protection()
407407

408+
def get_numb_attn_layer(self) -> int:
409+
"""Returns the number of se_atten attention layers."""
410+
return self.se_atten.attn_layer
411+
408412
def share_params(
409413
self, base_class: "DescrptDPA1", shared_level: int, resume: bool = False
410414
) -> NoReturn:

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,14 @@ def get_rcut(self) -> float:
345345
"""Returns the cut-off radius."""
346346
return self.rcut
347347

348+
def get_rcut_smth(self) -> float:
349+
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
350+
return self.rcut_smth
351+
352+
def get_env_protection(self) -> float:
353+
"""Returns the protection of building environment matrix."""
354+
return self.env_protection
355+
348356
def get_nsel(self) -> int:
349357
"""Returns the number of selected atoms in the cut-off radius."""
350358
return sum(self.sel)

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def compute_input_stats(
255255
stat_file_path : Optional[DPPath]
256256
The path to the stat file.
257257
"""
258+
self._param_stats: dict[str, list[StatItem]] = {}
258259
if self.numb_fparam == 0 and self.numb_aparam == 0:
259260
# skip data statistics
260261
return
@@ -296,6 +297,7 @@ def compute_input_stats(
296297
self._save_param_stats_to_file(
297298
stat_file_path, "fparam", fparam_stats
298299
)
300+
self._param_stats["fparam"] = fparam_stats
299301
fparam_avg = np.array(
300302
[s.compute_avg() for s in fparam_stats], dtype=np.float64
301303
)
@@ -362,6 +364,7 @@ def compute_input_stats(
362364
self._save_param_stats_to_file(
363365
stat_file_path, "aparam", aparam_stats
364366
)
367+
self._param_stats["aparam"] = aparam_stats
365368
aparam_avg = np.array(
366369
[s.compute_avg() for s in aparam_stats], dtype=np.float64
367370
)
@@ -407,6 +410,10 @@ def _load_param_stats_from_file(
407410
for ii in range(numb)
408411
]
409412

413+
def get_param_stats(self) -> dict[str, list[StatItem]]:
414+
"""Get the stored fparam/aparam statistics (populated by compute_input_stats)."""
415+
return getattr(self, "_param_stats", {})
416+
410417
@abstractmethod
411418
def _net_out_dim(self) -> int:
412419
"""Set the FittingNet output dim."""
@@ -666,11 +673,13 @@ def _call_common(
666673
# check fparam dim, concate to input descriptor
667674
if self.numb_fparam > 0:
668675
assert fparam is not None, "fparam should not be None"
669-
if fparam.shape[-1] != self.numb_fparam:
676+
try:
677+
fparam = xp.reshape(fparam, (nf, self.numb_fparam))
678+
except (ValueError, RuntimeError) as e:
670679
raise ValueError(
671-
f"get an input fparam of dim {fparam.shape[-1]}, "
672-
f"which is not consistent with {self.numb_fparam}."
673-
)
680+
f"input fparam: cannot reshape {fparam.shape} "
681+
f"into ({nf}, {self.numb_fparam})."
682+
) from e
674683
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
675684
fparam = xp.tile(
676685
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
@@ -687,12 +696,13 @@ def _call_common(
687696
# check aparam dim, concate to input descriptor
688697
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
689698
assert aparam is not None, "aparam should not be None"
690-
if aparam.shape[-1] != self.numb_aparam:
699+
try:
700+
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
701+
except (ValueError, RuntimeError) as e:
691702
raise ValueError(
692-
f"get an input aparam of dim {aparam.shape[-1]}, "
693-
f"which is not consistent with {self.numb_aparam}."
694-
)
695-
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
703+
f"input aparam: cannot reshape {aparam.shape} "
704+
f"into ({nf}, {nloc}, {self.numb_aparam})."
705+
) from e
696706
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
697707
xx = xp.concat(
698708
[xx, aparam],
@@ -735,7 +745,8 @@ def _call_common(
735745
)
736746
for type_i in range(self.ntypes):
737747
mask = xp.tile(
738-
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
748+
xp.reshape((atype == type_i), (nf, nloc, 1)),
749+
(1, 1, net_dim_out),
739750
)
740751
atom_property = self.nets[(type_i,)](xx)
741752
if self.remove_vaccum_contribution is not None and not (

deepmd/dpmodel/utils/env_mat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def _make_env_mat(
6868
xp = array_api_compat.array_namespace(nlist)
6969
nf, nloc, nnei = nlist.shape
7070
# nf x nall x 3
71-
coord = xp.reshape(coord, (nf, -1, 3))
71+
# Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise
72+
# both to (nf, nall, 3) using shape-based inference so the concrete nf
73+
# value is not baked into the reshape.
74+
if coord.ndim == 2:
75+
coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3))
7276
mask = nlist >= 0
7377
nlist = nlist * xp.astype(mask, nlist.dtype)
7478
# nf x (nloc x nnei) x 3
@@ -77,7 +81,7 @@ def _make_env_mat(
7781
# nf x nloc x nnei x 3
7882
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
7983
# nf x nloc x 1 x 3
80-
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3))
84+
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, nloc, 1, 3))
8185
# nf x nloc x nnei x 3
8286
diff = coord_r - coord_l
8387
# nf x nloc x nnei

deepmd/dpmodel/utils/env_mat_stat.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,75 @@
4040
)
4141

4242

43+
def merge_env_stat(
44+
base_obj: Union["Descriptor", "DescriptorBlock"],
45+
link_obj: Union["Descriptor", "DescriptorBlock"],
46+
model_prob: float = 1.0,
47+
) -> None:
48+
"""Merge descriptor env mat stats from link_obj into base_obj.
49+
50+
Uses probability-weighted merging: merged = base_stats + link_stats * model_prob,
51+
where model_prob = link_prob / base_prob.
52+
Mutates base_obj.stats for chaining (3+ models).
53+
54+
Parameters
55+
----------
56+
base_obj : Descriptor or DescriptorBlock
57+
The base descriptor whose stats will be updated.
58+
link_obj : Descriptor or DescriptorBlock
59+
The linked descriptor whose stats will be merged in.
60+
model_prob : float
61+
The probability weight ratio (link_prob / base_prob).
62+
"""
63+
if (
64+
getattr(base_obj, "stats", None) is None
65+
or getattr(link_obj, "stats", None) is None
66+
):
67+
return
68+
if getattr(base_obj, "set_stddev_constant", False) and getattr(
69+
base_obj, "set_davg_zero", False
70+
):
71+
return
72+
73+
# Weighted merge of StatItem objects
74+
base_stats = base_obj.stats
75+
link_stats = link_obj.stats
76+
merged_stats = {}
77+
for kk in base_stats:
78+
merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob
79+
80+
# Compute mean/stddev from merged stats
81+
base_env = EnvMatStatSe(base_obj)
82+
base_env.stats = merged_stats
83+
mean, stddev = base_env()
84+
85+
# Update base_obj stats for chaining
86+
base_obj.stats = merged_stats
87+
88+
# Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks)
89+
# mean/stddev are numpy arrays; convert to match the buffer's backend
90+
if hasattr(base_obj, "davg"):
91+
xp = array_api_compat.array_namespace(base_obj.dstd)
92+
device = array_api_compat.device(base_obj.dstd)
93+
if not getattr(base_obj, "set_davg_zero", False):
94+
base_obj.davg[...] = xp.asarray(
95+
mean, dtype=base_obj.davg.dtype, device=device
96+
)
97+
base_obj.dstd[...] = xp.asarray(
98+
stddev, dtype=base_obj.dstd.dtype, device=device
99+
)
100+
elif hasattr(base_obj, "mean"):
101+
xp = array_api_compat.array_namespace(base_obj.stddev)
102+
device = array_api_compat.device(base_obj.stddev)
103+
if not getattr(base_obj, "set_davg_zero", False):
104+
base_obj.mean[...] = xp.asarray(
105+
mean, dtype=base_obj.mean.dtype, device=device
106+
)
107+
base_obj.stddev[...] = xp.asarray(
108+
stddev, dtype=base_obj.stddev.dtype, device=device
109+
)
110+
111+
43112
class EnvMatStat(BaseEnvMatStat):
44113
def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]:
45114
"""Compute the statistics of the environment matrix for a single system.

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def get_env_protection(self) -> float:
383383
"""Returns the protection of building environment matrix."""
384384
return self.se_atten.get_env_protection()
385385

386+
def get_numb_attn_layer(self) -> int:
387+
"""Returns the number of se_atten attention layers."""
388+
return self.se_atten.attn_layer
389+
386390
def share_params(
387391
self, base_class: Any, shared_level: int, resume: bool = False
388392
) -> None:

deepmd/pt/model/task/fitting.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,10 @@ def _forward_common(
779779
assert fparam is not None, "fparam should not be None"
780780
assert self.fparam_avg is not None
781781
assert self.fparam_inv_std is not None
782-
if fparam.shape[-1] != self.numb_fparam:
782+
if fparam.numel() != nf * self.numb_fparam:
783783
raise ValueError(
784-
"get an input fparam of dim {fparam.shape[-1]}, ",
785-
"which is not consistent with {self.numb_fparam}.",
784+
f"input fparam: cannot reshape {list(fparam.shape)} "
785+
f"into ({nf}, {self.numb_fparam})."
786786
)
787787
fparam = fparam.view([nf, self.numb_fparam])
788788
nb, _ = fparam.shape
@@ -804,10 +804,10 @@ def _forward_common(
804804
assert aparam is not None, "aparam should not be None"
805805
assert self.aparam_avg is not None
806806
assert self.aparam_inv_std is not None
807-
if aparam.shape[-1] != self.numb_aparam:
807+
if aparam.numel() % (nf * self.numb_aparam) != 0:
808808
raise ValueError(
809-
f"get an input aparam of dim {aparam.shape[-1]}, ",
810-
f"which is not consistent with {self.numb_aparam}.",
809+
f"input aparam: cannot reshape {list(aparam.shape)} "
810+
f"into ({nf}, nloc, {self.numb_aparam})."
811811
)
812812
aparam = aparam.view([nf, -1, self.numb_aparam])
813813
nb, nloc, _ = aparam.shape

deepmd/pt_expt/descriptor/dpa1.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
cast_precision,
1010
)
1111
from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP
12+
from deepmd.dpmodel.utils.env_mat_stat import (
13+
merge_env_stat,
14+
)
1215
from deepmd.pt_expt.common import (
1316
torch_module,
1417
)
@@ -26,6 +29,31 @@
2629
class DescrptDPA1(DescrptDPA1DP):
2730
_update_sel_cls = UpdateSel
2831

32+
def share_params(
33+
self,
34+
base_class: Any,
35+
shared_level: int,
36+
model_prob: float = 1.0,
37+
resume: bool = False,
38+
) -> None:
39+
"""Share parameters with base_class for multi-task training.
40+
41+
Level 0: share type_embedding and se_atten (all modules and buffers).
42+
Level 1: share type_embedding only.
43+
"""
44+
assert self.__class__ == base_class.__class__, (
45+
"Only descriptors of the same type can share params!"
46+
)
47+
if shared_level == 0:
48+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
49+
if not resume:
50+
merge_env_stat(base_class.se_atten, self.se_atten, model_prob)
51+
self._modules["se_atten"] = base_class._modules["se_atten"]
52+
elif shared_level == 1:
53+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
54+
else:
55+
raise NotImplementedError
56+
2957
def enable_compression(
3058
self,
3159
min_nbor_dist: float,

deepmd/pt_expt/descriptor/dpa2.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
build_multiple_neighbor_list,
1515
get_multiple_nlist_key,
1616
)
17+
from deepmd.dpmodel.utils.env_mat_stat import (
18+
merge_env_stat,
19+
)
1720
from deepmd.pt_expt.common import (
1821
torch_module,
1922
)
@@ -30,6 +33,47 @@
3033
class DescrptDPA2(DescrptDPA2DP):
3134
_update_sel_cls = UpdateSel
3235

36+
def share_params(
37+
self,
38+
base_class: "DescrptDPA2",
39+
shared_level: int,
40+
model_prob: float = 1.0,
41+
resume: bool = False,
42+
) -> None:
43+
"""Share parameters with base_class for multi-task training.
44+
45+
Level 0: share type_embedding, repinit, repinit_three_body,
46+
g1_shape_tranform, and repformers.
47+
Level 1: share type_embedding only.
48+
"""
49+
assert self.__class__ == base_class.__class__, (
50+
"Only descriptors of the same type can share params!"
51+
)
52+
if shared_level == 0:
53+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
54+
if not resume:
55+
merge_env_stat(base_class.repinit, self.repinit, model_prob)
56+
if self.use_three_body and "repinit_three_body" in base_class._modules:
57+
merge_env_stat(
58+
base_class.repinit_three_body,
59+
self.repinit_three_body,
60+
model_prob,
61+
)
62+
merge_env_stat(base_class.repformers, self.repformers, model_prob)
63+
self._modules["repinit"] = base_class._modules["repinit"]
64+
if self.use_three_body and "repinit_three_body" in base_class._modules:
65+
self._modules["repinit_three_body"] = base_class._modules[
66+
"repinit_three_body"
67+
]
68+
self._modules["g1_shape_tranform"] = base_class._modules[
69+
"g1_shape_tranform"
70+
]
71+
self._modules["repformers"] = base_class._modules["repformers"]
72+
elif shared_level == 1:
73+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
74+
else:
75+
raise NotImplementedError
76+
3377
def enable_compression(
3478
self,
3579
min_nbor_dist: float,

deepmd/pt_expt/descriptor/dpa3.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

33
from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP
4+
from deepmd.dpmodel.utils.env_mat_stat import (
5+
merge_env_stat,
6+
)
47
from deepmd.pt_expt.common import (
58
torch_module,
69
)
@@ -16,3 +19,28 @@
1619
@torch_module
1720
class DescrptDPA3(DescrptDPA3DP):
1821
_update_sel_cls = UpdateSel
22+
23+
def share_params(
24+
self,
25+
base_class: "DescrptDPA3",
26+
shared_level: int,
27+
model_prob: float = 1.0,
28+
resume: bool = False,
29+
) -> None:
30+
"""Share parameters with base_class for multi-task training.
31+
32+
Level 0: share type_embedding and repflows.
33+
Level 1: share type_embedding only.
34+
"""
35+
assert self.__class__ == base_class.__class__, (
36+
"Only descriptors of the same type can share params!"
37+
)
38+
if shared_level == 0:
39+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
40+
if not resume:
41+
merge_env_stat(base_class.repflows, self.repflows, model_prob)
42+
self._modules["repflows"] = base_class._modules["repflows"]
43+
elif shared_level == 1:
44+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
45+
else:
46+
raise NotImplementedError

0 commit comments

Comments
 (0)