aarch64: support for per_dim_0 scales and bf16 dst_dt in jit int8 matmul#4987
aarch64: support for per_dim_0 scales and bf16 dst_dt in jit int8 matmul#4987michalowski-arm wants to merge 3 commits intouxlfoundation:mainfrom
Conversation
dzarukin
left a comment
There was a problem hiding this comment.
Please share the motivation of extending functional capabilities and confirm the readiness to enable it uniformly across the stack. The further guidance will follow based on those answers. Thank you.
| res->reason = skip_reason::case_not_supported; | ||
| return; | ||
| } | ||
| } |
There was a problem hiding this comment.
The preferred way of doing things is to enable reference (shared implementations between backends) to avoid multiple parties handling the case that is going to be supported in a single point of a single backend.
There was a problem hiding this comment.
That makes sense. I can rework it to use ref_matmul_int8 as the generic fallback instead. If I’m not wrong, the support is already there and the case is only rejected during pd creation. Would that work?
There was a problem hiding this comment.
Yes, it will.
Shared parts besides touched also include matmul_pd::attr_scales_ok() function and ref_matmul.{c,h}pp files.
There was a problem hiding this comment.
Let me know if the approach I took with the latest commit works.
7056882 to
3cc0836
Compare
| const std::vector<int> &supported_qmodes | ||
| = {quantization_mode::static_sazp}) const { | ||
| = {quantization_mode::static_sazp}, | ||
| bool allow_src_per_m = false) const { |
There was a problem hiding this comment.
It seems to me the signature of this function is not very suitable for mask support extension in one or two specific matmul implementations and should be further extended.
If you need to wrap up this activity earlier, I suggest to expand implementation local conditions and let new mask in.
If you are up to take an extra mile, I think the idea below should suit for the future:
// This function covers a common ground across ALL implementations. Masks that are supported
// by exclusive implementations can be passed through `extra_masks`. It takes an elements as pairs of
// `{arg, {1, 5, ...}}` and will be additionally checked.
virtual bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST},
const std::vector<int> &supported_qmodes
= {quantization_mode::static_sazp},
const std::map<int, std::vector<int>> &extra_masks = {}) {
...
// Masks supported in all implementations.
bool mask_ok = utils::one_of(mask, 0, src_qmask_K(),
src_qmask_M() + src_qmask_K(),
full_tensor_mask());
// If mask passed wasn't found, let check extra masks coming from the impl.
if (!mask_ok) {
if (extra_masks.find(arg) != extra_masks.end()) {
for (auto &em : extra_masks.at(arg).second)
if (mask == em) mask_ok = true;
}
}
...
Having it this way should prevent from updating other backends without affecting their testing results and support capabilities.
3cc0836 to
0915e90
Compare
Description
This change adds AArch64 jit:int8 matmul support for row-wise source scales (src:per_dim_0) and bf16 destination output.
The original motivation for this change is to improve W8A8 serving performance for workloads such as Llama and Whisper in vLLM. In these models, activations are symmetrically quantized with per-token scales, which map to matmul src:per_dim_0 scales (the matmul M dimension).
Before this change, the AArch64 jit:int8 matmul path did not support this case fully, so vLLM had to run an additional epilogue to apply the activation scales and convert the result to bf16. With this PR, the scale application and bf16 destination handling can stay inside the oneDNN matmul path, removing that extra epilogue and reducing extra memory traffic and output-side work.
In vLLM testing, this improved output-token throughput by roughly 5–10%, depending on the model.
Checklist
make testandmake test_benchdnn_*) pass locally for each commit?