Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -158,11 +158,6 @@ protected void square(long[] a, long[] r) {
*/
@Override
protected void mult(long[] a, long[] b, long[] r) {
multImpl(a, b, r);
reducePositive(r);
}

private void multImpl(long[] a, long[] b, long[] r) {
long aa0 = a[0];
long aa1 = a[1];
long aa2 = a[2];
Expand Down Expand Up @@ -394,17 +389,43 @@ private void multImpl(long[] a, long[] b, long[] r) {
dd4 += MathUtil.unsignedMultiplyHigh(n, modulus[4]) << shift1 | (n4 >>> shift2);
d4 += n4 & LIMB_MASK;

// Final carry propagate
c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB);
c6 += d2 + dd1;
c7 += d3 + dd2;
c8 += d4 + dd3;
c9 = dd4;

r[0] = c5;
r[1] = c6;
r[2] = c7;
r[3] = c8;
r[4] = c9;
c6 += d2 + dd1 + (c5 >>> BITS_PER_LIMB);
c7 += d3 + dd2 + (c6 >>> BITS_PER_LIMB);
c8 += d4 + dd3 + (c7 >>> BITS_PER_LIMB);
c9 = dd4 + (c8 >>> BITS_PER_LIMB);

c5 &= LIMB_MASK;
c6 &= LIMB_MASK;
c7 &= LIMB_MASK;
c8 &= LIMB_MASK;

// At this point, the result {c5, c6, c7, c8, c9} could overflow by
// one modulus. Subtract one modulus (with carry propagation), into
// {c0, c1, c2, c3, c4}. Note that in this calculation, limbs are
// signed
c0 = c5 - modulus[0];
c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
c0 &= LIMB_MASK;
c2 = c7 - modulus[2] + (c1 >> BITS_PER_LIMB);
c1 &= LIMB_MASK;
c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
c2 &= LIMB_MASK;
c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
c3 &= LIMB_MASK;

// We now must select a result that is in range of [0,modulus). i.e.
// either {c0-4} or {c5-9}. Iff {c0-4} is negative, then {c5-9} contains
// the result. (After carry propagation) IF c4 is negative, {c0-4} is
// negative. Arithmetic shift by 64 bits generates a mask from c4 that
// can be used to select 'constant time' either {c0-4} or {c5-9}.
long mask = c4 >> 63;
r[0] = ((c5 & mask) | (c0 & ~mask));
r[1] = ((c6 & mask) | (c1 & ~mask));
r[2] = ((c7 & mask) | (c2 & ~mask));
r[3] = ((c8 & mask) | (c3 & ~mask));
r[4] = ((c9 & mask) | (c4 & ~mask));
}

@Override
Expand Down Expand Up @@ -522,27 +543,4 @@ protected void reduceIn(long[] limbs, long v, int i) {
limbs[i - 5] += (v << 4) & LIMB_MASK;
limbs[i - 4] += v >> 48;
}

// Used when limbs a could overflow by one modulus.
// @ForceInline
protected void reducePositive(long[] a) {
long aa0 = a[0];
long aa1 = a[1] + (aa0>>BITS_PER_LIMB);
long aa2 = a[2] + (aa1>>BITS_PER_LIMB);
long aa3 = a[3] + (aa2>>BITS_PER_LIMB);
long aa4 = a[4] + (aa3>>BITS_PER_LIMB);

long c0 = a[0] - modulus[0];
long c1 = a[1] - modulus[1] + (c0 >> BITS_PER_LIMB);
long c2 = a[2] - modulus[2] + (c1 >> BITS_PER_LIMB);
long c3 = a[3] - modulus[3] + (c2 >> BITS_PER_LIMB);
long c4 = a[4] - modulus[4] + (c3 >> BITS_PER_LIMB);
long mask = c4 >> BITS_PER_LIMB; // Signed shift!

a[0] = ((aa0 & mask) | (c0 & ~mask)) & LIMB_MASK;
a[1] = ((aa1 & mask) | (c1 & ~mask)) & LIMB_MASK;
a[2] = ((aa2 & mask) | (c2 & ~mask)) & LIMB_MASK;
a[3] = ((aa3 & mask) | (c3 & ~mask)) & LIMB_MASK;
a[4] = ((aa4 & mask) | (c4 & ~mask));
}
}
Loading