|
| 1 | +""" |
| 2 | +Shared utilities for KernelBench environments. |
| 3 | +
|
| 4 | +Contains helpers used by both the single-turn and multi-turn environments: |
| 5 | +- System prompt construction |
| 6 | +- Step evaluation (parse → evaluate → reward → metrics) |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import logging |
| 12 | +import time |
| 13 | +from dataclasses import dataclass |
| 14 | +from typing import TYPE_CHECKING |
| 15 | + |
| 16 | +from tinker_cookbook import renderers |
| 17 | +from tinker_cookbook.rl.types import Action, Metrics |
| 18 | + |
| 19 | +from kernelbench_tinker.config.configs import EvalConfig |
| 20 | +from kernelbench_tinker.envs.kernelbench_client import ( |
| 21 | + KernelBenchProblem, |
| 22 | + KernelEvalResult, |
| 23 | + ParsedResponse, |
| 24 | + evaluate_kernel_async, |
| 25 | + parse_structured_response, |
| 26 | +) |
| 27 | +from kernelbench_tinker.training.reward import ( |
| 28 | + RewardConfig, |
| 29 | + compute_reward, |
| 30 | +) |
| 31 | + |
| 32 | +logger = logging.getLogger(__name__) |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class EvalStepResult: |
| 37 | + """Result from evaluate_step(), shared by single-turn and multi-turn envs.""" |
| 38 | + |
| 39 | + parsed: ParsedResponse |
| 40 | + eval_result: KernelEvalResult |
| 41 | + format_ok: bool |
| 42 | + kernel_code: str |
| 43 | + reward: float |
| 44 | + metrics: Metrics |
| 45 | + response_text: str # Raw response content from renderer (before structured parsing) |
| 46 | + |
| 47 | + |
| 48 | +def build_system_prompt(backend: str) -> str: |
| 49 | + """Build a backend-specific system prompt for kernel generation. |
| 50 | +
|
| 51 | + Used by both single-turn and multi-turn environments. |
| 52 | + """ |
| 53 | + return ( |
| 54 | + f"You are an expert GPU kernel developer. Your task is to optimize PyTorch " |
| 55 | + f"operations by writing efficient custom {backend.upper()} kernels.\n" |
| 56 | + f"\n" |
| 57 | + f"When given a PyTorch model, write an optimized kernel implementation.\n" |
| 58 | + f"\n" |
| 59 | + f"Your solution must:\n" |
| 60 | + f"- Be a drop-in replacement as a class named `ModelNew`\n" |
| 61 | + f"- Use custom {backend.upper()} kernels, not just PyTorch operations\n" |
| 62 | + f"- Be correct and produce the same results as the reference\n" |
| 63 | + f"\n" |
| 64 | + f"You MUST respond in exactly this format:\n" |
| 65 | + f"\n" |
| 66 | + f"<KERNEL>\n" |
| 67 | + f"```python\n" |
| 68 | + f"# Your complete optimized implementation here\n" |
| 69 | + f"class ModelNew(nn.Module):\n" |
| 70 | + f" ...\n" |
| 71 | + f"```\n" |
| 72 | + f"</KERNEL>" |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +async def evaluate_step( |
| 77 | + problem: KernelBenchProblem, |
| 78 | + renderer: renderers.Renderer, |
| 79 | + action: Action, |
| 80 | + eval_config: EvalConfig, |
| 81 | + reward_config: RewardConfig, |
| 82 | + step_start: float, |
| 83 | +) -> EvalStepResult: |
| 84 | + """Parse, evaluate, and compute reward for a single action. |
| 85 | +
|
| 86 | + Shared by KernelBenchEnv.step() and MultiTurnKernelBenchEnv.step(). |
| 87 | + """ |
| 88 | + message, _ = renderer.parse_response(action) |
| 89 | + response_text = message.get("content", "") |
| 90 | + |
| 91 | + parsed = parse_structured_response(response_text) |
| 92 | + kernel_code = parsed.kernel |
| 93 | + format_ok = parsed.format_ok |
| 94 | + |
| 95 | + eval_start = time.perf_counter() |
| 96 | + cfg = eval_config |
| 97 | + eval_result = await evaluate_kernel_async( |
| 98 | + level=problem.level, |
| 99 | + problem_id=problem.problem_id, |
| 100 | + backend=problem.backend, |
| 101 | + kernel_code=kernel_code, |
| 102 | + dataset_src=problem.dataset_src, |
| 103 | + num_correct_trials=cfg.num_correct_trials, |
| 104 | + measure_performance=cfg.measure_performance, |
| 105 | + num_perf_trials=cfg.num_perf_trials, |
| 106 | + timing_method=cfg.timing_method, |
| 107 | + precision=cfg.precision, |
| 108 | + check_for_excessive_speedup=cfg.check_for_excessive_speedup, |
| 109 | + excessive_speedup_threshold=cfg.excessive_speedup_threshold, |
| 110 | + timeout=cfg.modal_timeout, |
| 111 | + ) |
| 112 | + eval_time = time.perf_counter() - eval_start |
| 113 | + |
| 114 | + reward = compute_reward( |
| 115 | + eval_result, |
| 116 | + reward_config, |
| 117 | + kernel_code=kernel_code, |
| 118 | + backend=problem.backend, |
| 119 | + ) |
| 120 | + |
| 121 | + metrics: Metrics = { |
| 122 | + "level": problem.level, |
| 123 | + "problem_id": problem.problem_id, |
| 124 | + "format_ok": float(format_ok), |
| 125 | + "compiled": float(eval_result["compiled"]), |
| 126 | + "correctness": float(eval_result["correctness"]), |
| 127 | + "tests_passed": eval_result["tests_passed"], |
| 128 | + "tests_total": eval_result["tests_total"], |
| 129 | + } |
| 130 | + if eval_result.get("speedup") is not None: |
| 131 | + metrics["speedup"] = eval_result["speedup"] |
| 132 | + if eval_result.get("runtime_ms") is not None: |
| 133 | + metrics["runtime_ms"] = eval_result["runtime_ms"] |
| 134 | + metrics["time/eval"] = eval_time |
| 135 | + timing_metadata = (eval_result.get("metadata") or {}).get("timings", {}) |
| 136 | + if "reference_load_s" in timing_metadata: |
| 137 | + metrics["time/ref_load"] = timing_metadata["reference_load_s"] |
| 138 | + if "modal_eval_s" in timing_metadata: |
| 139 | + metrics["time/modal_eval"] = timing_metadata["modal_eval_s"] |
| 140 | + metrics["time/step_total"] = time.perf_counter() - step_start |
| 141 | + |
| 142 | + return EvalStepResult( |
| 143 | + parsed=parsed, |
| 144 | + eval_result=eval_result, |
| 145 | + format_ok=format_ok, |
| 146 | + kernel_code=kernel_code, |
| 147 | + reward=reward, |
| 148 | + metrics=metrics, |
| 149 | + response_text=response_text, |
| 150 | + ) |
0 commit comments