Skip to content

Commit

Permalink
Simplify and fix bugs in StreamUtil (#551)
Browse files Browse the repository at this point in the history
* Simplify and fix bugs in StreamUtil

* Fix checkstyle

---------

Co-authored-by: Karthik Ramgopal <kramgopa@linkedin.com>
  • Loading branch information
karthikrg and li-kramgopa authored Feb 16, 2024
1 parent 9f838a0 commit 2d0eab0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 38 deletions.
2 changes: 2 additions & 0 deletions avro-builder/builder-spi/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ dependencies {
implementation "org.apache.logging.log4j:log4j-api:2.17.1"
implementation "commons-io:commons-io:2.11.0"
implementation "jakarta.json:jakarta.json-api:2.0.1"
implementation "com.pivovarit:parallel-collectors:2.5.0"

testImplementation project(":test-common")
testImplementation "org.apache.avro:avro:1.9.2"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

package com.linkedin.avroutil1.builder.util;

import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import com.pivovarit.collectors.ParallelCollectors;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -52,7 +51,7 @@ private StreamUtil() {
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel.
*/
public static <T, R> Collector<T, ?, Stream<R>> toParallelStream(Function<T, R> mapper, int parallelism) {
return toParallelStream(mapper, parallelism, 1);
return ParallelCollectors.parallelToStream(mapper, WORK_EXECUTOR, parallelism);
}

/**
Expand All @@ -72,8 +71,9 @@ private StreamUtil() {
*/
public static <T, R> Collector<T, ?, Stream<R>> toParallelStream(Function<T, R> mapper, int parallelism,
int batchSize) {
if (parallelism <= 0 || batchSize <= 0) {
throw new IllegalArgumentException("Parallelism and batch size must be >= 1");
// When batch size is 1, fallback to toParallelStream
if (batchSize == 1) {
return toParallelStream(mapper, parallelism);
}

return Collectors.collectingAndThen(Collectors.toList(), list -> {
Expand All @@ -85,39 +85,16 @@ private StreamUtil() {
return list.stream().map(mapper);
}

final Executor limitingExecutor = new LimitingExecutor(parallelism);
final int batchCount = (list.size() - 1) / batchSize;
return IntStream.rangeClosed(0, batchCount)
.mapToObj(batch -> {
int startIndex = batch * batchSize;
int endIndex = (batch == batchCount) ? list.size() : (batch + 1) * batchSize;
return list.subList(startIndex, endIndex);
})
.map(batch -> CompletableFuture.supplyAsync(() -> batch.stream().map(mapper).collect(Collectors.toList()),
limitingExecutor))
.map(CompletableFuture::join)
.flatMap(Collection::stream);
});
}

private final static class LimitingExecutor implements Executor {
final Function<List<T>, List<R>> batchingMapper =
batch -> batch.stream().map(mapper).collect(Collectors.toList());
List<List<T>> sublists = IntStream.rangeClosed(0, batchCount).mapToObj(batch -> {
int startIndex = batch * batchSize;
int endIndex = (batch == batchCount) ? list.size() : (batch + 1) * batchSize;
return list.subList(startIndex, endIndex);
}).collect(Collectors.toList());

private final Semaphore _limiter;

private LimitingExecutor(int maxParallelism) {
_limiter = new Semaphore(maxParallelism);
}

@Override
public void execute(Runnable command) {
try {
_limiter.acquire();
WORK_EXECUTOR.execute(command);
} catch (InterruptedException e) {
throw new RuntimeException(e);
} finally {
_limiter.release();
}
}
return sublists.stream().collect(toParallelStream(batchingMapper, parallelism)).flatMap(List::stream);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2024 LinkedIn Corp.
* Licensed under the BSD 2-Clause License (the "License").
* See License in the project root for license information.
*/

package com.linkedin.avroutil1.builder.util;

import java.util.stream.IntStream;
import org.testng.Assert;
import org.testng.annotations.Test;


/**
* This is to test {@link StreamUtil}
*/
public class StreamUtilTest {

@Test
public void testParallelStreaming() {
int result = IntStream.rangeClosed(1, 100)
.boxed()
.collect(StreamUtil.toParallelStream(x -> x * x, 3, 4))
.reduce(0, Integer::sum);

int expected = IntStream.rangeClosed(1, 100).map(x -> x * x).sum();

Assert.assertEquals(result, expected);
}
}

0 comments on commit 2d0eab0

Please sign in to comment.