@@ -168,7 +168,55 @@ def create_tracker(tracker_name: str, config: dict, **kwargs) -> BaseTracker:
168168 tracker_cls = tracker_registry [tracker_name ]
169169 return tracker_cls (config , ** kwargs )
170170
171+ class TrackioTracker (BaseTracker ):
172+
173+ def __init__ (self , config : dict , ** kwargs ):
174+ self .config = config
175+
176+ project = kwargs .pop ("project" , None )
177+ name = kwargs .pop ("name" , None )
178+ group = kwargs .pop ("group" , None )
179+ space_id = kwargs .pop ("space_id" , None )
180+ dataset_id = kwargs .pop ("dataset_id" , None )
181+ tags = kwargs .pop ("tags" , None )
182+
183+ auto_log_gpu = kwargs .pop ("auto_log_gpu" , True )
184+ gpu_log_interval = kwargs .pop ("gpu_log_interval" , 2 )
185+
186+ import trackio
187+
188+ if space_id :
189+ logger .info (f"[Trackio] Using HF Space: { space_id } " )
190+ if dataset_id :
191+ logger .info (f"[Trackio] Syncing to dataset: { dataset_id } " )
192+
193+ self .run = trackio .init (
194+ project = project ,
195+ name = name ,
196+ group = group ,
197+ config = config ,
198+ space_id = space_id ,
199+ dataset_id = dataset_id ,
200+ tags = tags ,
201+ auto_log_gpu = auto_log_gpu ,
202+ gpu_log_interval = gpu_log_interval ,
203+ )
204+
205+ @strip_at_tag_in_log
206+ def log (self , values : dict , step : Optional [int ], ** kwargs ):
207+ if step is not None :
208+ values = dict (values )
209+ values ["step" ] = step
210+ self .run .log (values )
211+
212+ def log_system (self , values : dict ):
213+ self .run .log_system (values )
214+
215+ def finish (self ):
216+ self .run .finish ()
217+
171218tracker_registry ["tensorboard" ] = TensorBoardTracker
172219tracker_registry ["wandb" ] = WandbTracker
173220tracker_registry ["stdout" ] = StdoutTracker
174221tracker_registry ["swanlab" ] = SwanlabTracker
222+ tracker_registry ["trackio" ] = TrackioTracker
0 commit comments