66)
77
88import logging
9+ import re
910import traceback
1011from dataclasses import (
1112 dataclass ,
6869 ("V_RMSE" , "rmse_v_per_atom" ),
6970]
7071
71- BEST_METRIC_INFO_KEY = "full_validation_best_metric"
72- BEST_STEP_INFO_KEY = "full_validation_best_step"
72+ TOPK_RECORDS_INFO_KEY = "full_validation_topk_records"
7373BEST_METRIC_NAME_INFO_KEY = "full_validation_metric"
74- BEST_CKPT_GLOB = "best.ckpt-*.pt"
75- BEST_PATH_INFO_KEY_COMPAT = "full_validation_best_path"
74+ BEST_CKPT_GLOB = "best.ckpt-*.t-*.pt"
75+ BEST_CKPT_PATTERN = re .compile (r"^best\.ckpt-(\d+)\.t-(\d+)\.pt$" )
76+ STALE_FULL_VALIDATION_INFO_KEYS = (
77+ "full_validation_best_metric" ,
78+ "full_validation_best_step" ,
79+ "full_validation_best_path" ,
80+ "full_validation_best_records" ,
81+ )
7682BATCH_SIZE_LOGGER_NAME = "deepmd.utils.batch_size"
7783VAL_LOG_SIGNIFICANT_DIGITS = 5
7884VAL_LOG_COLUMN_GAP = " "
@@ -96,6 +102,14 @@ class FullValidationResult:
96102 saved_best_path : str | None
97103
98104
105+ @dataclass (order = True , frozen = True )
106+ class BestCheckpointRecord :
107+ """One best-checkpoint record ordered by metric then step."""
108+
109+ metric : float
110+ step : int
111+
112+
99113def resolve_full_validation_start_step (
100114 full_val_start : float , num_steps : int
101115) -> int | None :
@@ -251,6 +265,7 @@ def __init__(
251265 self .full_validation = bool (validating_params .get ("full_validation" , False ))
252266 self .validation_freq = int (validating_params .get ("validation_freq" , 5000 ))
253267 self .save_best = bool (validating_params .get ("save_best" , True ))
268+ self .max_best_ckpt = int (validating_params .get ("max_best_ckpt" , 1 ))
254269 self .metric_name , self .metric_key = parse_validation_metric (
255270 str (validating_params .get ("validation_metric" , "E:MAE" ))
256271 )
@@ -278,15 +293,7 @@ def __init__(
278293 (metric_key , header_label , max (len (header_label ), 18 ))
279294 )
280295
281- if self .train_infos .get (BEST_METRIC_NAME_INFO_KEY ) == self .metric_name :
282- best_metric = self .train_infos .get (BEST_METRIC_INFO_KEY )
283- self .best_metric_value = (
284- float (best_metric ) if best_metric is not None else None
285- )
286- self .best_step = self .train_infos .get (BEST_STEP_INFO_KEY )
287- else :
288- self .best_metric_value = None
289- self .best_step = None
296+ self .topk_records = self ._load_topk_records ()
290297 self ._sync_train_infos ()
291298 if self .rank == 0 :
292299 self ._initialize_best_checkpoints (restart_training = restart_training )
@@ -342,7 +349,7 @@ def run(
342349 else :
343350 save_checkpoint (Path (save_path [0 ]), lr = lr , step = step_id )
344351 if self .rank == 0 :
345- self ._prune_best_checkpoints ( keep_names = { Path ( save_path [ 0 ]). name } )
352+ self ._reconcile_best_checkpoints ( )
346353 except Exception as exc :
347354 caught_exception = exc
348355 error_message = (
@@ -553,31 +560,62 @@ def _update_best_state(
553560 display_step : int ,
554561 selected_metric_value : float ,
555562 ) -> str | None :
556- """Update the best metric state and return the checkpoint path to save."""
557- if (
558- self .best_metric_value is not None
559- and selected_metric_value >= self .best_metric_value
560- ):
563+ """Update the top-K records and return the checkpoint path to save."""
564+ candidate = BestCheckpointRecord (
565+ metric = selected_metric_value ,
566+ step = display_step ,
567+ )
568+ updated_records = [
569+ record for record in self .topk_records if record .step != display_step
570+ ]
571+ updated_records .append (candidate )
572+ updated_records .sort ()
573+ updated_records = updated_records [: self .max_best_ckpt ]
574+ if candidate not in updated_records :
561575 return None
562576
563- new_best_path = (
564- self ._best_checkpoint_name (display_step ) if self .save_best else None
565- )
566- self .best_metric_value = selected_metric_value
567- self .best_step = display_step
577+ self .topk_records = updated_records
568578 self ._sync_train_infos ()
569- return new_best_path
579+ if not self .save_best :
580+ return None
581+ candidate_rank = self .topk_records .index (candidate ) + 1
582+ return self ._best_checkpoint_name (display_step , candidate_rank )
570583
571584 def _sync_train_infos (self ) -> None :
572- """Synchronize best validation state into train infos."""
573- self .train_infos .pop (BEST_PATH_INFO_KEY_COMPAT , None )
585+ """Synchronize top-K validation state into train infos."""
586+ for key in STALE_FULL_VALIDATION_INFO_KEYS :
587+ self .train_infos .pop (key , None )
574588 self .train_infos [BEST_METRIC_NAME_INFO_KEY ] = self .metric_name
575- self .train_infos [BEST_METRIC_INFO_KEY ] = self .best_metric_value
576- self .train_infos [BEST_STEP_INFO_KEY ] = self .best_step
589+ self .train_infos [TOPK_RECORDS_INFO_KEY ] = [
590+ {"metric" : record .metric , "step" : record .step }
591+ for record in self .topk_records
592+ ]
593+
594+ def _load_topk_records (self ) -> list [BestCheckpointRecord ]:
595+ """Load top-K records from train infos for the current metric."""
596+ if self .train_infos .get (BEST_METRIC_NAME_INFO_KEY ) != self .metric_name :
597+ return []
598+ raw_records = self .train_infos .get (TOPK_RECORDS_INFO_KEY , [])
599+ if not isinstance (raw_records , list ):
600+ return []
601+ records = []
602+ for raw_record in raw_records :
603+ if not isinstance (raw_record , dict ):
604+ continue
605+ if "metric" not in raw_record or "step" not in raw_record :
606+ continue
607+ records .append (
608+ BestCheckpointRecord (
609+ metric = float (raw_record ["metric" ]),
610+ step = int (raw_record ["step" ]),
611+ )
612+ )
613+ records .sort ()
614+ return records [: self .max_best_ckpt ]
577615
578- def _best_checkpoint_name (self , step : int ) -> str :
616+ def _best_checkpoint_name (self , step : int , rank : int ) -> str :
579617 """Build the best-checkpoint filename for one step."""
580- return f"best.ckpt-{ step } .pt"
618+ return f"best.ckpt-{ step } .t- { rank } . pt"
581619
582620 def _list_best_checkpoints (self ) -> list [Path ]:
583621 """List all managed best checkpoints in the working directory."""
@@ -589,21 +627,63 @@ def _list_best_checkpoints(self) -> list[Path]:
589627 best_checkpoints .sort (key = lambda path : path .stat ().st_mtime )
590628 return best_checkpoints
591629
592- def _prune_best_checkpoints (self , keep_names : set [str ] | None = None ) -> None :
593- """Delete managed best checkpoints except the requested ones."""
594- keep_names = set () if keep_names is None else keep_names
595- for checkpoint_path in self ._list_best_checkpoints ():
596- if checkpoint_path .name not in keep_names :
597- checkpoint_path .unlink (missing_ok = True )
630+ def _expected_topk_checkpoint_names (self ) -> dict [int , str ]:
631+ """Return the expected checkpoint filename for each retained step."""
632+ return {
633+ record .step : self ._best_checkpoint_name (record .step , rank )
634+ for rank , record in enumerate (self .topk_records , start = 1 )
635+ }
636+
637+ def _reconcile_best_checkpoints (self ) -> None :
638+ """Rename retained best checkpoints to ranked names and delete stale ones."""
639+ expected_names = self ._expected_topk_checkpoint_names ()
640+ current_files = self ._list_best_checkpoints ()
641+ files_by_step : dict [int , list [Path ]] = {}
642+ stale_files : list [Path ] = []
643+ for checkpoint_path in current_files :
644+ match = BEST_CKPT_PATTERN .match (checkpoint_path .name )
645+ if match is None :
646+ stale_files .append (checkpoint_path )
647+ continue
648+ step = int (match .group (1 ))
649+ files_by_step .setdefault (step , []).append (checkpoint_path )
650+
651+ temp_moves : list [tuple [Path , Path ]] = []
652+ for step , checkpoint_paths in files_by_step .items ():
653+ expected_name = expected_names .get (step )
654+ if expected_name is None :
655+ stale_files .extend (checkpoint_paths )
656+ continue
657+
658+ keep_path = next (
659+ (
660+ checkpoint_path
661+ for checkpoint_path in checkpoint_paths
662+ if checkpoint_path .name == expected_name
663+ ),
664+ checkpoint_paths [0 ],
665+ )
666+ for checkpoint_path in checkpoint_paths :
667+ if checkpoint_path != keep_path :
668+ stale_files .append (checkpoint_path )
669+ if keep_path .name != expected_name :
670+ temp_path = keep_path .with_name (f"{ keep_path .name } .tmp" )
671+ keep_path .rename (temp_path )
672+ temp_moves .append ((temp_path , keep_path .with_name (expected_name )))
673+
674+ for checkpoint_path in stale_files :
675+ checkpoint_path .unlink (missing_ok = True )
676+ for temp_path , final_path in temp_moves :
677+ final_path .unlink (missing_ok = True )
678+ temp_path .rename (final_path )
598679
599680 def _initialize_best_checkpoints (self , restart_training : bool ) -> None :
600681 """Align on-disk best checkpoints with the current training mode."""
601- if restart_training and self .save_best and self .best_step is not None :
602- self ._prune_best_checkpoints (
603- keep_names = {self ._best_checkpoint_name (int (self .best_step ))}
604- )
605- else :
606- self ._prune_best_checkpoints ()
682+ if restart_training and self .save_best and self .topk_records :
683+ self ._reconcile_best_checkpoints ()
684+ return
685+ for checkpoint_path in self ._list_best_checkpoints ():
686+ checkpoint_path .unlink (missing_ok = True )
607687
608688 def _raise_if_distributed_error (
609689 self ,
0 commit comments