@@ -1551,7 +1551,9 @@ size_t jit_uni_eltwise_injector_t<isa>::aux_vecs_count() {
15511551 if (is_fwd_) {
15521552 switch (alg_) {
15531553 case eltwise_relu_use_dst_for_bwd:
1554- case eltwise_relu: return (alpha_ == 0 .f ) ? 2 : 3 ;
1554+ case eltwise_relu:
1555+ return (isa == asimd) ? ((alpha_ > 1 .f ) ? 3 : 2 )
1556+ : ((alpha_ == 0 .f ) ? 0 : 2 );
15551557 case eltwise_elu_use_dst_for_bwd:
15561558 case eltwise_elu: return (isa == asimd) ? 7 : 5 ; /* = exp + 2 */
15571559 case eltwise_tanh_use_dst_for_bwd:
@@ -2662,10 +2664,20 @@ void jit_uni_eltwise_injector_t<asimd>::relu_zero_ns_compute_vector_fwd(
26622664template <>
26632665void jit_uni_eltwise_injector_t <asimd>::relu_compute_vector_fwd(
26642666 const TRegS &vmm_src) {
2665- // vmm_tmp = alpha * vmm_tmp
2666- h->fmul (vmm_aux0, vmm_src, z_tmp);
2667- // vmm_src = max(vmm_src, vmm_tmp)
2668- h->fmaxnm (vmm_src, vmm_src, vmm_aux0);
2667+ // Compute x > 0 ? x : alpha * x.
2668+ // For alpha <= 1, this is equivalent to max(x, alpha * x).
2669+ if (alpha_ <= 1 .f ) {
2670+ h->fmul (vmm_aux0, vmm_src, z_tmp);
2671+ h->fmaxnm (vmm_src, vmm_src, vmm_aux0);
2672+ } else {
2673+ // For alpha > 1, keep positive lanes unchanged and scale only
2674+ // x <= 0 lanes.
2675+ h->mov (VReg16B (vmm_aux1.getIdx ()), VReg16B (vmm_src.getIdx ()));
2676+ h->fcmgt (vmm_aux0, vmm_aux1, 0 .);
2677+ h->fmul (vmm_src, vmm_src, z_tmp);
2678+ h->bit (VReg16B (vmm_src.getIdx ()), VReg16B (vmm_aux1.getIdx ()),
2679+ VReg16B (vmm_aux0.getIdx ()));
2680+ }
26692681}
26702682
26712683template <>
0 commit comments