Skip to content

Commit ee2df8e

Browse files
committed
add topk logic
1 parent ccdfedd commit ee2df8e

4 files changed

Lines changed: 242 additions & 59 deletions

File tree

deepmd/pt/train/validation.py

Lines changed: 123 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77

88
import logging
9+
import re
910
import traceback
1011
from dataclasses import (
1112
dataclass,
@@ -68,11 +69,16 @@
6869
("V_RMSE", "rmse_v_per_atom"),
6970
]
7071

71-
BEST_METRIC_INFO_KEY = "full_validation_best_metric"
72-
BEST_STEP_INFO_KEY = "full_validation_best_step"
72+
TOPK_RECORDS_INFO_KEY = "full_validation_topk_records"
7373
BEST_METRIC_NAME_INFO_KEY = "full_validation_metric"
74-
BEST_CKPT_GLOB = "best.ckpt-*.pt"
75-
BEST_PATH_INFO_KEY_COMPAT = "full_validation_best_path"
74+
BEST_CKPT_GLOB = "best.ckpt-*.t-*.pt"
75+
BEST_CKPT_PATTERN = re.compile(r"^best\.ckpt-(\d+)\.t-(\d+)\.pt$")
76+
STALE_FULL_VALIDATION_INFO_KEYS = (
77+
"full_validation_best_metric",
78+
"full_validation_best_step",
79+
"full_validation_best_path",
80+
"full_validation_best_records",
81+
)
7682
BATCH_SIZE_LOGGER_NAME = "deepmd.utils.batch_size"
7783
VAL_LOG_SIGNIFICANT_DIGITS = 5
7884
VAL_LOG_COLUMN_GAP = " "
@@ -96,6 +102,14 @@ class FullValidationResult:
96102
saved_best_path: str | None
97103

98104

