Skip to content

Commit 7526682

Browse files
ParagEkbotePanAndy
authored andcommitted
add initial trackio integration for roll.
1 parent a49a915 commit 7526682

2 files changed

Lines changed: 49 additions & 0 deletions

File tree

requirements_common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ sglang-router
2424

2525
wandb
2626
swanlab
27+
trackio
2728

2829
math-verify
2930
openai

roll/utils/tracking.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,55 @@ def create_tracker(tracker_name: str, config: dict, **kwargs) -> BaseTracker:
168168
tracker_cls = tracker_registry[tracker_name]
169169
return tracker_cls(config, **kwargs)
170170

171+
class TrackioTracker(BaseTracker):
172+
173+
def __init__(self, config: dict, **kwargs):
174+
self.config = config
175+
176+
project = kwargs.pop("project", None)
177+
name = kwargs.pop("name", None)
178+
group = kwargs.pop("group", None)
179+
space_id = kwargs.pop("space_id", None)
180+
dataset_id = kwargs.pop("dataset_id", None)
181+
tags = kwargs.pop("tags", None)
182+
183+
auto_log_gpu = kwargs.pop("auto_log_gpu", True)
184+
gpu_log_interval = kwargs.pop("gpu_log_interval", 2)
185+
186+
import trackio
187+
188+
if space_id:
189+
logger.info(f"[Trackio] Using HF Space: {space_id}")
190+
if dataset_id:
191+
logger.info(f"[Trackio] Syncing to dataset: {dataset_id}")
192+
193+
self.run = trackio.init(
194+
project=project,
195+
name=name,
196+
group=group,
197+
config=config,
198+
space_id=space_id,
199+
dataset_id=dataset_id,
200+
tags=tags,
201+
auto_log_gpu=auto_log_gpu,
202+
gpu_log_interval=gpu_log_interval,
203+
)
204+
205+
@strip_at_tag_in_log
206+
def log(self, values: dict, step: Optional[int], **kwargs):
207+
if step is not None:
208+
values = dict(values)
209+
values["step"] = step
210+
self.run.log(values)
211+
212+
def log_system(self, values: dict):
213+
self.run.log_system(values)
214+
215+
def finish(self):
216+
self.run.finish()
217+
171218
tracker_registry["tensorboard"] = TensorBoardTracker
172219
tracker_registry["wandb"] = WandbTracker
173220
tracker_registry["stdout"] = StdoutTracker
174221
tracker_registry["swanlab"] = SwanlabTracker
222+
tracker_registry["trackio"] = TrackioTracker

0 commit comments

Comments
 (0)