Skip to content
40 changes: 40 additions & 0 deletions nemo_rl/experience/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Optional

from nemo_rl.data.interfaces import LLMMessageLogType, VLMMessageLogType


@dataclass
class Completion:
"""A single generated completion for one prompt."""

message_log: LLMMessageLogType
env_extras: dict[str, Any]
truncated: bool
reward: float


@dataclass
class PromptGroupRecord:
"""All completions for a single prompt, with prompt-level metadata."""

prompt_idx: int
prompt: LLMMessageLogType | VLMMessageLogType
extra_env_info: Optional[dict[str, Any]]
metadata: dict[str, Any]
completions: list["Completion"]
rollout_metrics: dict[str, Any]
257 changes: 257 additions & 0 deletions nemo_rl/experience/rollout_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import warnings
from typing import Any, Optional

from transformers import PreTrainedTokenizerBase
from wandb import Table

from nemo_rl.data.interfaces import DatumSpec
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.experience.interfaces import Completion, PromptGroupRecord
from nemo_rl.experience.rollouts import _calculate_single_metric, _tensorize_by_key
from nemo_rl.models.generation.interfaces import GenerationConfig
from nemo_rl.utils.timer import Timer

TokenizerType = PreTrainedTokenizerBase


class AsyncNemoGymRolloutManager:
"""Manages per-prompt NeMo-Gym rollouts, producing a PromptGroupRecord per call.

Each run_rollout takes one prompt and returns num_generations_per_prompt completions
batched through a single NeMo-Gym run_rollouts call.
"""

def __init__(
self,
tokenizer: TokenizerType,
task_to_env: dict[str, EnvironmentInterface],
generation_config: GenerationConfig,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollout_manager.py:44

generation_config is annotated GenerationConfig, but L55 indexes generation_config["vllm_cfg"]["max_model_len"]. vllm_cfg is not a key on the GenerationConfig TypedDict — it's defined only on the subclass VllmConfig(GenerationConfig):

class VllmConfig(GenerationConfig):
vllm_cfg: VllmSpecificArgs

So L55 is a real type error (pyrefly typed-dict-key-error), currently masked by the pyrefly.toml omission above. The existing run_async_nemo_gym_rollout sidesteps this by reading policy_generation.cfg["vllm_cfg"] off the policy object rather than off generation_config. The simplest fix here is to tighten the annotation to VllmConfig (and from nemo_rl.models.generation.vllm.config import VllmConfig):

Suggested change
generation_config: GenerationConfig,
generation_config: VllmConfig,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's better to just fold in the vllmConfig into generationconfig for now? so no change here, but no need for vllmconfig

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_generations_per_prompt: int,
max_seq_len: Optional[int] = None,
max_rollout_turns: Optional[int] = None,
) -> None:
self._tokenizer = tokenizer
self._task_to_env = task_to_env
self._generation_config = generation_config
self._num_generations_per_prompt = num_generations_per_prompt
self._max_seq_len = max_seq_len
self._max_rollout_turns = max_rollout_turns
self._engine_max_model_len = generation_config["vllm_cfg"]["max_model_len"]

self._validate_init_params()

async def run_rollout(self, input_sample: DatumSpec) -> PromptGroupRecord:
"""Run num_generations_per_prompt rollouts for one prompt.

Args:
input_sample: A single prompt (one DatumSpec entry).

Returns:
PromptGroupRecord with num_generations_per_prompt completions.
"""
timer = Timer()
timer_prefix = "timing/rollout"
timer.start(f"{timer_prefix}/total")

rollout_inputs = self._build_inputs(input_sample)
completions, rollout_metrics = await self._run_rollouts(
rollout_inputs, timer, timer_prefix
)

timer.stop(f"{timer_prefix}/total")
rollout_metrics.update(timer.get_timing_metrics("sum"))

return PromptGroupRecord(
prompt_idx=input_sample["idx"],
prompt=input_sample["message_log"],
extra_env_info=input_sample["extra_env_info"],
metadata={"task_name": "nemo_gym"},
completions=completions,
rollout_metrics=rollout_metrics,
)

def _validate_init_params(self) -> None:
"""Validate initialization parameters."""
# Validate generation config.
for key in ["stop_strings", "stop_token_ids", "top_k"]:
assert not self._generation_config[key], ( # type: ignore
f"{key} is not supported in the generation config in NeMo-Gym path!"
)

# Validate max_seq_len.
if (
self._max_seq_len is not None
and self._max_seq_len > self._engine_max_model_len
):
warnings.warn(
f"policy max_total_sequence_length ({self._max_seq_len}) is greater than the "
f"generation engine's max_model_len ({self._engine_max_model_len}). The engine "
"will truncate sequences to its own limit, so the policy cap will not be "
"honored. Lower max_total_sequence_length or raise the engine's max_model_len."
)

# Validate max_rollout_turns.
assert self._max_rollout_turns is None, (
"`max_rollout_turns` is not supported in NeMo-Gym path!"
)
Comment thread
yuki-97 marked this conversation as resolved.

# Validate num_generations_per_prompt.
assert self._num_generations_per_prompt >= 1, (
"`num_generations_per_prompt` must be >= 1!"
)

def _build_inputs(self, input_sample: DatumSpec) -> list[dict]:
"""Build N row dicts from input_sample, applying generation config params."""
# Build a template row from the input_sample's extra_env_info, applying generation params.
template_row: dict = copy.deepcopy(input_sample["extra_env_info"]) # type: ignore

