Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ def _apply_lora(self):
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))

# FP8 dtypes (e.g. float8_e4m3fn) are not supported by standard torch.addmm,
# which nn.Linear uses internally. Models distributed in FP8 (e.g. MiniMax-M2.5)
# will cause LoRA forward passes to fail. Cast LoRA adapter weights to bfloat16
# so that the adapter matmuls use a supported dtype.
for name, param in self.model.named_parameters():
if "lora_" in name and param.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
param.data = param.data.to(torch.bfloat16)

print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")

def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
Expand Down Expand Up @@ -257,9 +265,26 @@ def get_merged_model(self) -> PreTrainedModel:
merged_model = peft_model.merge_and_unload()
return merged_model
else:
# Non-quantized model - can merge directly
# Non-quantized model - can merge directly.
# FP8 base weights don't support in-place addition (+=) needed by merge,
# so upcast them to bfloat16 first, merge, then cast back.
fp8_params = {}
for name, module in self.model.named_modules():
if hasattr(module, "weight") and not isinstance(module, Linear):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition not isinstance(module, Linear) prevents the logic from running on LoRA-wrapped layers. These are precisely the layers that need their base weights upcast because merge_and_unload() performs an in-place addition on them, which fails for FP8 dtypes. The weight property on a peft.tuners.lora.layer.Linear module correctly delegates to the base layer's weight, so these modules should be processed. Removing this part of the condition will fix the issue and allow the merge to succeed with FP8 models.

Suggested change
if hasattr(module, "weight") and not isinstance(module, Linear):
if hasattr(module, "weight"):

w = module.weight
if w.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
fp8_params[name] = w.dtype
module.weight.data = w.data.to(torch.bfloat16)

print("* Merging LoRA adapters into base model...")
merged_model = self.model.merge_and_unload()

# Cast merged weights back to their original FP8 dtype.
if fp8_params:
for name, module in merged_model.named_modules():
if name in fp8_params and hasattr(module, "weight"):
module.weight.data = module.weight.data.to(fp8_params[name])

# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
# Mark for full reload if user switches trials later.
self.needs_reload = True
Expand Down