Skip to content

[PyTorch] Reduce CPU overhead in grouped MLP block #2897

@timmoon10

Description

@timmoon10

Is your feature request related to a problem? Please describe.

The fused operations for the grouped MLP block (see ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 and BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, first added in #2769) has significant CPU overhead. When I run a basic benchmark (on GB200 with 64 experts and 128 hidden size), I find the forward pass takes ~1.2 ms and the backward pass takes ~2.1 ms.

Describe the solution you'd like

Based on profiling, here are some rough estimates for some slow sections:

  • wgrad tensor allocations: 350 us
  • tex.get_device_pointer_for_data_and_scales: >100 us, with ~50 us before entering nvte_multi_tensor_swizzle_scaling_factors
  • clear_tensor_data: 200 us
  • Initializing weight quantizers in the forward pass: 90 us
  • Tensor reshapes before and after cuDNN kernels: ~50 us
  • cuDNN group GEMM kernels: ~150 us

As we make optimizations, we should also adapt them to the unfused grouped linear op, and generally consider cleanups and refactors.

Describe alternatives you've considered

Additional context

Metadata

Metadata

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions