@@ -135,7 +135,7 @@ impl::status_t sdp_decomp_config_t::construct_params(
135135 sub_mm1_wei_md, sub_mm1_dst_md, sub_softmax_dst_md,
136136 sub_wei2_user_md, sub_mm2_wei_md, sub_mm2_dst_md, sub_dst_md,
137137 sub_dst_user_md, sub_select_cond_md, sub_select_src0_md;
138- std::vector<memory::desc> sub_mm1_post_md;
138+ std::vector<memory::desc> sub_mm1_post_md, sub_softmax_post_md ;
139139
140140 // must use user mode to support concurrent execution
141141 primitive_attr sub_reorder0_attr;
@@ -229,6 +229,25 @@ impl::status_t sdp_decomp_config_t::construct_params(
229229 // softmax
230230 // create softmax primitive attr
231231 dnnl::primitive_attr sub_softmax_attr = make_primitive_attr (sdp_op[2 ], mgr);
232+
233+ dnnl_pops = {};
234+ ori_dnnl_pops = sub_softmax_attr.get_post_ops ();
235+ for (int i = 0 ; i < ori_dnnl_pops.get ()->len (); i++) {
236+ const auto alg = static_cast <algorithm>(
237+ ori_dnnl_pops.get ()->entry_ [i].binary .alg );
238+ const dnnl::impl::memory_desc_t &ori_desc
239+ = ori_dnnl_pops.get ()->entry_ [i].binary .user_src1_desc ;
240+ auto post_shape = ori_desc.dims ;
241+ auto post_stride = ori_desc.format_desc .blocking .strides ;
242+ auto post_dt = static_cast <memory::data_type>(ori_desc.data_type );
243+ dims post_stride_dims = dims (post_stride, post_stride + ori_desc.ndims );
244+ auto new_sub_md = memory::desc ({1 , 1 , post_shape[2 ], post_shape[3 ]},
245+ post_dt, post_stride_dims);
246+ sub_softmax_post_md.emplace_back (new_sub_md);
247+ dnnl_pops.append_binary (alg, new_sub_md);
248+ }
249+ sub_softmax_attr.set_post_ops (dnnl_pops);
250+
232251 sub_softmax_dst_md = memory::desc (sub_mm1_dst_dims, dt_src_user, tag::abcd);
233252 const auto mode = sdp_op[2 ]->get_attr <std::string>(op_attr::mode);
234253 const dnnl::algorithm algo = mode == " inf_as_zero"
@@ -337,6 +356,23 @@ impl::status_t sdp_decomp_config_t::construct_params(
337356 }
338357 // softmax
339358 sub_softmax_dst = memory (sub_softmax_dst_md, p_engine, nullptr );
359+ for (int i = 0 ; i < (int )sub_softmax_post_md.size (); i++) {
360+ sub_softmax_post_mem.emplace_back (sub_softmax_post_md[i], p_engine);
361+ auto alg = static_cast <algorithm>(
362+ ori_dnnl_pops.get ()->entry_ [i].binary .alg );
363+ if (alg == dnnl::algorithm::binary_mul) {
364+ float *ptr = reinterpret_cast <float *>(
365+ sub_softmax_post_mem[i].get_data_handle ());
366+ ptr[0 ] = get_attr_value<float , float >(
367+ sdp_op[2 ], i + 1 , op_attr::scales);
368+ }
369+ if (alg == dnnl::algorithm::binary_add) {
370+ int *ptr = reinterpret_cast <int *>(
371+ sub_softmax_post_mem[i].get_data_handle ());
372+ ptr[0 ] = get_attr_value<int64_t , int32_t >(
373+ sdp_op[2 ], i + 1 , op_attr::zps);
374+ }
375+ }
340376 // reorder2
341377 sub_wei2_user = memory (sub_wei2_user_md, p_engine, nullptr );
342378 // mm2
@@ -372,6 +408,12 @@ impl::status_t sdp_decomp_config_t::construct_params(
372408 {DNNL_ARG_DST, sub_softmax_dst},
373409 {DNNL_ARG_SCRATCHPAD, sub_scratchpad}};
374410
411+ for (int i = 0 ; i < (int )sub_softmax_post_mem.size (); i++) {
412+ sub_softmax_args.insert (
413+ {DNNL_ARG_ATTR_MULTIPLE_POST_OP (i) | DNNL_ARG_SRC_1,
414+ sub_softmax_post_mem[i]});
415+ }
416+
375417 sub_reorder2_args = {{DNNL_ARG_SRC, sub_wei2_user},
376418 {DNNL_ARG_DST, sub_mm2_wei}, {DNNL_ARG_SCRATCHPAD, sub_scratchpad}};
377419
0 commit comments