Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,12 @@ def __call__(
}


class DPOLossConfig(TypedDict):
reference_policy_kl_penalty: float
preference_loss_weight: float
sft_loss_weight: float
preference_average_log_probs: bool
sft_average_log_probs: bool
class DPOLossConfig(BaseModel, extra="allow"):
reference_policy_kl_penalty: float = 0.05
preference_loss_weight: float = 1.0
sft_loss_weight: float = 0.0
preference_average_log_probs: bool = False
sft_average_log_probs: bool = False


class DPOLossDataDict(TypedDict):
Expand Down Expand Up @@ -861,12 +861,14 @@ class DPOLossFn(PreferenceLossFn):
loss_type = LossType.SEQUENCE_LEVEL
input_type = LossInputType.LOGPROB

def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False):
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.preference_loss_weight = cfg["preference_loss_weight"]
self.sft_loss_weight = cfg["sft_loss_weight"]
self.preference_average_log_probs = cfg["preference_average_log_probs"]
self.sft_average_log_probs = cfg["sft_average_log_probs"]
def __init__(self, cfg: DPOLossConfig | dict, use_linear_ce_fusion: bool = False):
if isinstance(cfg, dict):
cfg = DPOLossConfig(**cfg)
Comment on lines +864 to +866
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not use this tricky way and fix the places that fail because of this.

Suggested change
def __init__(self, cfg: DPOLossConfig | dict, use_linear_ce_fusion: bool = False):
if isinstance(cfg, dict):
cfg = DPOLossConfig(**cfg)
def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False):

Copy link
Copy Markdown
Contributor Author

@NolenLiang NolenLiang May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this branch dpo.py is not modified — DPOLossFn(master_config.dpo) at dpo.py:270 still passes a plain dict (DPOConfig is still a TypedDict here). Removing the guard now would break the L1 functional test. Will remove it once DPO PR #2524 merges.

self.reference_policy_kl_penalty = cfg.reference_policy_kl_penalty
self.preference_loss_weight = cfg.preference_loss_weight
self.sft_loss_weight = cfg.sft_loss_weight
self.preference_average_log_probs = cfg.preference_average_log_probs
self.sft_average_log_probs = cfg.sft_average_log_probs
self.use_linear_ce_fusion = use_linear_ce_fusion
self.sft_loss = NLLLossFn(use_linear_ce_fusion=use_linear_ce_fusion)

Expand Down Expand Up @@ -945,10 +947,10 @@ def __call__( # type: ignore
}


class DistillationLossConfig(TypedDict):
kl_type: str
mixed_kl_weight: float
zero_outside_topk: bool
class DistillationLossConfig(BaseModel, extra="allow"):
kl_type: str = "mixed"
mixed_kl_weight: float = 0.5
zero_outside_topk: bool = False


class DistillationLossDataDict(TypedDict):
Expand All @@ -967,9 +969,9 @@ class DistillationLossFn(LossFunction):
input_type = LossInputType.DISTILLATION

def __init__(self, cfg: DistillationLossConfig):
self.kl_type = cfg["kl_type"]
self.mixed_kl_weight = cfg["mixed_kl_weight"]
self.zero_outside_topk = cfg["zero_outside_topk"]
self.kl_type = cfg.kl_type
self.mixed_kl_weight = cfg.mixed_kl_weight
self.zero_outside_topk = cfg.zero_outside_topk
self.log_infinitesimal = -100

assert self.kl_type in ["forward", "reverse", "mixed"], "Invalid KL type"
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/algorithms/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
distillation_train,
validate,
)
from nemo_rl.algorithms.loss import DistillationLossFn
from nemo_rl.algorithms.loss import DistillationLossConfig, DistillationLossFn
from nemo_rl.data.interfaces import DatumSpec
from nemo_rl.distributed.batched_data_dict import BatchedDataDict

Expand Down Expand Up @@ -107,11 +107,11 @@ def val_iter(self):
tokenizer.pad_token_id = 0

loss_fn = DistillationLossFn(
{
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
}
DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=False,
)
)

logger = MagicMock()
Expand Down
81 changes: 41 additions & 40 deletions tests/unit/algorithms/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nemo_rl.algorithms.loss import (
ClippedPGLossConfig,
ClippedPGLossFn,
DistillationLossConfig,
DistillationLossFn,
DPOLossFn,
NLLLossFn,
Expand Down Expand Up @@ -1804,11 +1805,11 @@ def test_distillation_loss_different_settings(kl_type, zero_outside_topk):
data, student_logits = setup_distillation_test_data()

loss_fn = DistillationLossFn(
{
"kl_type": kl_type,
"mixed_kl_weight": 0.3,
"zero_outside_topk": zero_outside_topk,
}
DistillationLossConfig(
kl_type=kl_type,
mixed_kl_weight=0.3,
zero_outside_topk=zero_outside_topk,
)
)

loss_input, data = prepare_loss_input(student_logits, data, loss_fn)
Expand Down Expand Up @@ -1849,11 +1850,11 @@ def test_distillation_loss_topk_filtering(k, zero_outside_topk):
data, student_logits = setup_distillation_test_data(topk=k)

loss_fn = DistillationLossFn(
{
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": zero_outside_topk,
}
DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=zero_outside_topk,
)
)

loss_input, data = prepare_loss_input(student_logits, data, loss_fn)
Expand Down Expand Up @@ -1887,11 +1888,11 @@ def test_distillation_loss_invalid_k_zero():
data, student_logits = setup_distillation_test_data(topk=0)

loss_fn = DistillationLossFn(
{
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
}
DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=False,
)
)

# This should raise a ValueError for k=0
Expand All @@ -1907,11 +1908,11 @@ def test_distillation_loss_gradient_flow():
student_logits.requires_grad_(True)

loss_fn = DistillationLossFn(
{
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
}
DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=False,
)
)

loss_input, data = prepare_loss_input(student_logits, data, loss_fn)
Expand Down Expand Up @@ -1939,11 +1940,11 @@ def test_distillation_loss_edge_cases():
data, student_logits = setup_distillation_test_data()

loss_fn = DistillationLossFn(
{
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
}
DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=False,
)
)

# Test with all-zero logits
Expand Down Expand Up @@ -1992,22 +1993,22 @@ def test_distillation_loss_edge_cases():
def test_distillation_loss_fn_initialization():
"""Test DistillationLossFn initialization."""
# Test with default values
default_config = {
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
}
default_config = DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=False,
)
loss_fn = DistillationLossFn(default_config)
assert loss_fn.kl_type == "forward"
assert loss_fn.mixed_kl_weight == 0.5
assert not loss_fn.zero_outside_topk

# Test with custom values
custom_config = {
"kl_type": "reverse",
"mixed_kl_weight": 0.3,
"zero_outside_topk": True,
}
custom_config = DistillationLossConfig(
kl_type="reverse",
mixed_kl_weight=0.3,
zero_outside_topk=True,
)
loss_fn = DistillationLossFn(custom_config)
assert loss_fn.kl_type == "reverse"
assert loss_fn.mixed_kl_weight == 0.3
Expand All @@ -2019,11 +2020,11 @@ def test_distillation_loss_fn_call():
data, student_logits = setup_distillation_test_data()

loss_fn = DistillationLossFn(
{
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
}
DistillationLossConfig(
kl_type="forward",
mixed_kl_weight=0.5,
zero_outside_topk=False,
)
)

loss_input, data = prepare_loss_input(student_logits, data, loss_fn)
Expand Down
Loading