From 367e0a65561f95aad61b40930d5f46843fee3444 Mon Sep 17 00:00:00 2001
From: fabioromano1 <51378941+fabioromano1@users.noreply.github.com>
Date: Sat, 3 Aug 2024 13:08:54 +0000
Subject: [PATCH] 8334755: Asymptotically faster implementation of square root
algorithm
Reviewed-by: rgiulietti
---
.../share/classes/java/math/BigInteger.java | 12 +-
.../classes/java/math/MutableBigInteger.java | 307 +++++++++++++-----
.../java/math/BigInteger/BigIntegerTest.java | 24 +-
.../bench/java/math/BigIntegerSquareRoot.java | 132 ++++++++
4 files changed, 390 insertions(+), 85 deletions(-)
create mode 100644 test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java
diff --git a/src/java.base/share/classes/java/math/BigInteger.java b/src/java.base/share/classes/java/math/BigInteger.java
index c69e291d0e2..3a5fd143937 100644
--- a/src/java.base/share/classes/java/math/BigInteger.java
+++ b/src/java.base/share/classes/java/math/BigInteger.java
@@ -2723,7 +2723,7 @@ public BigInteger sqrt() {
throw new ArithmeticException("Negative BigInteger");
}
- return new MutableBigInteger(this.mag).sqrt().toBigInteger();
+ return new MutableBigInteger(this.mag).sqrtRem(false)[0].toBigInteger();
}
/**
@@ -2742,10 +2742,12 @@ public BigInteger sqrt() {
* @since 9
*/
public BigInteger[] sqrtAndRemainder() {
- BigInteger s = sqrt();
- BigInteger r = this.subtract(s.square());
- assert r.compareTo(BigInteger.ZERO) >= 0;
- return new BigInteger[] {s, r};
+ if (this.signum < 0) {
+ throw new ArithmeticException("Negative BigInteger");
+ }
+
+ MutableBigInteger[] sqrtRem = new MutableBigInteger(this.mag).sqrtRem(true);
+ return new BigInteger[] { sqrtRem[0].toBigInteger(), sqrtRem[1].toBigInteger() };
}
/**
diff --git a/src/java.base/share/classes/java/math/MutableBigInteger.java b/src/java.base/share/classes/java/math/MutableBigInteger.java
index 30ea8e130fc..b84e50f567e 100644
--- a/src/java.base/share/classes/java/math/MutableBigInteger.java
+++ b/src/java.base/share/classes/java/math/MutableBigInteger.java
@@ -109,9 +109,26 @@ class MutableBigInteger {
* the int val.
*/
MutableBigInteger(int val) {
- value = new int[1];
- intLen = 1;
- value[0] = val;
+ init(val);
+ }
+
+ /**
+ * Construct a new MutableBigInteger with a magnitude specified by
+ * the long val.
+ */
+ MutableBigInteger(long val) {
+ int hi = (int) (val >>> 32);
+ if (hi == 0) {
+ init((int) val);
+ } else {
+ value = new int[] { hi, (int) val };
+ intLen = 2;
+ }
+ }
+
+ private void init(int val) {
+ value = new int[] { val };
+ intLen = val != 0 ? 1 : 0;
}
/**
@@ -260,6 +277,7 @@ void reset() {
* Compare the magnitude of two MutableBigIntegers. Returns -1, 0 or 1
* as this MutableBigInteger is numerically less than, equal to, or
* greater than {@code b}.
+ * Assumes no leading unnecessary zeros.
*/
final int compare(MutableBigInteger b) {
int blen = b.intLen;
@@ -285,6 +303,7 @@ final int compare(MutableBigInteger b) {
/**
* Returns a value equal to what {@code b.leftShift(32*ints); return compare(b);}
* would return, but doesn't change the value of {@code b}.
+ * Assumes no leading unnecessary zeros.
*/
private int compareShifted(MutableBigInteger b, int ints) {
int blen = b.intLen;
@@ -538,6 +557,7 @@ void safeRightShift(int n) {
/**
* Right shift this MutableBigInteger n bits. The MutableBigInteger is left
* in normal form.
+ * Assumes {@code Math.ceilDiv(n, 32) <= intLen || intLen == 0}
*/
void rightShift(int n) {
if (intLen == 0)
@@ -911,6 +931,58 @@ void addLower(MutableBigInteger addend, int n) {
add(a);
}
+ /**
+ * Shifts {@code this} of {@code n} ints to the left and adds {@code addend}.
+ * Assumes {@code n > 0} for speed.
+ */
+ void shiftAdd(MutableBigInteger addend, int n) {
+ // Fast cases
+ if (addend.intLen <= n) {
+ shiftAddDisjoint(addend, n);
+ } else if (intLen == 0) {
+ copyValue(addend);
+ } else {
+ leftShift(n << 5);
+ add(addend);
+ }
+ }
+
+ /**
+ * Shifts {@code this} of {@code n} ints to the left and adds {@code addend}.
+ * Assumes {@code addend.intLen <= n}.
+ */
+ void shiftAddDisjoint(MutableBigInteger addend, int n) {
+ if (intLen == 0) { // Avoid unnormal values
+ copyValue(addend);
+ return;
+ }
+
+ int[] res;
+ final int resLen = intLen + n, resOffset;
+ if (resLen > value.length) {
+ res = new int[resLen];
+ System.arraycopy(value, offset, res, 0, intLen);
+ resOffset = 0;
+ } else {
+ res = value;
+ if (offset + resLen > value.length) {
+ System.arraycopy(value, offset, res, 0, intLen);
+ resOffset = 0;
+ } else {
+ resOffset = offset;
+ }
+ // Clear words where necessary
+ if (addend.intLen < n)
+ Arrays.fill(res, resOffset + intLen, resOffset + resLen - addend.intLen, 0);
+ }
+
+ System.arraycopy(addend.value, addend.offset, res, resOffset + resLen - addend.intLen, addend.intLen);
+
+ value = res;
+ offset = resOffset;
+ intLen = resLen;
+ }
+
/**
* Subtracts the smaller of this and b from the larger and places the
* result into this MutableBigInteger.
@@ -1003,6 +1075,7 @@ private int difference(MutableBigInteger b) {
/**
* Multiply the contents of two MutableBigInteger objects. The result is
* placed into MutableBigInteger z. The contents of y are not changed.
+ * Assume {@code intLen > 0}
*/
void multiply(MutableBigInteger y, MutableBigInteger z) {
int xLen = intLen;
@@ -1793,93 +1866,169 @@ private boolean unsignedLongCompare(long one, long two) {
}
/**
- * Calculate the integer square root {@code floor(sqrt(this))} where
- * {@code sqrt(.)} denotes the mathematical square root. The contents of
- * {@code this} are not changed. The value of {@code this} is assumed
- * to be non-negative.
+ * Calculate the integer square root {@code floor(sqrt(this))} and the remainder
+ * if needed, where {@code sqrt(.)} denotes the mathematical square root.
+ * The contents of {@code this} are not changed.
+ * The value of {@code this} is assumed to be non-negative.
*
- * @implNote The implementation is based on the material in Henry S. Warren,
- * Jr., Hacker's Delight (2nd ed.) (Addison Wesley, 2013), 279-282.
- *
- * @throws ArithmeticException if the value returned by {@code bitLength()}
- * overflows the range of {@code int}.
- * @return the integer square root of {@code this}
- * @since 9
+ * @return the integer square root of {@code this} and the remainder if needed
*/
- MutableBigInteger sqrt() {
+ MutableBigInteger[] sqrtRem(boolean needRemainder) {
// Special cases.
- if (this.isZero()) {
- return new MutableBigInteger(0);
- } else if (this.value.length == 1
- && (this.value[0] & LONG_MASK) < 4) { // result is unity
- return ONE;
- }
-
- if (bitLength() <= 63) {
- // Initial estimate is the square root of the positive long value.
- long v = new BigInteger(this.value, 1).longValueExact();
- long xk = (long)Math.floor(Math.sqrt(v));
-
- // Refine the estimate.
- do {
- long xk1 = (xk + v/xk)/2;
-
- // Terminate when non-decreasing.
- if (xk1 >= xk) {
- return new MutableBigInteger(new int[] {
- (int)(xk >>> 32), (int)(xk & LONG_MASK)
- });
+ if (this.intLen <= 2) {
+ final long x = this.toLong(); // unsigned
+ long s = unsignedLongSqrt(x);
+
+ return new MutableBigInteger[] {
+ new MutableBigInteger((int) s),
+ needRemainder ? new MutableBigInteger(x - s * s) : null
+ };
+ }
+
+ // Normalize
+ MutableBigInteger x = this;
+ final int shift = (Integer.numberOfLeadingZeros(x.value[x.offset]) & ~1) // shift must be even
+ + ((x.intLen & 1) << 5); // x.intLen must be even
+
+ if (shift != 0) {
+ x = new MutableBigInteger(x);
+ x.leftShift(shift);
+ }
+
+ // Compute sqrt and remainder
+ MutableBigInteger[] sqrtRem = x.sqrtRemKaratsuba(x.intLen, needRemainder);
+
+ // Unnormalize
+ if (shift != 0) {
+ final int halfShift = shift >> 1;
+ if (needRemainder) {
+ // shift <= 62, so s0 is at most 31 bit long
+ final long s0 = sqrtRem[0].value[sqrtRem[0].offset + sqrtRem[0].intLen - 1]
+ & (-1 >>> -halfShift); // Remove excess bits
+ if (s0 != 0L) { // An optimization
+ MutableBigInteger doubleProd = new MutableBigInteger();
+ sqrtRem[0].mul((int) (s0 << 1), doubleProd);
+
+ sqrtRem[1].add(doubleProd);
+ sqrtRem[1].subtract(new MutableBigInteger(s0 * s0));
}
+ sqrtRem[1].rightShift(shift);
+ }
+ sqrtRem[0].primitiveRightShift(halfShift);
+ }
+ return sqrtRem;
+ }
- xk = xk1;
- } while (true);
- } else {
- // Set up the initial estimate of the iteration.
+ private static long unsignedLongSqrt(long x) {
+ /* For every long value s in [0, 2^32) such that x == s * s,
+ * it is true that s - 1 <= (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) <= s,
+ * and if x == 2^64 - 1, then (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) == 2^32.
+ * Since both cast to long and `Math.sqrt()` are (weakly) increasing,
+ * this means that the value returned by Math.sqrt()
+ * for a long value in the range [0, 2^64) is either correct,
+ * or rounded up/down by one if the value is too high
+ * and too close to a perfect square.
+ */
+ long s = (long) Math.sqrt(x >= 0 ? x : x + 0x1p64);
+ long s2 = s * s; // overflows iff s == 2^32
+ return Long.compareUnsigned(x, s2) < 0 || s > LONG_MASK
+ ? s - 1
+ : (Long.compareUnsigned(x, s2 + (s << 1)) <= 0 // x <= (s + 1)^2 - 1, does not overflow
+ ? s
+ : s + 1);
+ }
- // Obtain the bitLength > 63.
- int bitLength = (int) this.bitLength();
- if (bitLength != this.bitLength()) {
- throw new ArithmeticException("bitLength() integer overflow");
- }
+ /**
+ * Assumes {@code 2 <= len <= intLen && len % 2 == 0
+ * && Integer.numberOfLeadingZeros(value[offset]) <= 1}
+ * @implNote The implementation is based on Zimmermann's works available
+ * here and
+ * here
+ */
+ private MutableBigInteger[] sqrtRemKaratsuba(int len, boolean needRemainder) {
+ if (len == 2) { // Base case
+ long x = ((value[offset] & LONG_MASK) << 32) | (value[offset + 1] & LONG_MASK);
+ long s = unsignedLongSqrt(x);
- // Determine an even valued right shift into positive long range.
- int shift = bitLength - 63;
- if (shift % 2 == 1) {
- shift++;
- }
+ // Allocate sufficient space to hold the final square root, assuming intLen % 2 == 0
+ MutableBigInteger sqrt = new MutableBigInteger(new int[intLen >> 1]);
- // Shift the value into positive long range.
- MutableBigInteger xk = new MutableBigInteger(this);
- xk.rightShift(shift);
- xk.normalize();
-
- // Use the square root of the shifted value as an approximation.
- double d = new BigInteger(xk.value, 1).doubleValue();
- BigInteger bi = BigInteger.valueOf((long)Math.ceil(Math.sqrt(d)));
- xk = new MutableBigInteger(bi.mag);
-
- // Shift the approximate square root back into the original range.
- xk.leftShift(shift / 2);
-
- // Refine the estimate.
- MutableBigInteger xk1 = new MutableBigInteger();
- do {
- // xk1 = (xk + n/xk)/2
- this.divide(xk, xk1, false);
- xk1.add(xk);
- xk1.rightShift(1);
-
- // Terminate when non-decreasing.
- if (xk1.compare(xk) >= 0) {
- return xk;
- }
+ // Place the partial square root
+ sqrt.intLen = 1;
+ sqrt.value[0] = (int) s;
+
+ return new MutableBigInteger[] { sqrt, new MutableBigInteger(x - s * s) };
+ }
- // xk = xk1
- xk.copyValue(xk1);
+ // Recursive step (len >= 4)
- xk1.reset();
- } while (true);
+ final int halfLen = len >> 1;
+ // Recursive invocation
+ MutableBigInteger[] sr = sqrtRemKaratsuba(halfLen + (halfLen & 1), true);
+
+ final int blockLen = halfLen >> 1;
+ MutableBigInteger dividend = sr[1];
+ dividend.shiftAddDisjoint(getBlockForSqrt(1, len, blockLen), blockLen);
+
+ // Compute dividend / (2*sqrt)
+ MutableBigInteger sqrt = sr[0];
+ MutableBigInteger q = new MutableBigInteger();
+ MutableBigInteger u = dividend.divide(sqrt, q);
+ if (q.isOdd())
+ u.add(sqrt);
+ q.rightShift(1);
+
+ sqrt.shiftAdd(q, blockLen);
+ // Corresponds to ub + a_0 in the paper
+ u.shiftAddDisjoint(getBlockForSqrt(0, len, blockLen), blockLen);
+ BigInteger qBig = q.toBigInteger(); // Cast to BigInteger to use fast multiplication
+ MutableBigInteger qSqr = new MutableBigInteger(qBig.multiply(qBig).mag);
+
+ MutableBigInteger rem;
+ if (needRemainder) {
+ rem = u;
+ if (rem.subtract(qSqr) < 0) {
+ MutableBigInteger twiceSqrt = new MutableBigInteger(sqrt);
+ twiceSqrt.leftShift(1);
+
+ // Since subtract() performs an absolute difference, to get the correct algebraic sum
+ // we must first add the sum of absolute values of addends concordant with the sign of rem
+ // and then subtract the sum of absolute values of addends that are discordant
+ rem.add(ONE);
+ rem.subtract(twiceSqrt);
+ sqrt.subtract(ONE);
+ }
+ } else {
+ rem = null;
+ if (u.compare(qSqr) < 0)
+ sqrt.subtract(ONE);
}
+
+ sr[1] = rem;
+ return sr;
+ }
+
+ /**
+ * Returns a {@code MutableBigInteger} obtained by taking {@code blockLen} ints from
+ * {@code this} number, ending at {@code blockIndex*blockLen} (exclusive).
+ * Used in Karatsuba square root.
+ * @param blockIndex the block index, starting from the lowest
+ * @param len the logical length of the input value in units of 32 bits
+ * @param blockLen the length of the block in units of 32 bits
+ *
+ * @return a {@code MutableBigInteger} obtained by taking {@code blockLen} ints from
+ * {@code this} number, ending at {@code blockIndex*blockLen} (exclusive).
+ */
+ private MutableBigInteger getBlockForSqrt(int blockIndex, int len, int blockLen) {
+ final int to = offset + len - blockIndex * blockLen;
+
+ // Skip leading zeros
+ int from;
+ for (from = to - blockLen; from < to && value[from] == 0; from++);
+
+ return from == to
+ ? new MutableBigInteger()
+ : new MutableBigInteger(Arrays.copyOfRange(value, from, to));
}
/**
diff --git a/test/jdk/java/math/BigInteger/BigIntegerTest.java b/test/jdk/java/math/BigInteger/BigIntegerTest.java
index 2ac4750e43f..7da3fdac618 100644
--- a/test/jdk/java/math/BigInteger/BigIntegerTest.java
+++ b/test/jdk/java/math/BigInteger/BigIntegerTest.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 1998, 2023, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1998, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -293,8 +293,30 @@ private static void squareRootSmall() {
report("squareRootSmall", failCount);
}
+ private static void perfectSquaresLong() {
+ /* For every long value n in [0, 2^32) such that x == n * n,
+ * n - 1 <= (long) Math.sqrt(x >= 0 ? x : x + 0x1p64) <= n
+ * must be true.
+ * This property is used to implement MutableBigInteger.unsignedLongSqrt().
+ */
+ int failCount = 0;
+
+ long limit = 1L << 32;
+ for (long n = 0; n < limit; n++) {
+ long x = n * n;
+ long s = (long) Math.sqrt(x >= 0 ? x : x + 0x1p64);
+ if (!(s == n || s == n - 1)) {
+ failCount++;
+ System.err.println(s + "^2 != " + x + " && (" + s + "+1)^2 != " + x);
+ }
+ }
+
+ report("perfectSquaresLong", failCount);
+ }
+
public static void squareRoot() {
squareRootSmall();
+ perfectSquaresLong();
ToIntFunction f = (n) -> {
int failCount = 0;
diff --git a/test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java b/test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java
new file mode 100644
index 00000000000..4b78b4cd8fa
--- /dev/null
+++ b/test/micro/org/openjdk/bench/java/math/BigIntegerSquareRoot.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+package org.openjdk.bench.java.math;
+
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OperationsPerInvocation;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.math.BigInteger;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@State(Scope.Thread)
+@Warmup(iterations = 5, time = 1)
+@Measurement(iterations = 5, time = 1)
+@Fork(value = 3)
+public class BigIntegerSquareRoot {
+
+ private BigInteger[] xsArray, sArray, mArray, lArray, xlArray;
+ private static final int TESTSIZE = 1000;
+
+ @Setup
+ public void setup() {
+ Random r = new Random(1123);
+
+ xsArray = new BigInteger[TESTSIZE]; /*
+ * Each array entry is atmost 64 bits
+ * in size
+ */
+ sArray = new BigInteger[TESTSIZE]; /*
+ * Each array entry is atmost 256 bits
+ * in size
+ */
+ mArray = new BigInteger[TESTSIZE]; /*
+ * Each array entry is atmost 1024 bits
+ * in size
+ */
+ lArray = new BigInteger[TESTSIZE]; /*
+ * Each array entry is atmost 4096 bits
+ * in size
+ */
+ xlArray = new BigInteger[TESTSIZE]; /*
+ * Each array entry is atmost 16384 bits
+ * in size
+ */
+
+ for (int i = 0; i < TESTSIZE; i++) {
+ xsArray[i] = new BigInteger(r.nextInt(64), r);
+ sArray[i] = new BigInteger(r.nextInt(256), r);
+ mArray[i] = new BigInteger(r.nextInt(1024), r);
+ lArray[i] = new BigInteger(r.nextInt(4096), r);
+ xlArray[i] = new BigInteger(r.nextInt(16384), r);
+ }
+ }
+
+ /** Test BigInteger.sqrt() with numbers long at most 64 bits */
+ @Benchmark
+ @OperationsPerInvocation(TESTSIZE)
+ public void testSqrtXS(Blackhole bh) {
+ for (BigInteger s : xsArray) {
+ bh.consume(s.sqrt());
+ }
+ }
+
+ /** Test BigInteger.sqrt() with numbers long at most 256 bits */
+ @Benchmark
+ @OperationsPerInvocation(TESTSIZE)
+ public void testSqrtS(Blackhole bh) {
+ for (BigInteger s : sArray) {
+ bh.consume(s.sqrt());
+ }
+ }
+
+ /** Test BigInteger.sqrt() with numbers long at most 1024 bits */
+ @Benchmark
+ @OperationsPerInvocation(TESTSIZE)
+ public void testSqrtM(Blackhole bh) {
+ for (BigInteger s : mArray) {
+ bh.consume(s.sqrt());
+ }
+ }
+
+ /** Test BigInteger.sqrt() with numbers long at most 4096 bits */
+ @Benchmark
+ @OperationsPerInvocation(TESTSIZE)
+ public void testSqrtL(Blackhole bh) {
+ for (BigInteger s : lArray) {
+ bh.consume(s.sqrt());
+ }
+ }
+
+ /** Test BigInteger.sqrt() with numbers long at most 16384 bits */
+ @Benchmark
+ @OperationsPerInvocation(TESTSIZE)
+ public void testSqrtXL(Blackhole bh) {
+ for (BigInteger s : xlArray) {
+ bh.consume(s.sqrt());
+ }
+ }
+}