Skip to content

Commit af4384f

Browse files
committed
feat: add NemotronH hybrid model support with multi-GPU VRAM calibration
NemotronH (Mamba2 SSM + MoE + Attention) requires several changes to load and abliterate correctly on multi-GPU systems. Architecture support (model.py): - Add backbone.layers fallback in get_layers() for NemotronH's model.backbone.layers structure - Add get_layer_modules() patterns for NemotronH's unified mixer attribute: mixer.out_proj (Mamba2), mixer.o_proj (attention), mixer.down_proj / mixer.experts[*].down_proj / mixer.shared_experts.down_proj (MoE) - Scan all layers in get_abliterable_components() instead of only layer 0, to discover the full union of component types in hybrid architectures - Add _get_hidden_states_via_hooks() fallback for models that don't return hidden_states through generate() (NemotronH returns tuple of Nones); use forward hooks on each layer with device-aware stacking for multi-GPU compatibility - Skip meta-device and NaN-weight modules in abliterate() to prevent NaN corruption when layers are CPU-offloaded by Accelerate - Add _has_mamba_layers() to detect hybrid SSM architectures Multi-GPU VRAM calibration (model.py): - After inference warmup on multi-GPU systems, check if any GPU has less than 6 GiB free; if so, release the model, measure actual free VRAM per GPU, and reload once with corrected per-GPU caps - Overloaded GPUs get a 0.7 correction factor for Accelerate's layer-size underestimation; other GPUs get full budget to absorb displaced layers; gated to hybrid SSM models via _has_mamba_layers() so regular transformers are unaffected User experience: - Show trust_remote_code explanation with model repo link before prompting, replacing the bare HuggingFace error message - Auto-install mamba-ssm when required, with clear nvcc/CUDA toolkit guidance on build failure - Suggest installing causal-conv1d and mamba-ssm after loading any model with Mamba layers when fast kernels are missing Other fixes: - Sum VRAM across all GPUs in print_memory_usage() (utils.py) - Show total and per-GPU VRAM in startup output (main.py) - Fix division by zero in evaluator when base_refusals is 0 - Add mamba optional dependency group to pyproject.toml
1 parent 4c80c4b commit af4384f

4 files changed

Lines changed: 2868 additions & 2522 deletions

File tree

src/heretic/evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
110110
kl_divergence_scale = self.settings.kl_divergence_scale
111111
kl_divergence_target = self.settings.kl_divergence_target
112112

113-
refusals_score = refusals / self.base_refusals
113+
refusals_score = (
114+
refusals / self.base_refusals if self.base_refusals > 0 else 0.0
115+
)
114116

115117
if kl_divergence >= kl_divergence_target:
116118
kld_score = kl_divergence / kl_divergence_scale

src/heretic/main.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,22 @@
5454
)
5555

5656

57-
def obtain_merge_strategy(settings: Settings) -> str | None:
57+
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
5858
"""
5959
Prompts the user for how to proceed with saving the model.
6060
Provides info to the user if the model is quantized on memory use.
6161
Returns "merge", "adapter", or None (if cancelled/invalid).
6262
"""
6363

64-
if settings.quantization == QuantizationMethod.BNB_4BIT:
64+
# Also detect pre-quantized models (FP8, MXFP4, etc.) via their built-in
65+
# quantization_config, which HuggingFace stores in the model's config.json.
66+
pre_quantized = (
67+
getattr(model.model.config, "quantization_config", None) is not None
68+
and settings.quantization == QuantizationMethod.NONE
69+
)
70+
is_quantized = settings.quantization == QuantizationMethod.BNB_4BIT or pre_quantized
71+
72+
if is_quantized:
6573
print()
6674
print(
6775
"Model was loaded with quantization. Merging requires reloading the base model."
@@ -753,7 +761,7 @@ def count_completed_trials() -> int:
753761
if not save_directory:
754762
continue
755763

756-
strategy = obtain_merge_strategy(settings)
764+
strategy = obtain_merge_strategy(settings, model)
757765
if strategy is None:
758766
continue
759767

@@ -802,7 +810,7 @@ def count_completed_trials() -> int:
802810
)
803811
private = visibility == "Private"
804812

805-
strategy = obtain_merge_strategy(settings)
813+
strategy = obtain_merge_strategy(settings, model)
806814
if strategy is None:
807815
continue
808816

0 commit comments

Comments
 (0)