105+
@dataclass(order=True, frozen=True)
106+
class BestCheckpointRecord:
107+
"""One best-checkpoint record ordered by metric then step."""
108+
109+
metric: float
110+
step: int
111+
112+
99113
def resolve_full_validation_start_step(
100114
full_val_start: float, num_steps: int
101115
) -> int | None:
@@ -251,6 +265,7 @@ def __init__(
251265
self.full_validation = bool(validating_params.get("full_validation", False))
252266
self.validation_freq = int(validating_params.get("validation_freq", 5000))
253267
self.save_best = bool(validating_params.get("save_best", True))
268+
self.max_best_ckpt = int(validating_params.get("max_best_ckpt", 1))
254269
self.metric_name, self.metric_key = parse_validation_metric(
255270
str(validating_params.get("validation_metric", "E:MAE"))
256271
)
@@ -278,15 +293,7 @@ def __init__(
278293
(metric_key, header_label, max(len(header_label), 18))
279294
)
280295

281-
if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) == self.metric_name:
282-
best_metric = self.train_infos.get(BEST_METRIC_INFO_KEY)
283-
self.best_metric_value = (
284-
float(best_metric) if best_metric is not None else None
285-
)
286-
self.best_step = self.train_infos.get(BEST_STEP_INFO_KEY)
287-
else:
288-
self.best_metric_value = None
289-
self.best_step = None
296+
self.topk_records = self._load_topk_records()
290297
self._sync_train_infos()
291298
if self.rank == 0:
292299
self._initialize_best_checkpoints(restart_training=restart_training)
@@ -342,7 +349,7 @@ def run(
342349
else:
343350
save_checkpoint(Path(save_path[0]), lr=lr, step=step_id)
344351
if self.rank == 0:
345-
self._prune_best_checkpoints(keep_names={Path(save_path[0]).name})
352+
self._reconcile_best_checkpoints()
346353
except Exception as exc:
347354
caught_exception = exc
348355
error_message = (
@@ -553,31 +560,62 @@ def _update_best_state(
553560
display_step: int,
554561
selected_metric_value: float,
555562
) -> str | None:
556-
"""Update the best metric state and return the checkpoint path to save."""
557-
if (
558-
self.best_metric_value is not None
559-
and selected_metric_value >= self.best_metric_value
560-
):
563+
"""Update the top-K records and return the checkpoint path to save."""
564+
candidate = BestCheckpointRecord(
565+
metric=selected_metric_value,
566+
step=display_step,
567+
)
568+
updated_records = [
569+
record for record in self.topk_records if record.step != display_step
570+
]
571+
updated_records.append(candidate)
572+
updated_records.sort()
573+
updated_records = updated_records[: self.max_best_ckpt]
574+
if candidate not in updated_records:
561575
return None
562576

563-
new_best_path = (
564-
self._best_checkpoint_name(display_step) if self.save_best else None
565-
)
566-
self.best_metric_value = selected_metric_value
567-
self.best_step = display_step
577+
self.topk_records = updated_records
568578
self._sync_train_infos()
569-
return new_best_path
579+
if not self.save_best:
580+
return None
581+
candidate_rank = self.topk_records.index(candidate) + 1
582+
return self._best_checkpoint_name(display_step, candidate_rank)
570583

571584
def _sync_train_infos(self) -> None:
572-
"""Synchronize best validation state into train infos."""
573-
self.train_infos.pop(BEST_PATH_INFO_KEY_COMPAT, None)
585+
"""Synchronize top-K validation state into train infos."""
586+
for key in STALE_FULL_VALIDATION_INFO_KEYS:
587+
self.train_infos.pop(key, None)
574588
self.train_infos[BEST_METRIC_NAME_INFO_KEY] = self.metric_name
575-
self.train_infos[BEST_METRIC_INFO_KEY] = self.best_metric_value
576-
self.train_infos[BEST_STEP_INFO_KEY] = self.best_step
589+
self.train_infos[TOPK_RECORDS_INFO_KEY] = [
590+
{"metric": record.metric, "step": record.step}
591+
for record in self.topk_records
592+
]
593+
594+
def _load_topk_records(self) -> list[BestCheckpointRecord]:
595+
"""Load top-K records from train infos for the current metric."""
596+
if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) != self.metric_name:
597+
return []
598+
raw_records = self.train_infos.get(TOPK_RECORDS_INFO_KEY, [])
599+
if not isinstance(raw_records, list):
600+
return []
601+
records = []
602+
for raw_record in raw_records:
603+
if not isinstance(raw_record, dict):
604+
continue
605+
if "metric" not in raw_record or "step" not in raw_record:
606+
continue
607+
records.append(
608+
BestCheckpointRecord(
609+
metric=float(raw_record["metric"]),
610+
step=int(raw_record["step"]),
611+
)
612+
)
613+
records.sort()
614+
return records[: self.max_best_ckpt]
577615

578-
def _best_checkpoint_name(self, step: int) -> str:
616+
def _best_checkpoint_name(self, step: int, rank: int) -> str:
579617
"""Build the best-checkpoint filename for one step."""
580-
return f"best.ckpt-{step}.pt"
618+
return f"best.ckpt-{step}.t-{rank}.pt"
581619

582620
def _list_best_checkpoints(self) -> list[Path]:
583621
"""List all managed best checkpoints in the working directory."""
@@ -589,21 +627,63 @@ def _list_best_checkpoints(self) -> list[Path]:
589627
best_checkpoints.sort(key=lambda path: path.stat().st_mtime)
590628
return best_checkpoints
591629

592-
def _prune_best_checkpoints(self, keep_names: set[str] | None = None) -> None:
593-
"""Delete managed best checkpoints except the requested ones."""
594-
keep_names = set() if keep_names is None else keep_names
595-
for checkpoint_path in self._list_best_checkpoints():
596-
if checkpoint_path.name not in keep_names:
597-
checkpoint_path.unlink(missing_ok=True)
630+
def _expected_topk_checkpoint_names(self) -> dict[int, str]:
631+
"""Return the expected checkpoint filename for each retained step."""
632+
return {
633+
record.step: self._best_checkpoint_name(record.step, rank)
634+
for rank, record in enumerate(self.topk_records, start=1)
635+
}
636+
637+
def _reconcile_best_checkpoints(self) -> None:
638+
"""Rename retained best checkpoints to ranked names and delete stale ones."""
639+
expected_names = self._expected_topk_checkpoint_names()
640+
current_files = self._list_best_checkpoints()
641+
files_by_step: dict[int, list[Path]] = {}
642+
stale_files: list[Path] = []
643+
for checkpoint_path in current_files:
644+
match = BEST_CKPT_PATTERN.match(checkpoint_path.name)
645+
if match is None:
646+
stale_files.append(checkpoint_path)
647+
continue
648+
step = int(match.group(1))
649+
files_by_step.setdefault(step, []).append(checkpoint_path)
650+
651+
temp_moves: list[tuple[Path, Path]] = []
652+
for step, checkpoint_paths in files_by_step.items():
653+
expected_name = expected_names.get(step)
654+
if expected_name is None:
655+
stale_files.extend(checkpoint_paths)
656+
continue
657+
658+
keep_path = next(
659+
(
660+
checkpoint_path
661+
for checkpoint_path in checkpoint_paths
662+
if checkpoint_path.name == expected_name
663+
),
664+
checkpoint_paths[0],
665+
)
666+
for checkpoint_path in checkpoint_paths:
667+
if checkpoint_path != keep_path:
668+
stale_files.append(checkpoint_path)
669+
if keep_path.name != expected_name:
670+
temp_path = keep_path.with_name(f"{keep_path.name}.tmp")
671+
keep_path.rename(temp_path)
672+
temp_moves.append((temp_path, keep_path.with_name(expected_name)))
673+
674+
for checkpoint_path in stale_files:
675+
checkpoint_path.unlink(missing_ok=True)
676+
for temp_path, final_path in temp_moves:
677+
final_path.unlink(missing_ok=True)
678+
temp_path.rename(final_path)
598679

599680
def _initialize_best_checkpoints(self, restart_training: bool) -> None:
600681
"""Align on-disk best checkpoints with the current training mode."""
601-
if restart_training and self.save_best and self.best_step is not None:
602-
self._prune_best_checkpoints(
603-
keep_names={self._best_checkpoint_name(int(self.best_step))}
604-
)
605-
else:
606-
self._prune_best_checkpoints()
682+
if restart_training and self.save_best and self.topk_records:
683+
self._reconcile_best_checkpoints()
684+
return
685+
for checkpoint_path in self._list_best_checkpoints():
686+
checkpoint_path.unlink(missing_ok=True)
607687

608688
def _raise_if_distributed_error(
609689
self,

deepmd/utils/argcheck.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4085,6 +4085,11 @@ def validating_args() -> Argument:
40854085
"Whether to save an extra checkpoint when the selected full validation "
40864086
"metric reaches a new best value."
40874087
)
4088+
doc_max_best_ckpt = (
4089+
"The maximum number of top-ranked best checkpoints to keep. The best "
4090+
"checkpoints are ranked by the selected validation metric in ascending "
4091+
"order. Default is 1."
4092+
)
40884093
doc_validation_metric = (
40894094
"Metric used to determine the best checkpoint during full validation. "
40904095
f"Supported values are {valid_metrics}. The string is case-insensitive. "
@@ -4126,6 +4131,15 @@ def validating_args() -> Argument:
41264131
default=True,
41274132
doc=doc_only_pt_supported + doc_save_best,
41284133
),
4134+
Argument(
4135+
"max_best_ckpt",
4136+
int,
4137+
optional=True,
4138+
default=1,
4139+
doc=doc_only_pt_supported + doc_max_best_ckpt,
4140+
extra_check=lambda x: x > 0,
4141+
extra_check_errmsg="must be greater than 0",
4142+
),
41294143
Argument(
41304144
"validation_metric",
41314145
str,
@@ -4149,7 +4163,7 @@ def validating_args() -> Argument:
41494163
"full_val_start",
41504164
[int, float],
41514165
optional=True,
4152-
default=0.0,
4166+
default=0.5,
41534167
doc=doc_only_pt_supported + doc_full_val_start,
41544168
extra_check=lambda x: x >= 0,
41554169
extra_check_errmsg="must be greater than or equal to 0",

source/tests/pt/test_training.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -773,13 +773,14 @@ def setUp(self) -> None:
773773
self.config["training"]["training_data"]["systems"] = data_file
774774
self.config["training"]["validation_data"]["systems"] = data_file
775775
self.config["model"] = deepcopy(model_se_e2_a)
776-
self.config["training"]["numb_steps"] = 2
776+
self.config["training"]["numb_steps"] = 4
777777
self.config["training"]["save_freq"] = 100
778778
self.config["training"]["disp_training"] = False
779779
self.config["validating"] = {
780780
"full_validation": True,
781781
"validation_freq": 1,
782782
"save_best": True,
783+
"max_best_ckpt": 2,
783784
"validation_metric": "E:MAE",
784785
"full_val_file": "val.log",
785786
"full_val_start": 0.0,
@@ -792,24 +793,33 @@ def tearDown(self) -> None:
792793
@patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
793794
def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None:
794795
mocked_eval.side_effect = [
795-
{"mae_e_per_atom": 2.0},
796796
{"mae_e_per_atom": 1.0},
797+
{"mae_e_per_atom": 2.0},
798+
{"mae_e_per_atom": 0.5},
799+
{"mae_e_per_atom": 1.5},
797800
]
798-
Path("best.ckpt-999.pt").touch()
801+
Path("best.ckpt-999.t-1.pt").touch()
799802
trainer = get_trainer(deepcopy(self.config))
800803
trainer.run()
801804

802-
self.assertFalse(Path("best.ckpt-999.pt").exists())
803-
self.assertFalse(Path("best.ckpt-1.pt").exists())
804-
self.assertTrue(Path("best.ckpt-2.pt").exists())
805+
self.assertFalse(Path("best.ckpt-999.t-1.pt").exists())
806+
self.assertFalse(Path("best.ckpt-1.t-1.pt").exists())
807+
self.assertFalse(Path("best.ckpt-2.t-1.pt").exists())
808+
self.assertTrue(Path("best.ckpt-3.t-1.pt").exists())
809+
self.assertTrue(Path("best.ckpt-1.t-2.pt").exists())
805810
train_infos = trainer._get_inner_module().train_infos
806-
self.assertEqual(train_infos["full_validation_best_step"], 2)
807-
self.assertEqual(train_infos["full_validation_best_metric"], 1.0)
808-
self.assertNotIn("full_validation_best_path", train_infos)
811+
self.assertEqual(
812+
train_infos["full_validation_topk_records"],
813+
[
814+
{"metric": 0.5, "step": 3},
815+
{"metric": 1.0, "step": 1},
816+
],
817+
)
809818
with open("val.log") as fp:
810819
val_lines = [line for line in fp.readlines() if not line.startswith("#")]
811-
self.assertEqual(val_lines[0].split()[1], "2000.0")
812-
self.assertEqual(val_lines[1].split()[1], "1000.0")
820+
self.assertEqual(len(val_lines), 4)
821+
self.assertEqual(val_lines[0].split()[1], "1000.0")
822+
self.assertEqual(val_lines[1].split()[1], "2000.0")
813823

814824
def test_full_validation_rejects_spin_loss(self) -> None:
815825
config = deepcopy(self.config)

0 commit comments

Comments
 (0)