Abstract: BigInteger has clever algorithms for multiplying large numbers. Unfortunately multiply() is single-threaded. Until now. In this newsletter I describe how the new parallelMultiply() method works and also how we can all contribute to the OpenJDK.
Welcome to the 305th edition of The Java(tm) Specialists' Newsletter. I am sitting in my trusty little Suzuki Jimny at Kalathas Beach, waiting for the rain to abate so I can do my morning exercise routine: Push-ups, run, dip in the sea. As anniversaries go, today is significant for two reasons. First off, my parents got married on the 30th November 56 years ago. Without that event, and many other extremely unlikely subsequent ones, well, I wouldn't be. Secondly, I sent out my first edition of The Java(tm) Specialists' Newsletter with fear and trepidation exactly 22 years ago this day. Twenty-two years. That's as old as Helene and I were when we got married. It's a long long time. And it would not have continued without your support. Thank you for all the amazing feedback and encouragement.
Also, really important: We are running our Black Friday Campaign until the 4th of December 2022. Make sure you don't miss it.
javaspecialists.teachable.com: Please visit our new self-study course catalog to see how you can upskill your Java knowledge.
One of my favourite talks is where I take the Fibonacci calculation and parallelize it. We begin with Dijkstra's recursive Sum of Squares algorithm. It is easy to parallelize through divide and conquer and using Fork/Join. However, the issue we find is that the last few calculations take most of the time. We thus also have to parallelize the multiplication within BigInteger. With Java 8 this was fairly easy. Copy and paste BigInteger.java and change it, then compile and add the class file into the boot class path. With Java 9 it became more difficult, but still possible, by patching.
Some history. In Newsletter 201: Fork Join with Fibonacci and Karatsuba we looked at how we could use Karatsuba to speed up BigInteger multiply, and even parallelize it with Fork/Join. In Newsletter 236: Computational Time Complexity of BigInteger.multiply() in Java 8, we examined the two new algorithms baked into BigInteger.multiply(), that being Karatsuba for large numbers and Toom Cook 3 for huge numbers. Toom Cook 3 has a better computational time complexity, but takes longer to set up. I mentioned at the time that Toom Cook 3 could also easily be parallelized.
How does parallelization work? In the simplest case, we write a recursive algorithm that divides the work at each step. If the work to be done is small enough, typically 10000 instructions or less, we execute it sequentially, otherwise we submit it as a parallel task. Let's examine this with a super simple calculation - Factorial.
import java.math.BigInteger; public interface Factorial { BigInteger calculate(int n); }
Since we will want to compare the multiply() and parallelMultiply() methods, we make this pluggable in the AbstractFactorial class:
import java.math.BigInteger; import java.util.function.*; public abstract class AbstractFactorial implements Factorial { private final BinaryOperator<BigInteger> multiply; protected AbstractFactorial(BinaryOperator<BigInteger> multiply) { this.multiply = multiply; } protected final BigInteger multiply(BigInteger a, BigInteger b) { return multiply.apply(a, b); } }
We begin with a simple recursive sequential algorithm:
import java.math.BigInteger; import java.util.function.*; public class FactorialSequential extends AbstractFactorial { public FactorialSequential(BinaryOperator<BigInteger> multiply) { super(multiply); } @Override public BigInteger calculate(int n) { if (n < 0) throw new IllegalArgumentException("n < 0"); return calculate(0, n); } private BigInteger calculate(int from, int to) { if (from == to) { if (from == 0) return BigInteger.ONE; return BigInteger.valueOf(from); } int mid = (from + to) >>> 1; BigInteger left = calculate(from, mid); BigInteger right = calculate(mid + 1, to); return multiply(left, right); } }
Easy enough, I hope, and should not need any explanation.
We can parallelize this with CompletableFuture, which internally uses ForkJoin by default:
import java.math.BigInteger; import java.util.concurrent.*; import java.util.function.*; import static java.util.concurrent.CompletableFuture.*; public class FactorialCompletableFuture extends AbstractFactorial { public FactorialCompletableFuture( BinaryOperator<BigInteger> multiply) { super(multiply); } @Override public BigInteger calculate(int n) { if (n < 0) throw new IllegalArgumentException("n < 0"); if (n == 0) return BigInteger.ONE; return calculate(1, n).join(); } private CompletableFuture<BigInteger> calculate(int from, int to) { if (from == to) return completedFuture(BigInteger.valueOf(from)); int mid = (from + to) >>> 1; return calculate(from, mid).thenCombineAsync( calculate(mid + 1, to), this::multiply); } }
This works, but has a few flaws. First off, there is no
threshold below which we don't parallelize. In other words,
we will call thenCombineAsync() even if we have tiny numbers
to multiply. It does not make sense to create thousands of
parallel tasks when we only have 8 cores to execute them on.
The effort of parallelization will cost us more than we will
gain. Secondly, the biggest cost will be in the final three
multiplies, and that is not parallelized either.
A third issue is that with CompletableFuture, there are
occasions where it will create a new native thread per
task, for example when we run on a single-core or
dual-core machine. This can happen even today if we run
inside a Docker container. If I run a factorial(1_000_000)
calculation on my 1-8-2 laptop, it uses exactly 16 threads
for the CompletableFuture calculations. But if I
set the cores to one with -XX:ActiveProcessorCount=1
,
then it fires up 1_000_000 threads and takes 60x longer than
the sequential version.
Even though the CompletableFuture solution is easy enough to understand, I find it even better to code with RecursiveTask directly. That way, we are not subject to the vagaries of the CompletableFuture implemention. It will also solve the issue of an unbounded number of native threads being created to do the async calculation:
import java.math.BigInteger; import java.util.concurrent.*; import java.util.function.*; public class FactorialForkJoin extends AbstractFactorial { public FactorialForkJoin( BinaryOperator<BigInteger> multiply) { super(multiply); } @Override public BigInteger calculate(int n) { if (n < 0) throw new IllegalArgumentException("n < 0"); if (n == 0) return BigInteger.ONE; return new FactorialTask(1, n).invoke(); } private class FactorialTask extends RecursiveTask<BigInteger> { private final int from, to; public FactorialTask(int from, int to) { this.from = from; this.to = to; } @Override protected BigInteger compute() { if (from == to) return BigInteger.valueOf(from); int mid = (from + to) >>> 1; var leftTask = new FactorialTask(from, mid); var rightTask = new FactorialTask(mid + 1, to); leftTask.fork(); BigInteger right = rightTask.invoke(); BigInteger left = leftTask.join(); return multiply(left, right); } } }
With this calculation, if we set the number of cores to 1 using
-XX:ActiveProcessorCount=1
, or
even if we turn off the common ForkJoinPool altogether with
-Djava.util.concurrent.ForkJoinPool.common.parallelism=0
,
it will then simply use the current thread for the
calculation.
Let's try solve the first issue of overeagerness to create parallel tasks by setting a threshold. One example of how to do that is shown in the RecursiveTask Javadoc. In the past they used a poor implementation of Fibonacci:
class Fibonacci extends RecursiveTask<Integer> { final int n; Fibonacci(int n) { this.n = n; } protected Integer compute() { if (n <= 1) return n; Fibonacci f1 = new Fibonacci(n - 1); f1.fork(); Fibonacci f2 = new Fibonacci(n - 2); return f2.compute() + f1.join(); } }
They did mention that we should "pick some minimum granularity size (for example 10 here) for which we always sequentially solve rather than subdividing." I recommended that we change the example to Factorial, which they did in Java 19.
However, in the example they show, they have the threshold the wrong way round. Instead of looking at how large a chunk of work we have to solve, we should instead look at how much work we can realistically do in parallel. Thus the number of available threads should determine the number of forked chunks.
I mentioned earlier that in talks around the world, I showed how we could parallelize the BigInteger multiply() method. Towards the end of last year, I decided to try to contribute that idea to the OpenJDK. There are a few hurdles to jump through if you want to make a contribution. It is outlined nicely on the OpenJDK website under How to Contribute. Here it is in my words:
For example, when I wanted to add the parallelMultiply() method to the BigInteger class, I logged bug 8277175. Paul Sandoz kindly agreed to sponsor the patch and we added that into Java 19. Our final solution was a group effort. You can see intermediate stages here. To get there took a bit of back and forth, over several months. The actual code for parallelizing the multiply() method was quick to write. I had coded this with my eyes closed several times during conference talks. However, I also spent time on creating unit tests and benchmarks, plus discussing and incorporating the suggestions from the OpenJDK maintainers.
The part which I particularly like about our final solution
is that we parallelized according to number of processors,
rather than just by task size. This means that we end up
forking a lot less tasks when it just doesn't make sense.
We do this by only parallelizing up to a certain depth of
recursion. The maximum depth is by default determined by the
common pool parallelism. The actual parallelization happens
in the RecursiveOp class. Since this is new code, we used
instanceof
pattern matching and made
it sealed
.
@SuppressWarnings("serial") private abstract static sealed class RecursiveOp extends RecursiveTask<BigInteger> { /** * The threshold until when we should continue forking * recursive ops if parallel is true. This threshold is only * relevant for Toom Cook 3 multiply and square. */ private static final int PARALLEL_FORK_DEPTH_THRESHOLD = calculateMaximumDepth(ForkJoinPool.getCommonPoolParallelism()); private static final int calculateMaximumDepth(int parallelism) { return 32 - Integer.numberOfLeadingZeros(parallelism); } final boolean parallel; /** * The current recursing depth. Since it is a logarithmic * algorithm, we do not need an int to hold the number. */ final byte depth; private RecursiveOp(boolean parallel, int depth) { this.parallel = parallel; this.depth = (byte) depth; } private static int getParallelForkDepthThreshold() { if (Thread.currentThread() instanceof ForkJoinWorkerThread fjwt) { return calculateMaximumDepth(fjwt.getPool().getParallelism()); } else { return PARALLEL_FORK_DEPTH_THRESHOLD; } } protected RecursiveTask<BigInteger> forkOrInvoke() { if (parallel && depth <= getParallelForkDepthThreshold()) fork(); else invoke(); return this; } @SuppressWarnings("serial") private static final class RecursiveMultiply extends RecursiveOp { private final BigInteger a; private final BigInteger b; public RecursiveMultiply(BigInteger a, BigInteger b, boolean parallel, int depth) { super(parallel, depth); this.a = a; this.b = b; } @Override public BigInteger compute() { return a.multiply(b, true, parallel, depth); } } @SuppressWarnings("serial") private static final class RecursiveSquare extends RecursiveOp { private final BigInteger a; public RecursiveSquare(BigInteger a, boolean parallel, int depth) { super(parallel, depth); this.a = a; } @Override public BigInteger compute() { return a.square(true, parallel, depth); } } private static RecursiveTask<BigInteger> multiply( BigInteger a, BigInteger b, boolean parallel, int depth) { return new RecursiveMultiply(a, b, parallel, depth).forkOrInvoke(); } private static RecursiveTask<BigInteger> square( BigInteger a, boolean parallel, int depth) { return new RecursiveSquare(a, parallel, depth).forkOrInvoke(); } }
Let's run our FactorialDemo and have a look at the three mechanisms. First we try FactorialSequential, both with multiply() and parallelMultiply(). Next we try FactorialCompletableFuture, and lastly FactorialForkJoin.
import java.math.BigInteger; import java.util.*; import java.util.function.*; import java.util.stream.*; public class FactorialDemo { public static void main(String... args) { record Pair(String description, BinaryOperator<BigInteger> operator) {} List<Pair> multiplies = List.of( new Pair("multiply", BigInteger::multiply), new Pair("parallelMultiply", BigInteger::parallelMultiply) ); List<Function<BinaryOperator<BigInteger>, Factorial>> funcs = List.of( FactorialSequential::new, FactorialCompletableFuture::new, FactorialForkJoin::new ); for (var func : funcs) { var factorial = func.apply(BigInteger::multiply); System.out.println(factorial.getClass()); IntStream.rangeClosed(0, 10) .mapToObj(factorial::calculate) .forEach(System.out::println); System.out.println(); } for (int i = 0; i < 10; i++) { for (var func : funcs) { for (var multiply : multiplies) { Factorial factorial = func.apply(multiply.operator()); System.out.print(factorial.getClass() + " with " + multiply.description()); long time = System.nanoTime(); try { BigInteger fac1m = factorial.calculate(1_000_000); } finally { time = System.nanoTime() - time; System.out.printf(" time = %dms%n", (time / 1_000_000)); } } } System.out.println(); } } }
Here is the abbreviated output of the run. I ran it on my 1-6-2 server, meaning single socket, 6 cores, with 2 hyperthreads per core. I have only kept the last result for display in this newsletter:
class FactorialSequential 1 1 2 6 24 120 720 5040 40320 362880 3628800 class FactorialCompletableFuture 1 1 2 6 24 120 720 5040 40320 362880 3628800 class FactorialForkJoin 1 1 2 6 24 120 720 5040 40320 362880 3628800 *snip* class FactorialSequential with multiply time = 2591ms class FactorialSequential with parallelMultiply time = 763ms class FactorialCompletableFuture with multiply time = 1484ms class FactorialCompletableFuture with parallelMultiply time = 689ms class FactorialForkJoin with multiply time = 1416ms class FactorialForkJoin with parallelMultiply time = 627ms
As we can see, the slowest is, not surprisingly, the single threaded FactorialSequential with multiply(), and fastest is our FactorialForkJoin with parallelMultiply(). What might be surprising to some is how good FactorialSequential does with the parallelMultiply(). It is just a bit slower than the ForkJoin solution with parallel multiplication.
Kind regards from Chorafakia
Heinz
P.S. We are still running our Black Friday Campaign until the 4th of December. Excellent opportunity to invest in your further Java education.
We are always happy to receive comments from our readers. Feel free to send me a comment via email or discuss the newsletter in our JavaSpecialists Slack Channel (Get an invite here)
We deliver relevant courses, by top Java developers to produce more resourceful and efficient programmers within their organisations.