Skip to content

NVFP4 ONNX export fails when rotate=True (FWHT trace + post-processor) #1424

@Clemxxx

Description

@Clemxxx

Before submitting an issue, please make sure it hasn't been already addressed by searching through the existing and past issues.

Describe the bug

When rotate: True is set on *weight_quantizer in an NVFP4 quant config (e.g. NVFP4_FP8_MHA_CONFIG), ONNX export breaks in two related ways that surface together:

  1. Trace-side: normalized_hadamard_transform in modelopt/torch/quantization/nn/functional.py calls fast_hadamard_transform.hadamard_transform(), which is a custom CUDA op without an ONNX symbolic function. During torch.onnx.export, it gets evaluated eagerly — the rotated activation tensor is baked as a Constant in the exported graph, severing the runtime input flow. The compiled engine then produces output that's independent of its inputs.

  2. Post-processor: NVFP4QuantExporter (in modelopt/onnx/export/nvfp4_exporter.py) assumes weight tensors live in graph.initializer. With rotation enabled, they end up in upstream Constant nodes' value attribute (or further upstream as a Reshape/MatMul subgraph). Three sites — compute_scales (L219), compress_weights (L263), post_process (L369) — hit:

AssertionError: Initializer for weight '/.../weight_quantizer/Constant_output_0' not found.

Discovered while compiling FLUX.2 Klein 4B for TensorRT. Investigation done with Claude's assistance.

Steps/Code to reproduce bug

import modelopt.torch.quantization as mtq
import torch.onnx
import onnx
from modelopt.onnx.export.nvfp4_exporter import NVFP4QuantExporter

cfg = mtq.NVFP4_FP8_MHA_CONFIG.copy()
cfg["quant_cfg"]["*weight_quantizer"]["rotate"] = True
mtq.quantize(model, cfg, forward_loop=...)

torch.onnx.export(model, ..., "model.onnx")
m = onnx.load("model.onnx")
NVFP4QuantExporter.process_model(m)
# AssertionError: Initializer for weight '/.../weight_quantizer/Constant_output_0' not found.

Expected behavior

rotate: True is documented as a supported flag in NVFP4 configs, so either:

  • Export should produce a usable engine (rotation works end-to-end through ONNX), or
  • The flag should fail clearly at config validation time if NVFP4-export + rotation isn't supported.

Workaround we used locally

Two user-side changes to unblock our pipeline:

  1. Trace fix: monkey-patched normalized_hadamard_transform to use torch.matmul(x, H) against an explicit pre-computed Hadamard matrix. Mathematically identical (the FWHT butterfly is just a fast factorization of x @ H), but pure PyTorch ops are ONNX-traceable. Slower at calibration time, no runtime cost (TRT folds the matmul against the constant H).

  2. Post-processor fix: added a _resolve_weight_tensor() helper to handle three weight-source patterns:

    • Direct initializer (existing path, unchanged — non-rotation users see no change)
    • Upstream Constant node value attribute
    • Upstream constant subgraph (Reshape/MatMul/Mul/Add/Sub/Div/Cast/Transpose) — evaluated to numpy at compile time

    Plus post_process promotes any remaining inline Constant ≥ 64 KB to a graph initializer so save_as_external_data can externalize them. Without this, large rotation H matrices stay inline and the proto exceeds protobuf's 2 GB serialization limit, blocking the TRT parser.

End-to-end this produces a working rotated NVFP4 engine for FLUX.2 Klein 4B on RTX 5090 + TensorRT 10.15.1.

Question for maintainers

What's the right fix shape upstream? A few options:

  • Replace the FWHT call in normalized_hadamard_transform with an ONNX-traceable equivalent (matmul against explicit H, or expressed as a Reshape/Add/Sub butterfly chain)?
  • Add an ONNX symbolic for FastHadamardTransform.apply() so the existing CUDA op exports correctly?
  • Accept user-side post-processor patches that handle non-initializer weight sources (the _resolve_weight_tensor approach)?
  • Mark rotation as PyTorch-inference-only and reject it at config validation for ONNX export paths?

Happy to PR whichever direction you prefer. We can share the matmul-Hadamard + post-processor patches if useful.

Related: #614

Who can help?

Anyone familiar with the NVFP4 ONNX export path or the rotation/Hadamard helpers in modelopt.torch.quantization.nn.functional.

System information

  • OS: Windows 11 (10.0.26100)
  • CPU architecture: x86_64
  • GPU: NVIDIA GeForce RTX 5090 (Blackwell SM_120)
  • GPU memory size: 32 GB
  • Number of GPUs: 1
  • Library versions:
    • Python: 3.12.10
    • ModelOpt: 0.42.0
    • CUDA: 13.0
    • PyTorch: 2.11.0+cu130
    • Transformers: 5.5.4
    • TensorRT: 10.15.1.29
  • Other: FLUX.2 Klein 4B (rectified flow distilled)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions