Skip to content

Commit 7c139d6

Browse files
add multi-turn RL support
1 parent 81155e3 commit 7c139d6

13 files changed

Lines changed: 1712 additions & 250 deletions

README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,47 @@ We then extend `KernelBenchEnv` to support:
3131
- **Batching**: `KernelBenchEnvGroupBuilder` groups multiple rollouts for the same problem, enabling **GRPO-style** training where rewards are normalized within groups.
3232
- **Dataset Construction**: `KernelBenchDatasetBuilder` handles the iteration over KernelBench levels and problems, partitioning them into training and evaluation sets. You are welcome to extend it to support more problems beyond what is currently in KernelBench.
3333

34+
### Multi-Turn RL
35+
36+
We extend the single-turn pipeline with multi-turn iterative refinement, following the approach in [Kevin](https://arxiv.org/abs/2507.11948). Instead of generating one kernel per problem, the model generates a kernel, receives evaluation feedback (compilation errors, correctness failures, or speedup results), and refines its solution over multiple turns.
37+
38+
`MultiTurnKernelBenchEnv` manages the multi-turn loop:
39+
- **History management**: Prior turns (prompt, response, feedback) are kept in context with token-based truncation to stay within the context window.
40+
- **Evaluation feedback**: Structured feedback tells the model what went wrong (compilation error, incorrect output, or correct but slow) so it can fix specific issues.
41+
- **Early stopping**: Optionally stop the episode when the kernel passes all correctness tests.
42+
43+
Training uses GRPO with discounted returns across turns:
44+
- Per-turn scores are computed as `S = 0.3 * correct + speedup` (only for correct kernels).
45+
- Discounted returns: `R_t = S_t + γ * R_{t+1}` (backward recursion, γ=0.4 by default).
46+
- Advantages are normalized across all `group_size × max_turns` turn-level samples: `(R - mean) / (std + ε)`.
47+
- PPO with asymmetric clipping (Clip-Higher, ε_low=0.2, ε_high=0.28) and constant length normalization.
48+
49+
Enable multi-turn via config:
50+
```yaml
51+
multiturn:
52+
enabled: true
53+
max_turns: 4 # Refinement turns per trajectory
54+
gamma: 0.4 # Discount factor
55+
aggregation: "sum" # "sum" or "max"
56+
```
57+
58+
Or via CLI:
59+
```bash
60+
uv run python -m kernelbench_tinker.scripts.train_kernel_rl \
61+
--config src/kernelbench_tinker/config/rl_kernelbench.yaml \
62+
multiturn.enabled=true \
63+
log_path=./runs/my_multiturn_experiment
64+
```
65+
66+
Multi-turn inference is also supported via the eval script:
67+
```bash
68+
uv run python -m kernelbench_tinker.scripts.eval_kernel_rl \
69+
checkpoint_path=<your_checkpoint> \
70+
multiturn_enabled=true \
71+
multiturn_max_turns=8 \
72+
level=1
73+
```
74+
3475

3576
### Directory Structure
3677
```text
@@ -54,6 +95,7 @@ src/kernelbench_tinker/
5495
envs/
5596
kernelbench_client.py # KernelBench Python API wrapper
5697
kernelbench_env.py # Single-turn RL environment
98+
multiturn_kernelbench_env.py # Multi-turn RL environment
5799
training/
58100
models.py # Model/renderer configuration
59101
reward.py # Reward shaping
@@ -282,7 +324,6 @@ Note the scope of this repo is an open-source implementation of KernelBench-Tink
282324

283325
* More reward examples leveraging more fine-grained metrics
284326
* More reward hack checking
285-
* Multi-turn RL to have denser reward signal like [Kevin](https://arxiv.org/abs/2507.11948)
286327
* Improve Step time and training efficiency
287328

288329

src/kernelbench_tinker/config/configs.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,57 @@ class DatasetConfig:
8181

8282
# Train/test split
8383
test_fraction: float = 0.1
84+
85+
86+
@dataclass
87+
class MultiTurnConfig:
88+
"""
89+
Configuration for multi-turn RL training.
90+
91+
Controls the iterative refinement loop where the model receives
92+
evaluation feedback and can fix errors across multiple turns.
93+
"""
94+
95+
# Enable multi-turn mode (False = single-turn)
96+
enabled: bool = False
97+
98+
# Maximum refinement turns per trajectory
99+
max_turns: int = 4
100+
101+
# Discount factor for multi-turn returns: R_t = S_t + gamma * R_{t+1}
102+
gamma: float = 0.4
103+
104+
# Return aggregation mode: "sum" or "max"
105+
# sum: R_t = Σ γ^(i-t) × S_i (reward turns leading to many good kernels)
106+
# max: R_t = max{ γ^(i-t) × S_i } (reward turns leading to one great kernel)
107+
aggregation: str = "sum"
108+
109+
# Stop the episode early when the kernel is correct.
110+
# Default False for training: model needs post-correctness turns to
111+
# learn speedup optimization. Set True at eval time if desired.
112+
early_stop_on_correct: bool = False
113+
114+
# Optional: require this speedup before early stopping
115+
speedup_threshold: float | None = None
116+
117+
# Prompt
118+
prompt_max_tokens: int | None = None # Token budget for history truncation (None = char fallback)
119+
inject_think_token: bool = False # Append <think>\n to generation prompts
120+
121+
# Generation
122+
temperature: float = 0.9
123+
top_p: float = 1.0
124+
seed: int | None = None
125+
126+
# Response length extension mid-training (0 = disabled)
127+
max_tokens_extended: int = 22000
128+
max_tokens_extend_after_step: int = 30
129+
130+
# Training
131+
loss_fn: str = "ppo"
132+
max_grad_norm: float = 0.05
133+
warmup_ratio: float = 0.03
134+
clip_epsilon_low: float = 0.2
135+
clip_epsilon_high: float = 0.28
136+
constant_length_norm: int = 16384
137+
num_substeps: int = 2

src/kernelbench_tinker/config/rl_kernelbench.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,33 @@ learning_rate: 0.000002 # 2e-6 as explicit float
2626
max_tokens: 16384
2727
temperature: 1.0
2828

29+
# =============================================================================
30+
# Multi-turn Configuration (disabled by default)
31+
# =============================================================================
32+
multiturn:
33+
enabled: false # true to enable iterative refinement
34+
max_turns: 4 # Maximum refinement turns per trajectory
35+
gamma: 0.4 # Discount factor for multi-turn returns
36+
aggregation: "sum" # "sum" (reward many good kernels) or "max" (reward one great kernel)
37+
early_stop_on_correct: false # Stop episode when kernel passes all tests
38+
speedup_threshold: null # Required speedup before early stopping (null = any correct)
39+
# Prompt
40+
prompt_max_tokens: null # Token budget for history truncation (null = char fallback)
41+
inject_think_token: false # Append <think>\n to generation prompts
42+
# Generation
43+
temperature: 0.9 # Generation temperature
44+
top_p: 1.0 # Nucleus sampling (1.0 = disabled)
45+
seed: null # Random seed for generation (null = random)
46+
max_tokens_extended: 22000 # Extend max_tokens mid-training (0 = disabled)
47+
max_tokens_extend_after_step: 30 # Step at which to switch
48+
# Training
49+
loss_fn: "ppo" # Loss function (single-turn uses top-level loss_fn)
50+
max_grad_norm: 0.05 # Gradient clipping (0.0 = disabled)
51+
warmup_ratio: 0.03 # Linear LR warmup fraction
52+
clip_epsilon_low: 0.2 # PPO clip lower bound
53+
clip_epsilon_high: 0.28 # PPO clip upper bound (Clip-High)
54+
constant_length_norm: 16384 # GRPO constant length normalization (0 = disabled)
55+
2956
# =============================================================================
3057
# Training Configuration
3158
# =============================================================================
@@ -57,6 +84,7 @@ dataset_builder:
5784
# Problem Selection
5885
# ---------------------------------------------------------------------------
5986
level: 1 # KernelBench level (1, 2, 3, or 4)
87+
levels: null # Train on multiple levels (e.g. [1, 2]); overrides level when set
6088
start_problem: null # First problem ID (null = start from 1)
6189
end_problem: null # Last problem ID (null = all problems)
6290
dataset_src: "huggingface" # "huggingface" or "local"
@@ -107,6 +135,9 @@ dataset_builder:
107135
reward_correctness_weight: 0.3
108136
reward_speed_weight: 1.0
109137
reward_length_weight: 0.0
138+
reward_speed_max_reward: 10.0 # Cap on speed reward component (set high to uncap)
139+
reward_clip_min: null # Lower bound on total reward (null = no clipping)
140+
reward_clip_max: null # Upper bound on total reward (null = no clipping)
110141

111142
# ---------------------------------------------------------------------------
112143
# Reward Hacking Detection (Static Checker)

src/kernelbench_tinker/envs/kernelbench_client.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,18 @@
3333
re.DOTALL | re.IGNORECASE
3434
)
3535

36+
# Summary block pattern - reasoning summary inside <SUMMARY>...</SUMMARY>
37+
SUMMARY_BLOCK_PATTERN = re.compile(
38+
r"<SUMMARY>(.*?)</SUMMARY>",
39+
re.DOTALL | re.IGNORECASE
40+
)
41+
3642

3743
@dataclass
3844
class ParsedResponse:
3945
"""Parsed model response with kernel blocks."""
4046
kernel: str # Kernel code (from <KERNEL> block or extracted code block)
47+
cot_summary: str # Reasoning summary (from <SUMMARY> block)
4148
raw: str # Original raw response
4249
format_ok: bool # Whether we successfully extracted kernel code
4350

@@ -94,8 +101,15 @@ def parse_structured_response(text: str) -> ParsedResponse:
94101
# Check if we got valid kernel code
95102
format_ok = bool(kernel) and ("class ModelNew" in kernel or "def forward" in kernel)
96103

104+
# Extract CoT summary from <SUMMARY> block
105+
cot_summary = ""
106+
summary_match = SUMMARY_BLOCK_PATTERN.search(text)
107+
if summary_match:
108+
cot_summary = summary_match.group(1).strip()
109+
97110
return ParsedResponse(
98111
kernel=kernel,
112+
cot_summary=cot_summary,
99113
raw=raw,
100114
format_ok=format_ok,
101115
)
@@ -487,6 +501,7 @@ class KernelBenchProblem:
487501
prompt_gpu_name: str | None = None
488502

489503
_prompt: str | None = field(default=None, repr=False)
504+
_base_prompt: str | None = field(default=None, repr=False)
490505

491506
@property
492507
def prompt(self) -> str:
@@ -504,3 +519,23 @@ def prompt(self) -> str:
504519
)
505520
return self._prompt
506521

522+
@property
523+
def base_prompt(self) -> str:
524+
"""Get the zero-shot prompt (no examples) for refinement turns.
525+
526+
In multi-turn training, the one-shot example is included only on the
527+
first turn. Subsequent turns use this stripped-down prompt to save
528+
context tokens.
529+
"""
530+
if self._base_prompt is None:
531+
self._base_prompt = get_prompt_for_problem(
532+
self.level,
533+
self.problem_id,
534+
self.backend,
535+
option="zero_shot",
536+
dataset_src=self.dataset_src,
537+
precision=self.prompt_precision,
538+
include_hardware=self.prompt_include_hardware,
539+
gpu_name=self.prompt_gpu_name,
540+
)
541+
return self._base_prompt

0 commit comments

Comments
 (0)