Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ class Settings(BaseSettings):
),
)

kl_tokens: int = Field(
default=1,
description=(
"Number of tokens to generate when computing KL divergence. "
"Higher values give a more robust quality signal at the cost of slower evaluation. "
"The KL divergence is averaged across all token positions. "
"Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns)."
),
)
Comment on lines +218 to +226
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to rule #8 of the repository style guide, when a new setting is added to config.py, it should also be added to config.default.toml. The new kl_tokens setting is missing from config.default.toml.

Please add the following to config.default.toml after winsorization_quantile:

# Number of tokens to generate when computing KL divergence.
# Higher values give a more robust quality signal at the cost of slower evaluation.
# The KL divergence is averaged across all token positions.
# Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns).
kl_tokens = 1
References
  1. When new settings are added in config.py, they should also be added to config.default.toml, set to their default value and with their description as a comment. The order of settings in config.default.toml should match that in config.py. (link)


n_trials: int = Field(
default=200,
description="Number of abliteration trials to run during optimization.",
Expand Down
21 changes: 17 additions & 4 deletions src/heretic/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def __init__(self, settings: Settings, model: Model):
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")

print("* Obtaining first-token probability distributions...")
kl_label = (
f"* Obtaining {settings.kl_tokens}-token probability distributions..."
if settings.kl_tokens > 1
else "* Obtaining first-token probability distributions..."
)
print(kl_label)
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)

print()
Expand Down Expand Up @@ -93,7 +98,12 @@ def count_refusals(self) -> int:
return refusal_count

def get_score(self) -> tuple[tuple[float, float], float, int]:
print(" * Obtaining first-token probability distributions...")
kl_label = (
f" * Obtaining {self.settings.kl_tokens}-token probability distributions..."
if self.settings.kl_tokens > 1
else " * Obtaining first-token probability distributions..."
)
print(kl_label)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to generate the kl_label string is duplicated from the __init__ method (lines 31-36). To improve maintainability and avoid code duplication, consider refactoring this logic into a private helper method within the Evaluator class.

logprobs = self.model.get_logprobs_batched(self.good_prompts)
kl_divergence = F.kl_div(
logprobs,
Expand All @@ -107,8 +117,11 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
refusals = self.count_refusals()
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")

kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target
# Scale thresholds by kl_tokens since multi-token KL produces
# proportionally larger absolute values.
kl_tokens = self.settings.kl_tokens
kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens
kl_divergence_target = self.settings.kl_divergence_target * kl_tokens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The KL divergence thresholds are being scaled by kl_tokens. However, the implementation in model.py reshapes the logits so that F.kl_div with reduction="batchmean" computes the average KL divergence across all generated tokens. An average value should not be proportional to the number of tokens. This scaling appears to be incorrect and may lead to the optimizer behaving in unintended ways. The thresholds should likely not be scaled.

As noted in model.py, batchmean already handles the averaging, so the resulting kl_divergence value should be of a similar magnitude regardless of kl_tokens.

Suggested change
# Scale thresholds by kl_tokens since multi-token KL produces
# proportionally larger absolute values.
kl_tokens = self.settings.kl_tokens
kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens
kl_divergence_target = self.settings.kl_divergence_target * kl_tokens
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target


refusals_score = refusals / self.base_refusals

Expand Down
23 changes: 13 additions & 10 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,25 +656,28 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
# We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence.
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the (log) probability distributions
# over the vocabulary at that token position, for each prompt.
n_tokens = self.settings.kl_tokens

_, outputs = self.generate(
prompts,
max_new_tokens=1,
max_new_tokens=n_tokens,
output_scores=True,
return_dict_in_generate=True,
)

# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
scores = cast(tuple[FloatTensor], outputs.scores)

if n_tokens == 1:
# Original single-token path: shape (prompt, vocab).
return F.log_softmax(scores[0], dim=-1)

# Logits for the first (only) generated token.
# This cast is valid because we passed output_scores=True above.
logits = cast(tuple[FloatTensor], outputs.scores)[0]
# Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab)
# so KL div with batchmean naturally averages across all positions.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment does not fully adhere to the style guide (rule #4), which requires comments to use correct grammar, start with a capital letter, and end with a period. The second line is a continuation of the first and starts with a lowercase letter. Please rephrase for clarity and to follow the style guide.

Suggested change
# Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab)
# so KL div with batchmean naturally averages across all positions.
# Multi-token: stack all positions and reshape to (prompt * n_tokens, vocab).
# This allows `F.kl_div` with `reduction="batchmean"` to average the KL divergence across all positions.
References
  1. Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. (link)

all_logits = torch.stack(list(scores), dim=0) # (n_tokens, prompt, vocab)
all_logits = all_logits.permute(1, 0, 2).reshape(-1, all_logits.shape[-1])

# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
return F.log_softmax(all_logits, dim=-1)

def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = []
Expand Down