-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat: multi-token KL divergence for more robust quality measurement #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,7 +28,12 @@ def __init__(self, settings: Settings, model: Model): | |||||||||||||||
| self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts) | ||||||||||||||||
| print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") | ||||||||||||||||
|
|
||||||||||||||||
| print("* Obtaining first-token probability distributions...") | ||||||||||||||||
| kl_label = ( | ||||||||||||||||
| f"* Obtaining {settings.kl_tokens}-token probability distributions..." | ||||||||||||||||
| if settings.kl_tokens > 1 | ||||||||||||||||
| else "* Obtaining first-token probability distributions..." | ||||||||||||||||
| ) | ||||||||||||||||
| print(kl_label) | ||||||||||||||||
| self.base_logprobs = model.get_logprobs_batched(self.good_prompts) | ||||||||||||||||
|
|
||||||||||||||||
| print() | ||||||||||||||||
|
|
@@ -93,7 +98,12 @@ def count_refusals(self) -> int: | |||||||||||||||
| return refusal_count | ||||||||||||||||
|
|
||||||||||||||||
| def get_score(self) -> tuple[tuple[float, float], float, int]: | ||||||||||||||||
| print(" * Obtaining first-token probability distributions...") | ||||||||||||||||
| kl_label = ( | ||||||||||||||||
| f" * Obtaining {self.settings.kl_tokens}-token probability distributions..." | ||||||||||||||||
| if self.settings.kl_tokens > 1 | ||||||||||||||||
| else " * Obtaining first-token probability distributions..." | ||||||||||||||||
| ) | ||||||||||||||||
| print(kl_label) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||
| logprobs = self.model.get_logprobs_batched(self.good_prompts) | ||||||||||||||||
| kl_divergence = F.kl_div( | ||||||||||||||||
| logprobs, | ||||||||||||||||
|
|
@@ -107,8 +117,11 @@ def get_score(self) -> tuple[tuple[float, float], float, int]: | |||||||||||||||
| refusals = self.count_refusals() | ||||||||||||||||
| print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}") | ||||||||||||||||
|
|
||||||||||||||||
| kl_divergence_scale = self.settings.kl_divergence_scale | ||||||||||||||||
| kl_divergence_target = self.settings.kl_divergence_target | ||||||||||||||||
| # Scale thresholds by kl_tokens since multi-token KL produces | ||||||||||||||||
| # proportionally larger absolute values. | ||||||||||||||||
| kl_tokens = self.settings.kl_tokens | ||||||||||||||||
| kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens | ||||||||||||||||
| kl_divergence_target = self.settings.kl_divergence_target * kl_tokens | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The KL divergence thresholds are being scaled by As noted in
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| refusals_score = refusals / self.base_refusals | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -656,25 +656,28 @@ def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: | |||||||||
| # We work with logprobs rather than probabilities for numerical stability | ||||||||||
| # when computing the KL divergence. | ||||||||||
| def get_logprobs(self, prompts: list[Prompt]) -> Tensor: | ||||||||||
| # We only generate one token, and we return the (log) probability distributions | ||||||||||
| # over the vocabulary at that token position, for each prompt. | ||||||||||
| n_tokens = self.settings.kl_tokens | ||||||||||
|
|
||||||||||
| _, outputs = self.generate( | ||||||||||
| prompts, | ||||||||||
| max_new_tokens=1, | ||||||||||
| max_new_tokens=n_tokens, | ||||||||||
| output_scores=True, | ||||||||||
| return_dict_in_generate=True, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # This cast is valid because GenerateDecoderOnlyOutput is the return type | ||||||||||
| # of model.generate with return_dict_in_generate=True. | ||||||||||
| outputs = cast(GenerateDecoderOnlyOutput, outputs) | ||||||||||
| scores = cast(tuple[FloatTensor], outputs.scores) | ||||||||||
|
|
||||||||||
| if n_tokens == 1: | ||||||||||
| # Original single-token path: shape (prompt, vocab). | ||||||||||
| return F.log_softmax(scores[0], dim=-1) | ||||||||||
|
|
||||||||||
| # Logits for the first (only) generated token. | ||||||||||
| # This cast is valid because we passed output_scores=True above. | ||||||||||
| logits = cast(tuple[FloatTensor], outputs.scores)[0] | ||||||||||
| # Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab) | ||||||||||
| # so KL div with batchmean naturally averages across all positions. | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment does not fully adhere to the style guide (rule #4), which requires comments to use correct grammar, start with a capital letter, and end with a period. The second line is a continuation of the first and starts with a lowercase letter. Please rephrase for clarity and to follow the style guide.
Suggested change
References
|
||||||||||
| all_logits = torch.stack(list(scores), dim=0) # (n_tokens, prompt, vocab) | ||||||||||
| all_logits = all_logits.permute(1, 0, 2).reshape(-1, all_logits.shape[-1]) | ||||||||||
|
|
||||||||||
| # The returned tensor has shape (prompt, token). | ||||||||||
| return F.log_softmax(logits, dim=-1) | ||||||||||
| return F.log_softmax(all_logits, dim=-1) | ||||||||||
|
|
||||||||||
| def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: | ||||||||||
| logprobs = [] | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to rule #8 of the repository style guide, when a new setting is added to
config.py, it should also be added toconfig.default.toml. The newkl_tokenssetting is missing fromconfig.default.toml.Please add the following to
config.default.tomlafterwinsorization_quantile:References
config.py, they should also be added toconfig.default.toml, set to their default value and with their description as a comment. The order of settings inconfig.default.tomlshould match that inconfig.py. (link)