diff --git a/config.default.toml b/config.default.toml index abfa0fc7..13c144d4 100644 --- a/config.default.toml +++ b/config.default.toml @@ -62,6 +62,12 @@ kl_divergence_scale = 1.0 # This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". kl_divergence_target = 0.01 +# Number of tokens to generate when computing KL divergence. +# Higher values give a more robust quality signal at the cost of slower evaluation. +# The KL divergence is averaged across all token positions. +# Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns). +kl_tokens = 1 + # Whether to adjust the refusal directions so that only the component that is # orthogonal to the good direction is subtracted during abliteration. orthogonalize_direction = false diff --git a/src/heretic/config.py b/src/heretic/config.py index 8ed3f80c..0b94352f 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -215,6 +215,16 @@ class Settings(BaseSettings): ), ) + kl_tokens: int = Field( + default=1, + description=( + "Number of tokens to generate when computing KL divergence. " + "Higher values give a more robust quality signal at the cost of slower evaluation. " + "The KL divergence is averaged across all token positions. " + "Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns)." + ), + ) + n_trials: int = Field( default=200, description="Number of abliteration trials to run during optimization.", diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index f2a8a258..642da390 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -28,7 +28,7 @@ def __init__(self, settings: Settings, model: Model): self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts) print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") - print("* Obtaining first-token probability distributions...") + print(f"* {self._kl_label()}") self.base_logprobs = model.get_logprobs_batched(self.good_prompts) print() @@ -92,8 +92,14 @@ def count_refusals(self) -> int: return refusal_count + def _kl_label(self) -> str: + """Return a human-readable label for the KL computation step.""" + if self.settings.kl_tokens > 1: + return f"Obtaining {self.settings.kl_tokens}-token probability distributions..." + return "Obtaining first-token probability distributions..." + def get_score(self) -> tuple[tuple[float, float], float, int]: - print(" * Obtaining first-token probability distributions...") + print(f" * {self._kl_label()}") logprobs = self.model.get_logprobs_batched(self.good_prompts) kl_divergence = F.kl_div( logprobs, diff --git a/src/heretic/model.py b/src/heretic/model.py index 58300b16..cfe03e22 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -656,25 +656,28 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: # We work with logprobs rather than probabilities for numerical stability # when computing the KL divergence. def get_logprobs(self, prompts: list[Prompt]) -> Tensor: - # We only generate one token, and we return the (log) probability distributions - # over the vocabulary at that token position, for each prompt. + n_tokens = self.settings.kl_tokens + _, outputs = self.generate( prompts, - max_new_tokens=1, + max_new_tokens=n_tokens, output_scores=True, return_dict_in_generate=True, ) - # This cast is valid because GenerateDecoderOnlyOutput is the return type - # of model.generate with return_dict_in_generate=True. outputs = cast(GenerateDecoderOnlyOutput, outputs) + scores = cast(tuple[FloatTensor], outputs.scores) + + if n_tokens == 1: + # Original single-token path: shape (prompt, vocab). + return F.log_softmax(scores[0], dim=-1) - # Logits for the first (only) generated token. - # This cast is valid because we passed output_scores=True above. - logits = cast(tuple[FloatTensor], outputs.scores)[0] + # Multi-token: stack all positions and reshape to (prompt * n_tokens, vocab). + # This allows F.kl_div with reduction="batchmean" to average across all positions. + all_logits = torch.stack(list(scores), dim=0) # (n_tokens, prompt, vocab) + all_logits = all_logits.permute(1, 0, 2).reshape(-1, all_logits.shape[-1]) - # The returned tensor has shape (prompt, token). - return F.log_softmax(logits, dim=-1) + return F.log_softmax(all_logits, dim=-1) def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: logprobs = []