Skip to content

Commit 2f219a9

Browse files
add multi-turn RL support
1 parent 81155e3 commit 2f219a9

17 files changed

Lines changed: 1857 additions & 277 deletions

README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,47 @@ We then extend `KernelBenchEnv` to support:
3131
- **Batching**: `KernelBenchEnvGroupBuilder` groups multiple rollouts for the same problem, enabling **GRPO-style** training where rewards are normalized within groups.
3232
- **Dataset Construction**: `KernelBenchDatasetBuilder` handles the iteration over KernelBench levels and problems, partitioning them into training and evaluation sets. You are welcome to extend it to support more problems beyond what is currently in KernelBench.
3333

34+
### Multi-Turn RL
35+
36+
We extend the single-turn pipeline with multi-turn iterative refinement, following the approach in [Kevin](https://arxiv.org/abs/2507.11948). Instead of generating one kernel per problem, the model generates a kernel, receives evaluation feedback (compilation errors, correctness failures, or speedup results), and refines its solution over multiple turns.
37+
38+
`MultiTurnKernelBenchEnv` manages the multi-turn loop:
39+
- **History management**: Prior turns (prompt, response, feedback) are kept in context with token-based truncation to stay within the context window.
40+
- **Evaluation feedback**: Structured feedback tells the model what went wrong (compilation error, incorrect output, or correct but slow) so it can fix specific issues.
41+
- **Early stopping**: Optionally stop the episode when the kernel passes all correctness tests.
42+
43+
Training uses GRPO with discounted returns across turns:
44+
- Per-turn scores are computed as `S = 0.3 * correct + speedup` (only for correct kernels).
45+
- Discounted returns: `R_t = S_t + γ * R_{t+1}` (backward recursion, γ=0.4 by default).
46+
- Advantages are normalized across all `group_size × max_turns` turn-level samples: `(R - mean) / (std + ε)`.
47+
- PPO with asymmetric clipping (Clip-Higher, ε_low=0.2, ε_high=0.28) and constant length normalization.
48+
49+
Enable multi-turn via config:
50+
```yaml
51+
multiturn:
52+
enabled: true
53+
max_turns: 4 # Refinement turns per trajectory
54+
gamma: 0.4 # Discount factor
55+
aggregation: "sum" # "sum" or "max"
56+
```
57+
58+
Or via CLI:
59+
```bash
60+
uv run python -m kernelbench_tinker.scripts.train_kernel_rl \
61+
--config src/kernelbench_tinker/config/rl_kernelbench.yaml \
62+
multiturn.enabled=true \
63+
log_path=./runs/my_multiturn_experiment
64+
```
65+
66+
Multi-turn inference is also supported via the eval script:
67+
```bash
68+
uv run python -m kernelbench_tinker.scripts.eval_kernel_rl \
69+
checkpoint_path=<your_checkpoint> \
70+
multiturn_enabled=true \
71+
multiturn_max_turns=8 \
72+
level=1
73+
```
74+
3475

3576
### Directory Structure
3677
```text
@@ -54,6 +95,7 @@ src/kernelbench_tinker/
5495
envs/
5596
kernelbench_client.py # KernelBench Python API wrapper
5697
kernelbench_env.py # Single-turn RL environment
98+
multiturn_kernelbench_env.py # Multi-turn RL environment
5799
training/
58100
models.py # Model/renderer configuration
59101
reward.py # Reward shaping
@@ -282,7 +324,6 @@ Note the scope of this repo is an open-source implementation of KernelBench-Tink
282324

283325
* More reward examples leveraging more fine-grained metrics
284326
* More reward hack checking
285-
* Multi-turn RL to have denser reward signal like [Kevin](https://arxiv.org/abs/2507.11948)
286327
* Improve Step time and training efficiency
287328

288329

src/kernelbench_tinker/config/configs.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,57 @@ class DatasetConfig:
8181

8282
# Train/test split
8383
test_fraction: float = 0.1
84+
85+
86+
@dataclass
87+
class MultiTurnConfig:
88+
"""
89+
Configuration for multi-turn RL training.
90+
91+
Controls the iterative refinement loop where the model receives
92+
evaluation feedback and can fix errors across multiple turns.
93+
"""
94+
95+
# Enable multi-turn mode (False = single-turn)
96+
enabled: bool = False
97+
98+
# Maximum refinement turns per trajectory
99+
max_turns: int = 4
100+
101+
# Discount factor for multi-turn returns: R_t = S_t + gamma * R_{t+1}
102+
gamma: float = 0.4
103+
104+
# Return aggregation mode: "sum" or "max"
105+
# sum: R_t = Σ γ^(i-t) × S_i (reward turns leading to many good kernels)
106+
# max: R_t = max{ γ^(i-t) × S_i } (reward turns leading to one great kernel)
107+
aggregation: str = "sum"
108+
109+
# Stop the episode early when the kernel is correct.
110+
# Default False for training: model needs post-correctness turns to
111+
# learn speedup optimization. Set True at eval time if desired.
112+
early_stop_on_correct: bool = False
113+
114+
# Optional: require this speedup before early stopping
115+
speedup_threshold: float | None = None
116+
117+
# Prompt
118+
prompt_max_tokens: int | None = None # Token budget for history truncation (None = char fallback)
119+
inject_think_token: bool = False # Append <think>\n to generation prompts
120+
121+
# Generation
122+
temperature: float = 0.9
123+
top_p: float = 1.0
124+
seed: int | None = None
125+
126+
# Response length extension mid-training (0 = disabled)
127+
max_tokens_extended: int = 22000
128+
max_tokens_extend_after_step: int = 30
129+
130+
# Training
131+
loss_fn: str = "ppo"
132+
max_grad_norm: float = 0.05
133+
warmup_ratio: float = 0.03
134+
clip_epsilon_low: float = 0.2
135+
clip_epsilon_high: float = 0.28
136+
constant_length_norm: int = 16384
137+
num_substeps: int = 2

src/kernelbench_tinker/config/rl_kernelbench.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,33 @@ learning_rate: 0.000002 # 2e-6 as explicit float
2626
max_tokens: 16384
2727
temperature: 1.0
2828

29+
# =============================================================================
30+
# Multi-turn Configuration (disabled by default)
31+
# =============================================================================
32+
multiturn:
33+
enabled: false # true to enable iterative refinement
34+
max_turns: 4 # Maximum refinement turns per trajectory
35+
gamma: 0.4 # Discount factor for multi-turn returns
36+
aggregation: "sum" # "sum" (reward many good kernels) or "max" (reward one great kernel)
37+
early_stop_on_correct: false # Stop episode when kernel passes all tests
38+
speedup_threshold: null # Required speedup before early stopping (null = any correct)
39+
# Prompt
40+
prompt_max_tokens: null # Token budget for history truncation (null = char fallback)
41+
inject_think_token: false # Append <think>\n to generation prompts
42+
# Generation
43+
temperature: 0.9 # Generation temperature
44+
top_p: 1.0 # Nucleus sampling (1.0 = disabled)
45+
seed: null # Random seed for generation (null = random)
46+
max_tokens_extended: 22000 # Extend max_tokens mid-training (0 = disabled)
47+
max_tokens_extend_after_step: 30 # Step at which to switch
48+
# Training
49+
loss_fn: "ppo" # Loss function (single-turn uses top-level loss_fn)
50+
max_grad_norm: 0.05 # Gradient clipping (0.0 = disabled)
51+
warmup_ratio: 0.03 # Linear LR warmup fraction
52+
clip_epsilon_low: 0.2 # PPO clip lower bound
53+
clip_epsilon_high: 0.28 # PPO clip upper bound (Clip-High)
54+
constant_length_norm: 16384 # GRPO constant length normalization (0 = disabled)
55+
2956
# =============================================================================
3057
# Training Configuration
3158
# =============================================================================
@@ -57,6 +84,7 @@ dataset_builder:
5784
# Problem Selection
5885
# ---------------------------------------------------------------------------
5986
level: 1 # KernelBench level (1, 2, 3, or 4)
87+
levels: null # Train on multiple levels (e.g. [1, 2]); overrides level when set
6088
start_problem: null # First problem ID (null = start from 1)
6189
end_problem: null # Last problem ID (null = all problems)
6290
dataset_src: "huggingface" # "huggingface" or "local"
@@ -107,6 +135,9 @@ dataset_builder:
107135
reward_correctness_weight: 0.3
108136
reward_speed_weight: 1.0
109137
reward_length_weight: 0.0
138+
reward_speed_max_reward: 10.0 # Cap on speed reward component (set high to uncap)
139+
reward_clip_min: null # Lower bound on total reward (null = no clipping)
140+
reward_clip_max: null # Upper bound on total reward (null = no clipping)
110141

111142
# ---------------------------------------------------------------------------
112143
# Reward Hacking Detection (Static Checker)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Shared utilities for KernelBench environments.
3+
4+
Contains helpers used by both the single-turn and multi-turn environments:
5+
- System prompt construction
6+
- Step evaluation (parse → evaluate → reward → metrics)
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import logging
12+
import time
13+
from dataclasses import dataclass
14+
from typing import TYPE_CHECKING
15+
16+
from tinker_cookbook import renderers
17+
from tinker_cookbook.rl.types import Action, Metrics
18+
19+
from kernelbench_tinker.config.configs import EvalConfig
20+
from kernelbench_tinker.envs.kernelbench_client import (
21+
KernelBenchProblem,
22+
KernelEvalResult,
23+
ParsedResponse,
24+
evaluate_kernel_async,
25+
parse_structured_response,
26+
)
27+
from kernelbench_tinker.training.reward import (
28+
RewardConfig,
29+
compute_reward,
30+
)
31+
32+
logger = logging.getLogger(__name__)
33+
34+
35+
@dataclass
36+
class EvalStepResult:
37+
"""Result from evaluate_step(), shared by single-turn and multi-turn envs."""
38+
39+
parsed: ParsedResponse
40+
eval_result: KernelEvalResult
41+
format_ok: bool
42+
kernel_code: str
43+
reward: float
44+
metrics: Metrics
45+
response_text: str # Raw response content from renderer (before structured parsing)
46+
47+
48+
def build_system_prompt(backend: str) -> str:
49+
"""Build a backend-specific system prompt for kernel generation.
50+
51+
Used by both single-turn and multi-turn environments.
52+
"""
53+
return (
54+
f"You are an expert GPU kernel developer. Your task is to optimize PyTorch "
55+
f"operations by writing efficient custom {backend.upper()} kernels.\n"
56+
f"\n"
57+
f"When given a PyTorch model, write an optimized kernel implementation.\n"
58+
f"\n"
59+
f"Your solution must:\n"
60+
f"- Be a drop-in replacement as a class named `ModelNew`\n"
61+
f"- Use custom {backend.upper()} kernels, not just PyTorch operations\n"
62+
f"- Be correct and produce the same results as the reference\n"
63+
f"\n"
64+
f"You MUST respond in exactly this format:\n"
65+
f"\n"
66+
f"<KERNEL>\n"
67+
f"```python\n"
68+
f"# Your complete optimized implementation here\n"
69+
f"class ModelNew(nn.Module):\n"
70+
f" ...\n"
71+
f"```\n"
72+
f"</KERNEL>"
73+
)
74+
75+
76+
async def evaluate_step(
77+
problem: KernelBenchProblem,
78+
renderer: renderers.Renderer,
79+
action: Action,
80+
eval_config: EvalConfig,
81+
reward_config: RewardConfig,
82+
step_start: float,
83+
) -> EvalStepResult:
84+
"""Parse, evaluate, and compute reward for a single action.
85+
86+
Shared by KernelBenchEnv.step() and MultiTurnKernelBenchEnv.step().
87+
"""
88+
message, _ = renderer.parse_response(action)
89+
response_text = message.get("content", "")
90+
91+
parsed = parse_structured_response(response_text)
92+
kernel_code = parsed.kernel
93+
format_ok = parsed.format_ok
94+
95+
eval_start = time.perf_counter()
96+
cfg = eval_config
97+
eval_result = await evaluate_kernel_async(
98+
level=problem.level,
99+
problem_id=problem.problem_id,
100+
backend=problem.backend,
101+
kernel_code=kernel_code,
102+
dataset_src=problem.dataset_src,
103+
num_correct_trials=cfg.num_correct_trials,
104+
measure_performance=cfg.measure_performance,
105+
num_perf_trials=cfg.num_perf_trials,
106+
timing_method=cfg.timing_method,
107+
precision=cfg.precision,
108+
check_for_excessive_speedup=cfg.check_for_excessive_speedup,
109+
excessive_speedup_threshold=cfg.excessive_speedup_threshold,
110+
timeout=cfg.modal_timeout,
111+
)
112+
eval_time = time.perf_counter() - eval_start
113+
114+
reward = compute_reward(
115+
eval_result,
116+
reward_config,
117+
kernel_code=kernel_code,
118+
backend=problem.backend,
119+
)
120+
121+
metrics: Metrics = {
122+
"level": problem.level,
123+
"problem_id": problem.problem_id,
124+
"format_ok": float(format_ok),
125+
"compiled": float(eval_result["compiled"]),
126+
"correctness": float(eval_result["correctness"]),
127+
"tests_passed": eval_result["tests_passed"],
128+
"tests_total": eval_result["tests_total"],
129+
}
130+
if eval_result.get("speedup") is not None:
131+
metrics["speedup"] = eval_result["speedup"]
132+
if eval_result.get("runtime_ms") is not None:
133+
metrics["runtime_ms"] = eval_result["runtime_ms"]
134+
metrics["time/eval"] = eval_time
135+
timing_metadata = (eval_result.get("metadata") or {}).get("timings", {})
136+
if "reference_load_s" in timing_metadata:
137+
metrics["time/ref_load"] = timing_metadata["reference_load_s"]
138+
if "modal_eval_s" in timing_metadata:
139+
metrics["time/modal_eval"] = timing_metadata["modal_eval_s"]
140+
metrics["time/step_total"] = time.perf_counter() - step_start
141+
142+
return EvalStepResult(
143+
parsed=parsed,
144+
eval_result=eval_result,
145+
format_ok=format_ok,
146+
kernel_code=kernel_code,
147+
reward=reward,
148+
metrics=metrics,
149+
response_text=response_text,
150+
)

0 commit comments

Comments
 (0)