Skip to content

aarch64: support for per_dim_0 scales and bf16 dst_dt in jit int8 matmul#4987

Open
michalowski-arm wants to merge 3 commits intouxlfoundation:mainfrom
michalowski-arm:jitint8@final
Open

aarch64: support for per_dim_0 scales and bf16 dst_dt in jit int8 matmul#4987
michalowski-arm wants to merge 3 commits intouxlfoundation:mainfrom
michalowski-arm:jitint8@final

Conversation

@michalowski-arm
Copy link
Copy Markdown
Contributor

@michalowski-arm michalowski-arm commented Apr 9, 2026

Description

This change adds AArch64 jit:int8 matmul support for row-wise source scales (src:per_dim_0) and bf16 destination output.

The original motivation for this change is to improve W8A8 serving performance for workloads such as Llama and Whisper in vLLM. In these models, activations are symmetrically quantized with per-token scales, which map to matmul src:per_dim_0 scales (the matmul M dimension).

Before this change, the AArch64 jit:int8 matmul path did not support this case fully, so vLLM had to run an additional epilogue to apply the activation scales and convert the result to bf16. With this PR, the scale application and bf16 destination handling can stay inside the oneDNN matmul path, removing that extra epilogue and reducing extra memory traffic and output-side work.

In vLLM testing, this improved output-token throughput by roughly 5–10%, depending on the model.

Checklist

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

@michalowski-arm michalowski-arm requested review from a team as code owners April 9, 2026 10:38
@github-actions github-actions Bot added platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64 component:tests Codeowner: @oneapi-src/onednn-arch component:common labels Apr 9, 2026
Copy link
Copy Markdown
Contributor

@dzarukin dzarukin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please share the motivation of extending functional capabilities and confirm the readiness to enable it uniformly across the stack. The further guidance will follow based on those answers. Thank you.

Comment thread src/common/matmul.cpp
Comment thread tests/benchdnn/matmul/matmul.cpp Outdated
res->reason = skip_reason::case_not_supported;
return;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preferred way of doing things is to enable reference (shared implementations between backends) to avoid multiple parties handling the case that is going to be supported in a single point of a single backend.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I can rework it to use ref_matmul_int8 as the generic fallback instead. If I’m not wrong, the support is already there and the case is only rejected during pd creation. Would that work?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will.
Shared parts besides touched also include matmul_pd::attr_scales_ok() function and ref_matmul.{c,h}pp files.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if the approach I took with the latest commit works.

Comment thread src/common/matmul_pd.hpp Outdated
const std::vector<int> &supported_qmodes
= {quantization_mode::static_sazp}) const {
= {quantization_mode::static_sazp},
bool allow_src_per_m = false) const {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me the signature of this function is not very suitable for mask support extension in one or two specific matmul implementations and should be further extended.

If you need to wrap up this activity earlier, I suggest to expand implementation local conditions and let new mask in.
If you are up to take an extra mile, I think the idea below should suit for the future:

// This function covers a common ground across ALL implementations. Masks that are supported
// by exclusive implementations can be passed through `extra_masks`. It takes an elements as pairs of
// `{arg, {1, 5, ...}}` and will be additionally checked.
virtual bool attr_scales_ok(const std::vector<int> &supported_args
            = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST},
            const std::vector<int> &supported_qmodes
            = {quantization_mode::static_sazp},
            const std::map<int, std::vector<int>> &extra_masks = {}) {
            ...
            // Masks supported in all implementations.
            bool mask_ok = utils::one_of(mask, 0, src_qmask_K(),
                                src_qmask_M() + src_qmask_K(),
                                full_tensor_mask());
            // If mask passed wasn't found, let check extra masks coming from the impl.
            if (!mask_ok) {
                if (extra_masks.find(arg) != extra_masks.end()) {
                    for (auto &em : extra_masks.at(arg).second)
                        if (mask == em) mask_ok = true;
                }
            }
            ...

Having it this way should prevent from updating other backends without affecting their testing results and support capabilities.

Comment thread tests/benchdnn/dnn_types.cpp Outdated
Comment thread tests/benchdnn/inputs/matmul/test_matmul_ci Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:common component:tests Codeowner: @oneapi-src/onednn-arch platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants