Is your feature request related to a problem? Please describe.
The fused operations for the grouped MLP block (see ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 and BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, first added in #2769) has significant CPU overhead. When I run a basic benchmark (on GB200 with 64 experts and 128 hidden size), I find the forward pass takes ~1.2 ms and the backward pass takes ~2.1 ms.
Describe the solution you'd like
Based on profiling, here are some rough estimates for some slow sections:
As we make optimizations, we should also adapt them to the unfused grouped linear op, and generally consider cleanups and refactors.
Describe alternatives you've considered
Additional context
Is your feature request related to a problem? Please describe.
The fused operations for the grouped MLP block (see
ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8andBackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, first added in #2769) has significant CPU overhead. When I run a basic benchmark (on GB200 with 64 experts and 128 hidden size), I find the forward pass takes ~1.2 ms and the backward pass takes ~2.1 ms.Describe the solution you'd like
Based on profiling, here are some rough estimates for some slow sections:
tex.get_device_pointer_for_data_and_scales: >100 us, with ~50 us before enteringnvte_multi_tensor_swizzle_scaling_factorsclear_tensor_data: 200 usAs we make optimizations, we should also adapt them to the unfused grouped linear op, and generally consider cleanups and refactors.
Describe alternatives you've considered
Additional context