diff --git a/nemo_rl/experience/interfaces.py b/nemo_rl/experience/interfaces.py new file mode 100644 index 0000000000..83d591f4d1 --- /dev/null +++ b/nemo_rl/experience/interfaces.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Optional + +from nemo_rl.data.interfaces import LLMMessageLogType, VLMMessageLogType + + +@dataclass +class Completion: + """A single generated completion for one prompt.""" + + message_log: LLMMessageLogType + env_extras: dict[str, Any] + truncated: bool + reward: float + + +@dataclass +class PromptGroupRecord: + """All completions for a single prompt, with prompt-level metadata.""" + + prompt_idx: int + prompt: LLMMessageLogType | VLMMessageLogType + extra_env_info: Optional[dict[str, Any]] + metadata: dict[str, Any] + completions: list["Completion"] + rollout_metrics: dict[str, Any] diff --git a/nemo_rl/experience/rollout_manager.py b/nemo_rl/experience/rollout_manager.py new file mode 100644 index 0000000000..28ac778c96 --- /dev/null +++ b/nemo_rl/experience/rollout_manager.py @@ -0,0 +1,257 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import warnings +from typing import Any, Optional + +from transformers import PreTrainedTokenizerBase +from wandb import Table + +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.experience.interfaces import Completion, PromptGroupRecord +from nemo_rl.experience.rollouts import _calculate_single_metric, _tensorize_by_key +from nemo_rl.models.generation.interfaces import GenerationConfig +from nemo_rl.utils.timer import Timer + +TokenizerType = PreTrainedTokenizerBase + + +class AsyncNemoGymRolloutManager: + """Manages per-prompt NeMo-Gym rollouts, producing a PromptGroupRecord per call. + + Each run_rollout takes one prompt and returns num_generations_per_prompt completions + batched through a single NeMo-Gym run_rollouts call. + """ + + def __init__( + self, + tokenizer: TokenizerType, + task_to_env: dict[str, EnvironmentInterface], + generation_config: GenerationConfig, + num_generations_per_prompt: int, + max_seq_len: Optional[int] = None, + max_rollout_turns: Optional[int] = None, + ) -> None: + self._tokenizer = tokenizer + self._task_to_env = task_to_env + self._generation_config = generation_config + self._num_generations_per_prompt = num_generations_per_prompt + self._max_seq_len = max_seq_len + self._max_rollout_turns = max_rollout_turns + self._engine_max_model_len = generation_config["vllm_cfg"]["max_model_len"] + + self._validate_init_params() + + async def run_rollout(self, input_sample: DatumSpec) -> PromptGroupRecord: + """Run num_generations_per_prompt rollouts for one prompt. + + Args: + input_sample: A single prompt (one DatumSpec entry). + + Returns: + PromptGroupRecord with num_generations_per_prompt completions. + """ + timer = Timer() + timer_prefix = "timing/rollout" + timer.start(f"{timer_prefix}/total") + + rollout_inputs = self._build_inputs(input_sample) + completions, rollout_metrics = await self._run_rollouts( + rollout_inputs, timer, timer_prefix + ) + + timer.stop(f"{timer_prefix}/total") + rollout_metrics.update(timer.get_timing_metrics("sum")) + + return PromptGroupRecord( + prompt_idx=input_sample["idx"], + prompt=input_sample["message_log"], + extra_env_info=input_sample["extra_env_info"], + metadata={"task_name": "nemo_gym"}, + completions=completions, + rollout_metrics=rollout_metrics, + ) + + def _validate_init_params(self) -> None: + """Validate initialization parameters.""" + # Validate generation config. + for key in ["stop_strings", "stop_token_ids", "top_k"]: + assert not self._generation_config[key], ( # type: ignore + f"{key} is not supported in the generation config in NeMo-Gym path!" + ) + + # Validate max_seq_len. + if ( + self._max_seq_len is not None + and self._max_seq_len > self._engine_max_model_len + ): + warnings.warn( + f"policy max_total_sequence_length ({self._max_seq_len}) is greater than the " + f"generation engine's max_model_len ({self._engine_max_model_len}). The engine " + "will truncate sequences to its own limit, so the policy cap will not be " + "honored. Lower max_total_sequence_length or raise the engine's max_model_len." + ) + + # Validate max_rollout_turns. + assert self._max_rollout_turns is None, ( + "`max_rollout_turns` is not supported in NeMo-Gym path!" + ) + + # Validate num_generations_per_prompt. + assert self._num_generations_per_prompt >= 1, ( + "`num_generations_per_prompt` must be >= 1!" + ) + + def _build_inputs(self, input_sample: DatumSpec) -> list[dict]: + """Build N row dicts from input_sample, applying generation config params.""" + # Build a template row from the input_sample's extra_env_info, applying generation params. + template_row: dict = copy.deepcopy(input_sample["extra_env_info"]) # type: ignore + + # We do not translate max_seq_len into row-level max_tokens here because that would + # change semantics from "total sequence length" to "max new tokens". + responses_create_params = template_row["responses_create_params"] + responses_create_params["temperature"] = self._generation_config["temperature"] + responses_create_params["top_p"] = self._generation_config["top_p"] + if self._generation_config["max_new_tokens"] is not None: + existing = responses_create_params.get("max_output_tokens") + responses_create_params["max_output_tokens"] = ( + min(existing, self._generation_config["max_new_tokens"]) + if existing is not None + else self._generation_config["max_new_tokens"] + ) + + # Build N rows with distinct rowidxs so run_rollouts can sort them correctly. + rows = [] + for i in range(self._num_generations_per_prompt): + row = copy.deepcopy(template_row) + row["_rowidx"] = i + rows.append(row) + return rows + + async def _run_rollouts( + self, inputs: list[dict], timer: Timer, timer_prefix: str + ) -> tuple[list[Completion], dict[str, Any]]: + """Dispatch rows to NeMo-Gym and return completions + metrics.""" + nemo_gym_env = self._task_to_env["nemo_gym"] + + # Run generation. + with timer.time(f"{timer_prefix}/run_rollouts"): + results, env_timing_metrics = await nemo_gym_env.run_rollouts.remote( + inputs, self._tokenizer, timer_prefix + ) + # Convert results to completions. + completions = [self._result_to_completion(r) for r in results] + + # Compute rollout metrics. + with timer.time(f"{timer_prefix}/compute_metrics"): + rollout_metrics = self._compute_rollout_metrics( + completions, inputs[0]["agent_ref"]["name"] + ) + + rollout_metrics.update(env_timing_metrics) + + return completions, rollout_metrics + + def _result_to_completion(self, result: dict) -> Completion: + """Convert one run_rollouts result dict into a Completion.""" + # Tensorize token fields. + _tensorize_by_key(result["input_message_log"], "token_ids") + _tensorize_by_key(result["message_log"], "token_ids") + _tensorize_by_key( + [m for m in result["message_log"] if m["role"] == "assistant"], + "generation_logprobs", + ) + + # Calculate truncation. + truncated = ( + sum(len(m["token_ids"]) for m in result["message_log"]) + == self._engine_max_model_len + ) + + return Completion( + message_log=result["message_log"], + env_extras=result["full_result"], + truncated=truncated, + reward=float(result["full_result"]["reward"]), + ) + + def _compute_rollout_metrics( + self, + completions: list[Completion], + agent_name: str, + ) -> dict[str, Any]: + """Aggregate per-sample and per-agent metrics.""" + n = len(completions) + + # Aggregate metrics across all samples + rollout_metrics: dict[str, Any] = { + **_calculate_single_metric( + [ + sum(1 for m in c.message_log if m["role"] == "user") + for c in completions + ], + n, + "turns_per_sample", + ), + **_calculate_single_metric( + [sum(len(m["token_ids"]) for m in c.message_log) for c in completions], + n, + "total_tokens_per_sample", + ), + **_calculate_single_metric( + [ + sum( + len(m["token_ids"]) + for m in c.message_log + if m["role"] == "assistant" + ) + for c in completions + ], + n, + "gen_tokens_per_sample", + ), + **_calculate_single_metric( + [c.reward for c in completions], + n, + "total_reward", + ), + "natural_termination_rate": sum(not c.truncated for c in completions) / n, + "truncation_rate": sum(c.truncated for c in completions) / n, + } + + # Agent-level metrics. + agent_extras = [c.env_extras for c in completions] + for key in agent_extras[0].keys(): + values = [ + float(r[key]) + for r in agent_extras + if isinstance(r.get(key), (bool, int, float)) + ] + if values: + rollout_metrics.update( + _calculate_single_metric(values, n, f"{agent_name}/{key}") + ) + rollout_metrics[f"{agent_name}/full_result"] = Table( + data=[[json.dumps(r, separators=(",", ":"))] for r in agent_extras], + columns=["Full result"], + ) + + # Necessary for downstream nemo rl logging/printing. + rollout_metrics["mean_gen_tokens_per_sample"] = rollout_metrics[ + "gen_tokens_per_sample/mean" + ] + return rollout_metrics diff --git a/nemo_rl/models/generation/__init__.py b/nemo_rl/models/generation/__init__.py index 8465112c6c..c6c64ac1f6 100644 --- a/nemo_rl/models/generation/__init__.py +++ b/nemo_rl/models/generation/__init__.py @@ -42,7 +42,7 @@ def configure_generation_config( # vllm setting if config["backend"] == "vllm": - config = cast(VllmConfig, config) + config = cast(VllmConfig, config) # type: ignore # set load_format config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy" speculative_config = config.get("vllm_kwargs", {}).get("speculative_config") diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 037b4880f5..bd97418bdb 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -18,6 +18,8 @@ import torch from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.models.generation.sglang.config import SglangSpecificArgs +from nemo_rl.models.generation.vllm.config import VllmSpecificArgs def verify_right_padding( @@ -127,6 +129,12 @@ class GenerationConfig(TypedDict): stop_token_ids: list[int] | None stop_strings: list[str] | None colocated: NotRequired[ColocationConfig] + + # backend-specific configs + vllm_cfg: NotRequired[VllmSpecificArgs] + sglang_cfg: NotRequired[SglangSpecificArgs] + mcore_generation_config: NotRequired[dict[str, Any]] + # This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config _pad_token_id: NotRequired[int] diff --git a/pyrefly.toml b/pyrefly.toml index 9bf76baeac..f323c6ea96 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -122,6 +122,8 @@ project-includes = [ "nemo_rl/evals/__init__.py", "nemo_rl/evals/answer_parsing.py", "nemo_rl/experience/__init__.py", + "nemo_rl/experience/interfaces.py", + "nemo_rl/experience/rollout_manager.py", "nemo_rl/experience/rollouts.py", "nemo_rl/modelopt/__init__.py", "nemo_rl/modelopt/models/__init__.py", diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 3b4ef1dede..4b2c38c668 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import gc import json import tempfile @@ -36,6 +37,8 @@ SlidingPuzzleGameLogic, SlidingPuzzleMetadata, ) +from nemo_rl.experience.interfaces import Completion, PromptGroupRecord +from nemo_rl.experience.rollout_manager import AsyncNemoGymRolloutManager from nemo_rl.experience.rollouts import ( _calculate_single_metric, run_async_multi_turn_rollout, @@ -970,3 +973,236 @@ def _standardize(d: dict) -> dict: 1. In nemo_rl/experience/rollouts.py::run_async_nemo_gym_rollout, the sampling params are passed appropriately 2. In nemo_rl/models/generation/vllm/vllm_worker_async.py::VllmAsyncGenerationWorker::_setup_vllm_server::create_chat_completion, the sampling params (like top_k) are set as appropriate """ + + +# --------------------------------------------------------------------------- +# Tests for AsyncNemoGymRolloutManager +# --------------------------------------------------------------------------- + + +@pytest.mark.nemo_gym +def test_async_nemo_gym_rollout_manager( + nemo_gym, # noqa: F811 + nemo_gym_vllm_generation, # noqa: F811 + nemo_gym_sanity_test_data, # noqa: F811 + nemo_gym_tokenizer, # noqa: F811 +): + """Standalone test for AsyncNemoGymRolloutManager. + + Given 1 prompt with num_generations_per_prompt=N, asserts: + - output is a PromptGroupRecord with N Completion objects + - each Completion has a reward (float) and a non-empty message_log + - completions hold independent message_log objects + + If the result here does not match, please check the following: + 1. Test data changed: re-run test_nemo_gym_sanity (tests/unit/environments/test_nemo_gym.py) + and use _write_actual_test_data output to refresh test_nemo_gym_sanity.json. + 2. Logic changed: inspect recent changes to AsyncNemoGymRolloutManager or the gym env. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for data in nemo_gym_sanity_test_data["input"]: + f.write(json.dumps(data) + "\n") + data_path = f.name + + dataset = NemoGymDataset(data_path) + examples = [ + nemo_gym_data_processor(dataset.dataset[idx], None, None, None, idx) + for idx in range(len(dataset.dataset)) + ] + input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(examples) + + # Use only the first prompt + single_prompt = { + "message_log": input_batch["message_log"][0], + "extra_env_info": input_batch["extra_env_info"][0], + "task_name": "nemo_gym", + "idx": 0, + "loss_multiplier": float(input_batch["loss_multiplier"][0]), + } + num_generations = 2 + + manager = AsyncNemoGymRolloutManager( + tokenizer=nemo_gym_tokenizer, + task_to_env={"nemo_gym": nemo_gym}, + generation_config=nemo_gym_vllm_generation.cfg, + num_generations_per_prompt=num_generations, + max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], + ) + record = asyncio.run(manager.run_rollout(single_prompt)) + + assert isinstance(record, PromptGroupRecord) + assert len(record.completions) == num_generations, ( + f"Expected {num_generations} completions, got {len(record.completions)}" + ) + assert record.prompt_idx == 0 + + for i, completion in enumerate(record.completions): + assert isinstance(completion, Completion) + + # 1. message_log length + assert len(completion.message_log) == 2, ( + f"Completion {i}: expected 2 messages, got {len(completion.message_log)}" + ) + + # 2. last assistant token_ids + last_assistant = next( + (m for m in reversed(completion.message_log) if m["role"] == "assistant"), + None, + ) + assert last_assistant is not None, f"Completion {i}: no assistant message found" + assert torch.equal( + last_assistant["token_ids"], + torch.tensor([151667, 198, 32313, 11, 1077]), + ), ( + f"Completion {i}: last assistant token_ids {last_assistant['token_ids'].tolist()} " + f"!= [151667, 198, 32313, 11, 1077]" + ) + + # 3. reward + assert completion.reward == 0.0, ( + f"Completion {i}: reward {completion.reward} != 0.0" + ) + + # completions must be independent objects + assert record.completions[0].message_log is not record.completions[1].message_log + + +@pytest.mark.nemo_gym +def test_async_nemo_gym_rollout_manager_matches_original( + nemo_gym, # noqa: F811 + nemo_gym_vllm_generation, # noqa: F811 + nemo_gym_sanity_test_data, # noqa: F811 + nemo_gym_tokenizer, # noqa: F811 +): + """Comparison test: AsyncNemoGymRolloutManager output is structurally equivalent to the original. + + Calls run_async_nemo_gym_rollout with a batch of N identical rows, + then calls AsyncNemoGymRolloutManager with 1 prompt, N generations. + Asserts that both produce N results and rewards are in the same numeric domain. + + TODO: remove this test together with run_async_nemo_gym_rollout when the legacy path is deleted. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for data in nemo_gym_sanity_test_data["input"]: + f.write(json.dumps(data) + "\n") + data_path = f.name + + dataset = NemoGymDataset(data_path) + examples = [ + nemo_gym_data_processor(dataset.dataset[idx], None, None, None, idx) + for idx in range(len(dataset.dataset)) + ] + input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(examples) + + num_generations = 2 + single_prompt = { + "message_log": input_batch["message_log"][0], + "extra_env_info": input_batch["extra_env_info"][0], + "task_name": "nemo_gym", + "idx": 0, + "loss_multiplier": float(input_batch["loss_multiplier"][0]), + } + + # Build a batch of N identical rows for the original function + repeated_batch = BatchedDataDict( + { + "message_log": [ + deepcopy(input_batch["message_log"][0]) for _ in range(num_generations) + ], + "extra_env_info": [ + deepcopy(input_batch["extra_env_info"][0]) + for _ in range(num_generations) + ], + "loss_multiplier": input_batch["loss_multiplier"][0:1].repeat( + num_generations + ), + "idx": list(range(num_generations)), + "task_name": ["nemo_gym"] * num_generations, + } + ) + + original_result = run_async_nemo_gym_rollout( + policy_generation=nemo_gym_vllm_generation, + input_batch=repeated_batch, + tokenizer=nemo_gym_tokenizer, + task_to_env={"nemo_gym": nemo_gym}, + generation_config=nemo_gym_vllm_generation.cfg, + max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], + max_rollout_turns=None, + ) + + manager = AsyncNemoGymRolloutManager( + tokenizer=nemo_gym_tokenizer, + task_to_env={"nemo_gym": nemo_gym}, + generation_config=nemo_gym_vllm_generation.cfg, + num_generations_per_prompt=num_generations, + max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], + ) + record = asyncio.run(manager.run_rollout(single_prompt)) + + # Both should produce N completions + assert len(original_result.final_batch["message_log"]) == num_generations + assert len(record.completions) == num_generations + + for i in range(num_generations): + orig_msg_log = original_result.final_batch["message_log"][i] + new_msg_log = record.completions[i].message_log + + # 1. message_log length matches + assert len(orig_msg_log) == len(new_msg_log), ( + f"Completion {i}: message_log length {len(new_msg_log)} != original {len(orig_msg_log)}" + ) + + # 2. last assistant token_ids match + def _last_assistant_token_ids(msg_log): + for m in reversed(msg_log): + if m["role"] == "assistant": + return m.get("token_ids") + return None + + orig_token_ids = _last_assistant_token_ids(orig_msg_log) + new_token_ids = _last_assistant_token_ids(new_msg_log) + assert orig_token_ids is not None, ( + f"Completion {i}: no assistant message in original" + ) + assert new_token_ids is not None, ( + f"Completion {i}: no assistant message in manager" + ) + assert torch.equal(orig_token_ids, new_token_ids), ( + f"Completion {i}: last assistant token_ids mismatch\n" + f" original: {orig_token_ids.tolist()}\n" + f" manager: {new_token_ids.tolist()}" + ) + + # 3. reward matches + orig_reward = original_result.final_batch["total_reward"][i].item() + new_reward = record.completions[i].reward + assert orig_reward == new_reward, ( + f"Completion {i}: reward mismatch — original {orig_reward}, manager {new_reward}" + ) + + # 4. rollout_metrics numeric values match (timing and Table fields are excluded) + orig_metrics = original_result.rollout_metrics + new_metrics = record.rollout_metrics + for key in orig_metrics.keys(): + # Skip timing and full_result fields + if key.startswith("timing/") or key.endswith("/full_result"): + continue + + # Check that the key is present in the new metrics + assert key in new_metrics, f"rollout_metrics[{key!r}] missing from manager" + + orig_val = orig_metrics[key] + new_val = new_metrics[key] + + # Skip non-numeric fields + assert type(orig_val) == type(new_val), ( + f"rollout_metrics[{key!r}] type mismatch: {type(orig_val)} != {type(new_val)}" + ) + if not isinstance(orig_val, (bool, int, float)): + continue + + # Check equal + assert orig_val == pytest.approx(new_val), ( + f"rollout_metrics[{key!r}] mismatch — original {orig_val}, manager {new_val}" + )