# We do not translate max_seq_len into row-level max_tokens here because that would
# change semantics from "total sequence length" to "max new tokens".
responses_create_params = template_row["responses_create_params"]
responses_create_params["temperature"] = self._generation_config["temperature"]
responses_create_params["top_p"] = self._generation_config["top_p"]
if self._generation_config["max_new_tokens"] is not None:
existing = responses_create_params.get("max_output_tokens")
responses_create_params["max_output_tokens"] = (
min(existing, self._generation_config["max_new_tokens"])
if existing is not None
else self._generation_config["max_new_tokens"]
)

# Build N rows with distinct rowidxs so run_rollouts can sort them correctly.
rows = []
for i in range(self._num_generations_per_prompt):
row = copy.deepcopy(template_row)
row["_rowidx"] = i
rows.append(row)
return rows

async def _run_rollouts(
self, inputs: list[dict], timer: Timer, timer_prefix: str
) -> tuple[list[Completion], dict[str, Any]]:
"""Dispatch rows to NeMo-Gym and return completions + metrics."""
nemo_gym_env = self._task_to_env["nemo_gym"]

# Run generation.
with timer.time(f"{timer_prefix}/run_rollouts"):
results, env_timing_metrics = await nemo_gym_env.run_rollouts.remote(
inputs, self._tokenizer, timer_prefix
)
# Convert results to completions.
completions = [self._result_to_completion(r) for r in results]

# Compute rollout metrics.
with timer.time(f"{timer_prefix}/compute_metrics"):
rollout_metrics = self._compute_rollout_metrics(
completions, inputs[0]["agent_ref"]["name"]
)

rollout_metrics.update(env_timing_metrics)

return completions, rollout_metrics

def _result_to_completion(self, result: dict) -> Completion:
"""Convert one run_rollouts result dict into a Completion."""
# Tensorize token fields.
_tensorize_by_key(result["input_message_log"], "token_ids")
_tensorize_by_key(result["message_log"], "token_ids")
_tensorize_by_key(
[m for m in result["message_log"] if m["role"] == "assistant"],
"generation_logprobs",
)

# Calculate truncation.
truncated = (
sum(len(m["token_ids"]) for m in result["message_log"])
== self._engine_max_model_len
)

return Completion(
message_log=result["message_log"],
env_extras=result["full_result"],
truncated=truncated,
reward=float(result["full_result"]["reward"]),
)

def _compute_rollout_metrics(
self,
completions: list[Completion],
agent_name: str,
) -> dict[str, Any]:
"""Aggregate per-sample and per-agent metrics."""
n = len(completions)

# Aggregate metrics across all samples
rollout_metrics: dict[str, Any] = {
**_calculate_single_metric(
[
sum(1 for m in c.message_log if m["role"] == "user")
for c in completions
],
n,
"turns_per_sample",
),
**_calculate_single_metric(
[sum(len(m["token_ids"]) for m in c.message_log) for c in completions],
n,
"total_tokens_per_sample",
),
**_calculate_single_metric(
[
sum(
len(m["token_ids"])
for m in c.message_log
if m["role"] == "assistant"
)
for c in completions
],
n,
"gen_tokens_per_sample",
),
**_calculate_single_metric(
[c.reward for c in completions],
n,
"total_reward",
),
"natural_termination_rate": sum(not c.truncated for c in completions) / n,
"truncation_rate": sum(c.truncated for c in completions) / n,
}

# Agent-level metrics.
agent_extras = [c.env_extras for c in completions]
for key in agent_extras[0].keys():
values = [
float(r[key])
for r in agent_extras
if isinstance(r.get(key), (bool, int, float))
]
if values:
rollout_metrics.update(
_calculate_single_metric(values, n, f"{agent_name}/{key}")
)
rollout_metrics[f"{agent_name}/full_result"] = Table(
data=[[json.dumps(r, separators=(",", ":"))] for r in agent_extras],
columns=["Full result"],
)

# Necessary for downstream nemo rl logging/printing.
rollout_metrics["mean_gen_tokens_per_sample"] = rollout_metrics[
"gen_tokens_per_sample/mean"
]
return rollout_metrics
2 changes: 1 addition & 1 deletion nemo_rl/models/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def configure_generation_config(

# vllm setting
if config["backend"] == "vllm":
config = cast(VllmConfig, config)
config = cast(VllmConfig, config) # type: ignore
# set load_format
config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy"
speculative_config = config.get("vllm_kwargs", {}).get("speculative_config")
Expand Down
8 changes: 8 additions & 0 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.models.generation.sglang.config import SglangSpecificArgs
from nemo_rl.models.generation.vllm.config import VllmSpecificArgs


def verify_right_padding(
Expand Down Expand Up @@ -127,6 +129,12 @@ class GenerationConfig(TypedDict):
stop_token_ids: list[int] | None
stop_strings: list[str] | None
colocated: NotRequired[ColocationConfig]

# backend-specific configs
vllm_cfg: NotRequired[VllmSpecificArgs]
sglang_cfg: NotRequired[SglangSpecificArgs]
mcore_generation_config: NotRequired[dict[str, Any]]

# This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config
_pad_token_id: NotRequired[int]

Expand Down
2 changes: 2 additions & 0 deletions pyrefly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ project-includes = [
"nemo_rl/evals/__init__.py",
"nemo_rl/evals/answer_parsing.py",
"nemo_rl/experience/__init__.py",
"nemo_rl/experience/interfaces.py",
Comment thread
yuki-97 marked this conversation as resolved.
"nemo_rl/experience/rollout_manager.py",
"nemo_rl/experience/rollouts.py",
"nemo_rl/modelopt/__init__.py",
"nemo_rl/modelopt/models/__init__.py",
Expand Down
Loading
Loading