Skip to content

Commit 8c8eb43

Browse files
Gu, YonghaoTaoLv
authored andcommitted
graph: backend: dnnl: support int8 sdpa for softmax
1 parent 49ceeaf commit 8c8eb43

5 files changed

Lines changed: 71 additions & 5 deletions

File tree

src/graph/backend/dnnl/kernels/sdp_decomp.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,14 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
7777
BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops);
7878
BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_matmul);
7979
if (quantized) {
80+
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
8081
BACKEND_DNNL_ADD_PASS(pipeline, convert_to_runtime_dst_scales);
8182
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_scales);
8283
BACKEND_DNNL_ADD_PASS(pipeline, convert_to_runtime_dst_zero_points);
8384
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_zero_points);
84-
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
85+
// fuse those new post-binaries converted from add_zps and mul_scales
86+
BACKEND_DNNL_ADD_PASS(pipeline, replace_quant_data_with_binary_post_op);
87+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops);
8588
}
8689
pipeline.reset_visualize_arg(true, false);
8790
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);

src/graph/backend/dnnl/kernels/sdp_decomp.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ struct sdp_decomp_kernel_t : public kernel_base_t {
118118
mem_map[ori_mem.get()][tid]
119119
= memory(ori_mem.get_desc(),
120120
ori_mem.get_engine(), nullptr);
121-
if (iter.first >= DNNL_ARG_ATTR_SCALES
122-
&& iter.first <= DNNL_ARG_ATTR_POST_OP_DW) {
121+
if (iter.first >= DNNL_ARG_ATTR_SCALES) {
123122
mem_map[ori_mem.get()][tid].set_data_handle(
124123
ori_mem.get_data_handle());
125124
}

src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl::status_t sdp_decomp_config_t::construct_params(
135135
sub_mm1_wei_md, sub_mm1_dst_md, sub_softmax_dst_md,
136136
sub_wei2_user_md, sub_mm2_wei_md, sub_mm2_dst_md, sub_dst_md,
137137
sub_dst_user_md, sub_select_cond_md, sub_select_src0_md;
138-
std::vector<memory::desc> sub_mm1_post_md;
138+
std::vector<memory::desc> sub_mm1_post_md, sub_softmax_post_md;
139139

140140
// must use user mode to support concurrent execution
141141
primitive_attr sub_reorder0_attr;
@@ -229,6 +229,25 @@ impl::status_t sdp_decomp_config_t::construct_params(
229229
// softmax
230230
// create softmax primitive attr
231231
dnnl::primitive_attr sub_softmax_attr = make_primitive_attr(sdp_op[2], mgr);
232+
233+
dnnl_pops = {};
234+
ori_dnnl_pops = sub_softmax_attr.get_post_ops();
235+
for (int i = 0; i < ori_dnnl_pops.get()->len(); i++) {
236+
const auto alg = static_cast<algorithm>(
237+
ori_dnnl_pops.get()->entry_[i].binary.alg);
238+
const dnnl::impl::memory_desc_t &ori_desc
239+
= ori_dnnl_pops.get()->entry_[i].binary.user_src1_desc;
240+
auto post_shape = ori_desc.dims;
241+
auto post_stride = ori_desc.format_desc.blocking.strides;
242+
auto post_dt = static_cast<memory::data_type>(ori_desc.data_type);
243+
dims post_stride_dims = dims(post_stride, post_stride + ori_desc.ndims);
244+
auto new_sub_md = memory::desc({1, 1, post_shape[2], post_shape[3]},
245+
post_dt, post_stride_dims);
246+
sub_softmax_post_md.emplace_back(new_sub_md);
247+
dnnl_pops.append_binary(alg, new_sub_md);
248+
}
249+
sub_softmax_attr.set_post_ops(dnnl_pops);
250+
232251
sub_softmax_dst_md = memory::desc(sub_mm1_dst_dims, dt_src_user, tag::abcd);
233252
const auto mode = sdp_op[2]->get_attr<std::string>(op_attr::mode);
234253
const dnnl::algorithm algo = mode == "inf_as_zero"
@@ -337,6 +356,23 @@ impl::status_t sdp_decomp_config_t::construct_params(
337356
}
338357
// softmax
339358
sub_softmax_dst = memory(sub_softmax_dst_md, p_engine, nullptr);
359+
for (int i = 0; i < (int)sub_softmax_post_md.size(); i++) {
360+
sub_softmax_post_mem.emplace_back(sub_softmax_post_md[i], p_engine);
361+
auto alg = static_cast<algorithm>(
362+
ori_dnnl_pops.get()->entry_[i].binary.alg);
363+
if (alg == dnnl::algorithm::binary_mul) {
364+
float *ptr = reinterpret_cast<float *>(
365+
sub_softmax_post_mem[i].get_data_handle());
366+
ptr[0] = get_attr_value<float, float>(
367+
sdp_op[2], i + 1, op_attr::scales);
368+
}
369+
if (alg == dnnl::algorithm::binary_add) {
370+
int *ptr = reinterpret_cast<int *>(
371+
sub_softmax_post_mem[i].get_data_handle());
372+
ptr[0] = get_attr_value<int64_t, int32_t>(
373+
sdp_op[2], i + 1, op_attr::zps);
374+
}
375+
}
340376
// reorder2
341377
sub_wei2_user = memory(sub_wei2_user_md, p_engine, nullptr);
342378
// mm2
@@ -372,6 +408,12 @@ impl::status_t sdp_decomp_config_t::construct_params(
372408
{DNNL_ARG_DST, sub_softmax_dst},
373409
{DNNL_ARG_SCRATCHPAD, sub_scratchpad}};
374410

411+
for (int i = 0; i < (int)sub_softmax_post_mem.size(); i++) {
412+
sub_softmax_args.insert(
413+
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1,
414+
sub_softmax_post_mem[i]});
415+
}
416+
375417
sub_reorder2_args = {{DNNL_ARG_SRC, sub_wei2_user},
376418
{DNNL_ARG_DST, sub_mm2_wei}, {DNNL_ARG_SCRATCHPAD, sub_scratchpad}};
377419

src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ struct sdp_decomp_config_t {
109109
//mm1
110110
memory sub_mm1_src, sub_mm1_wei, sub_mm1_dst;
111111
// sub_mm1_post_mem contains [post_scale, attn_mask(optional)]
112-
std::vector<memory> sub_mm1_post_mem;
112+
std::vector<memory> sub_mm1_post_mem, sub_softmax_post_mem;
113113
//select binary
114114
memory sub_select_cond, sub_select_src0, sub_select_dst;
115115
//softmax

src/graph/backend/dnnl/passes/transform.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,17 @@ status_t fuse_dst_scales(std::shared_ptr<subgraph_t> &sg) {
12071207
if (consumers.size() != 1) continue;
12081208
auto &next_op = consumers[0].get_op();
12091209
if (next_op.get_kind() != op_kind::dnnl_mul_scales) continue;
1210+
// For these three ops, the dst zps are not supported
1211+
if (impl::utils::one_of(cur_op->get_kind(), op_kind::dnnl_softmax,
1212+
op_kind::dnnl_layernorm, op_kind::dnnl_groupnorm)) {
1213+
out_val = next_op.get_output_value(0);
1214+
consumers = out_val->get_consumers();
1215+
if (consumers.size() == 1) {
1216+
auto &next2_op = consumers[0].get_op();
1217+
if (next2_op.get_kind() == op_kind::dnnl_add_zps) continue;
1218+
}
1219+
}
1220+
12101221
fuse_groups.emplace_back(cur_op.get(), &next_op);
12111222
visited.insert(cur_op.get());
12121223
visited.insert(&next_op);
@@ -1249,6 +1260,17 @@ status_t convert_to_runtime_dst_scales(std::shared_ptr<subgraph_t> &sg) {
12491260
|| visited.count(cur_op.get()))
12501261
continue;
12511262

1263+
if (impl::utils::one_of(cur_op->get_input_op(0)->get_kind(),
1264+
op_kind::dnnl_softmax, op_kind::dnnl_layernorm,
1265+
op_kind::dnnl_groupnorm)) {
1266+
auto out_val = cur_op->get_output_value(0);
1267+
auto consumers = out_val->get_consumers();
1268+
if (consumers.size() == 1) {
1269+
auto &next_op = consumers[0].get_op();
1270+
if (next_op.get_kind() == op_kind::dnnl_add_zps) continue;
1271+
}
1272+
}
1273+
12521274
// This pass only handle static quantization
12531275
bool dync_quantization = cur_op->has_attr(op_attr::with_runtime_scales)
12541276
&& cur_op->get_attr<bool>(op_attr::with_runtime_scales);

0 commit comments

Comments
 (0)