diff --git a/avro-builder/builder-spi/build.gradle b/avro-builder/builder-spi/build.gradle index d66b3306..6fd27db7 100644 --- a/avro-builder/builder-spi/build.gradle +++ b/avro-builder/builder-spi/build.gradle @@ -17,7 +17,9 @@ dependencies { implementation "org.apache.logging.log4j:log4j-api:2.17.1" implementation "commons-io:commons-io:2.11.0" implementation "jakarta.json:jakarta.json-api:2.0.1" + implementation "com.pivovarit:parallel-collectors:2.5.0" + testImplementation project(":test-common") testImplementation "org.apache.avro:avro:1.9.2" } diff --git a/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java b/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java index 77d6fb16..3d5b047e 100644 --- a/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java +++ b/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java @@ -6,11 +6,10 @@ package com.linkedin.avroutil1.builder.util; -import java.util.Collection; -import java.util.concurrent.CompletableFuture; +import com.pivovarit.collectors.ParallelCollectors; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Semaphore; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -52,7 +51,7 @@ private StreamUtil() { * @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel. */ public static Collector> toParallelStream(Function mapper, int parallelism) { - return toParallelStream(mapper, parallelism, 1); + return ParallelCollectors.parallelToStream(mapper, WORK_EXECUTOR, parallelism); } /** @@ -72,8 +71,9 @@ private StreamUtil() { */ public static Collector> toParallelStream(Function mapper, int parallelism, int batchSize) { - if (parallelism <= 0 || batchSize <= 0) { - throw new IllegalArgumentException("Parallelism and batch size must be >= 1"); + // When batch size is 1, fallback to toParallelStream + if (batchSize == 1) { + return toParallelStream(mapper, parallelism); } return Collectors.collectingAndThen(Collectors.toList(), list -> { @@ -85,39 +85,16 @@ private StreamUtil() { return list.stream().map(mapper); } - final Executor limitingExecutor = new LimitingExecutor(parallelism); final int batchCount = (list.size() - 1) / batchSize; - return IntStream.rangeClosed(0, batchCount) - .mapToObj(batch -> { - int startIndex = batch * batchSize; - int endIndex = (batch == batchCount) ? list.size() : (batch + 1) * batchSize; - return list.subList(startIndex, endIndex); - }) - .map(batch -> CompletableFuture.supplyAsync(() -> batch.stream().map(mapper).collect(Collectors.toList()), - limitingExecutor)) - .map(CompletableFuture::join) - .flatMap(Collection::stream); - }); - } - - private final static class LimitingExecutor implements Executor { + final Function, List> batchingMapper = + batch -> batch.stream().map(mapper).collect(Collectors.toList()); + List> sublists = IntStream.rangeClosed(0, batchCount).mapToObj(batch -> { + int startIndex = batch * batchSize; + int endIndex = (batch == batchCount) ? list.size() : (batch + 1) * batchSize; + return list.subList(startIndex, endIndex); + }).collect(Collectors.toList()); - private final Semaphore _limiter; - - private LimitingExecutor(int maxParallelism) { - _limiter = new Semaphore(maxParallelism); - } - - @Override - public void execute(Runnable command) { - try { - _limiter.acquire(); - WORK_EXECUTOR.execute(command); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } finally { - _limiter.release(); - } - } + return sublists.stream().collect(toParallelStream(batchingMapper, parallelism)).flatMap(List::stream); + }); } } diff --git a/avro-builder/builder-spi/src/test/java/com/linkedin/avroutil1/builder/util/StreamUtilTest.java b/avro-builder/builder-spi/src/test/java/com/linkedin/avroutil1/builder/util/StreamUtilTest.java new file mode 100644 index 00000000..93d704fe --- /dev/null +++ b/avro-builder/builder-spi/src/test/java/com/linkedin/avroutil1/builder/util/StreamUtilTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 LinkedIn Corp. + * Licensed under the BSD 2-Clause License (the "License"). + * See License in the project root for license information. + */ + +package com.linkedin.avroutil1.builder.util; + +import java.util.stream.IntStream; +import org.testng.Assert; +import org.testng.annotations.Test; + + +/** + * This is to test {@link StreamUtil} + */ +public class StreamUtilTest { + + @Test + public void testParallelStreaming() { + int result = IntStream.rangeClosed(1, 100) + .boxed() + .collect(StreamUtil.toParallelStream(x -> x * x, 3, 4)) + .reduce(0, Integer::sum); + + int expected = IntStream.rangeClosed(1, 100).map(x -> x * x).sum(); + + Assert.assertEquals(result, expected); + } +} \ No newline at end of file