Skip to content

Commit de2a2e4

Browse files
Anndrey24Sqvid
authored andcommitted
cpu: aarch64: fix ASIMD leaky ReLU when alpha > 1
1 parent 187507f commit de2a2e4

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,9 @@ size_t jit_uni_eltwise_injector_t<isa>::aux_vecs_count() {
15511551
if (is_fwd_) {
15521552
switch (alg_) {
15531553
case eltwise_relu_use_dst_for_bwd:
1554-
case eltwise_relu: return (alpha_ == 0.f) ? 2 : 3;
1554+
case eltwise_relu:
1555+
return (isa == asimd) ? ((alpha_ > 1.f) ? 3 : 2)
1556+
: ((alpha_ == 0.f) ? 0 : 2);
15551557
case eltwise_elu_use_dst_for_bwd:
15561558
case eltwise_elu: return (isa == asimd) ? 7 : 5; /* = exp + 2 */
15571559
case eltwise_tanh_use_dst_for_bwd:
@@ -2662,10 +2664,20 @@ void jit_uni_eltwise_injector_t<asimd>::relu_zero_ns_compute_vector_fwd(
26622664
template <>
26632665
void jit_uni_eltwise_injector_t<asimd>::relu_compute_vector_fwd(
26642666
const TRegS &vmm_src) {
2665-
// vmm_tmp = alpha * vmm_tmp
2666-
h->fmul(vmm_aux0, vmm_src, z_tmp);
2667-
// vmm_src = max(vmm_src, vmm_tmp)
2668-
h->fmaxnm(vmm_src, vmm_src, vmm_aux0);
2667+
// Compute x > 0 ? x : alpha * x.
2668+
// For alpha <= 1, this is equivalent to max(x, alpha * x).
2669+
if (alpha_ <= 1.f) {
2670+
h->fmul(vmm_aux0, vmm_src, z_tmp);
2671+
h->fmaxnm(vmm_src, vmm_src, vmm_aux0);
2672+
} else {
2673+
// For alpha > 1, keep positive lanes unchanged and scale only
2674+
// x <= 0 lanes.
2675+
h->mov(VReg16B(vmm_aux1.getIdx()), VReg16B(vmm_src.getIdx()));
2676+
h->fcmgt(vmm_aux0, vmm_aux1, 0.);
2677+
h->fmul(vmm_src, vmm_src, z_tmp);
2678+
h->bit(VReg16B(vmm_src.getIdx()), VReg16B(vmm_aux1.getIdx()),
2679+
VReg16B(vmm_aux0.getIdx()));
2680+
}
26692681
}
26702682

26712683
template <>

0 commit comments

Comments
 (0)