From 0d0957aff93dae6339f4d4d0eacf5442fccecf06 Mon Sep 17 00:00:00 2001 From: Todd Ginsberg Date: Mon, 15 Jul 2024 14:26:39 -0400 Subject: [PATCH] Running and trailing average of BigDecimal + Add general WithOriginal object and Gatherer to carry an original value and a mapped value (an average, in this case). + Add function to GathererUtils to throw an IllegalArgumentException when a parameter is null, rather than use Objects.requireNotNull and get a NullPointerException. + Averaging of BigDecimal with ability to specify a positive number of trailing values to consider, whether to emit partially calculated values for a trailing average, and the ability to change the RoudingMode and MathContext for mathematical operations. --- README.md | 22 ++ .../AveragingBigDecimalGatherer.java | 147 +++++++ .../ginsberg/gatherers4j/GathererUtils.java | 9 +- .../com/ginsberg/gatherers4j/Gatherers4j.java | 23 ++ .../ginsberg/gatherers4j/WithOriginal.java | 20 + .../gatherers4j/WithOriginalGatherer.java | 62 +++ .../AveragingBigDecimalGathererTest.java | 373 ++++++++++++++++++ .../gatherers4j/GathererUtilsTest.java | 18 +- 8 files changed, 672 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/ginsberg/gatherers4j/AveragingBigDecimalGatherer.java create mode 100644 src/main/java/com/ginsberg/gatherers4j/WithOriginal.java create mode 100644 src/main/java/com/ginsberg/gatherers4j/WithOriginalGatherer.java create mode 100644 src/test/java/com/ginsberg/gatherers4j/AveragingBigDecimalGathererTest.java diff --git a/README.md b/README.md index bede1ed..807d997 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,28 @@ TBD, once I start publishing snapshots to Maven Central. (Example, TODO clean this up) +**Running Average** + +```java +Stream + .of(new BigDecimal("1.0"), new BigDecimal("2.0"), new BigDecimal("10.0")) + .gather(Gatherers4j.averageBigDecimals()) + .toList(); + +// [1, 1.5, 4.3333333333333333] +``` + +**Trailing Average** + +```java +Stream + .of(new BigDecimal("1.0"), new BigDecimal("2.0"), new BigDecimal("10.0"), new BigDecimal("20.0"), new BigDecimal("30.0")) + .gather(Gatherers4j.averageBigDecimals().trailing(2)) + .toList(); + +// [1.5, 6, 15, 25] +``` + **Removing consecutive duplicate elements:** diff --git a/src/main/java/com/ginsberg/gatherers4j/AveragingBigDecimalGatherer.java b/src/main/java/com/ginsberg/gatherers4j/AveragingBigDecimalGatherer.java new file mode 100644 index 0000000..b5dfc25 --- /dev/null +++ b/src/main/java/com/ginsberg/gatherers4j/AveragingBigDecimalGatherer.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.math.RoundingMode; +import java.util.Arrays; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Gatherer; + +import static com.ginsberg.gatherers4j.GathererUtils.mustNotBeNull; + +public class AveragingBigDecimalGatherer + implements Gatherer { + + private final Function mappingFunction; + private RoundingMode roundingMode = RoundingMode.HALF_UP; + private MathContext mathContext = MathContext.DECIMAL64; + private BigDecimal nullReplacement; + private int trailingCount = 1; + private boolean includePartialValues; + + AveragingBigDecimalGatherer(final Function mappingFunction) { + super(); + this.mappingFunction = mappingFunction; + } + + @Override + public Supplier initializer() { + return trailingCount == 1 ? State::new : () -> new TrailingState(trailingCount); + } + + @Override + public Integrator integrator() { + return (state, element, downstream) -> { + final BigDecimal mappedElement = element == null ? nullReplacement : mappingFunction.apply(element); + if (mappedElement != null) { + state.add(mappedElement, mathContext); + if (state.canCalculate(includePartialValues)) { + return downstream.push(state.average(roundingMode, mathContext.getPrecision())); + } + } + return !downstream.isRejecting(); + }; + } + + public AveragingBigDecimalGatherer trailing(int count) { + if (count <= 0) { + throw new IllegalArgumentException("Trailing count must be positive"); + } + trailingCount = count; + return this; + } + + public AveragingBigDecimalGatherer includePartialTailingValues() { + includePartialValues = true; + return this; + } + + public AveragingBigDecimalGatherer treatNullAsZero() { + return treatNullAs(BigDecimal.ZERO); + } + + public AveragingBigDecimalGatherer treatNullAs(final BigDecimal rule) { + this.nullReplacement = rule; + return this; + } + + public AveragingBigDecimalGatherer withMathContext(final MathContext mathContext) { + mustNotBeNull(mathContext, "MathContext must not be null"); + this.mathContext = mathContext; + return this; + } + + public AveragingBigDecimalGatherer withRoundingMode(final RoundingMode roundingMode) { + mustNotBeNull(roundingMode, "RoundingMode must not be null"); + this.roundingMode = roundingMode; + return this; + } + + public WithOriginalGatherer withOriginal() { + return new WithOriginalGatherer<>(this); + } + + public static class State { + long count; + BigDecimal sum = BigDecimal.ZERO; + + void add(final BigDecimal element, final MathContext mathContext) { + count++; + sum = sum.add(element, mathContext); + } + + boolean canCalculate(final boolean allowPartial) { + return true; + } + + BigDecimal average(final RoundingMode roundingMode, int precision) { + if (sum.equals(BigDecimal.ZERO)) { + return BigDecimal.ZERO; + } else { + return sum.divide(BigDecimal.valueOf(count), precision, roundingMode); + } + } + } + + public static class TrailingState extends State { + final BigDecimal[] series; + int index = 0; + + private TrailingState(int lookBack) { + this.series = new BigDecimal[lookBack]; + Arrays.fill(series, BigDecimal.ZERO); + } + + @Override + boolean canCalculate(final boolean allowPartial) { + return allowPartial || count >= series.length; + } + + @Override + void add(final BigDecimal element, final MathContext mathContext) { + sum = sum.subtract(series[index]).add(element, mathContext); + series[index] = element; + index = (index + 1) % series.length; + if (count < series.length) { + count++; + } + } + } +} diff --git a/src/main/java/com/ginsberg/gatherers4j/GathererUtils.java b/src/main/java/com/ginsberg/gatherers4j/GathererUtils.java index 36e39d8..4c6f349 100644 --- a/src/main/java/com/ginsberg/gatherers4j/GathererUtils.java +++ b/src/main/java/com/ginsberg/gatherers4j/GathererUtils.java @@ -16,7 +16,8 @@ package com.ginsberg.gatherers4j; public class GathererUtils { - public static boolean safeEquals(Object left, Object right) { + + public static boolean safeEquals(final Object left, final Object right) { if (left == null && right == null) { return true; } else if (left == null || right == null) { @@ -24,4 +25,10 @@ public static boolean safeEquals(Object left, Object right) { } return left.equals(right); } + + public static void mustNotBeNull(final Object subject, final String message) { + if (subject == null) { + throw new IllegalArgumentException(message); + } + } } diff --git a/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java b/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java index b9b0c30..e7fe61e 100644 --- a/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java +++ b/src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java @@ -16,6 +16,7 @@ package com.ginsberg.gatherers4j; +import java.math.BigDecimal; import java.util.List; import java.util.Objects; import java.util.function.Function; @@ -24,6 +25,28 @@ public class Gatherers4j { + /** + * Create a Stream that is the running average of Stream<BigDecimal> + * + * @return AveragingBigDecimalGatherer + */ + public static AveragingBigDecimalGatherer averageBigDecimals() { + return new AveragingBigDecimalGatherer<>(Function.identity()); + } + + /** + * Create a Stream that is the running average of BigDecimal objects as mapped by + * the given function. This is useful when paired with the withOriginal function. + * + * @param mappingFunction A non-null function to map the INPUT type to BigDecimal + * @return AveragingBigDecimalGatherer + */ + public static AveragingBigDecimalGatherer averageBigDecimalsBy( + final Function mappingFunction + ) { + return new AveragingBigDecimalGatherer<>(mappingFunction); + } + /** *

Given a stream of objects, filter the objects such that any consecutively appearing * after the first one are dropped. diff --git a/src/main/java/com/ginsberg/gatherers4j/WithOriginal.java b/src/main/java/com/ginsberg/gatherers4j/WithOriginal.java new file mode 100644 index 0000000..446cced --- /dev/null +++ b/src/main/java/com/ginsberg/gatherers4j/WithOriginal.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + +public record WithOriginal(ORIGINAL original, CALCULATED calculated) { +} diff --git a/src/main/java/com/ginsberg/gatherers4j/WithOriginalGatherer.java b/src/main/java/com/ginsberg/gatherers4j/WithOriginalGatherer.java new file mode 100644 index 0000000..ac6c9d9 --- /dev/null +++ b/src/main/java/com/ginsberg/gatherers4j/WithOriginalGatherer.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + +import java.util.Deque; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.function.Supplier; +import java.util.stream.Gatherer; + +public class WithOriginalGatherer + implements Gatherer> { + + private final Gatherer delegate; + + WithOriginalGatherer(final Gatherer delegate) { + this.delegate = delegate; + } + + @Override + public Supplier initializer() { + return delegate.initializer(); + } + + @Override + public Integrator> integrator() { + final CapturingDownstream capturingDownstream = new CapturingDownstream<>(); + final Integrator delegateIntegrator = delegate.integrator(); + + return (state, element, downstream) -> { + final boolean response = delegateIntegrator.integrate(state, element, capturingDownstream); + while (!capturingDownstream.captured.isEmpty()) { + downstream.push(new WithOriginal<>(element, capturingDownstream.captured.poll())); + } + return response; + }; + } + + private static class CapturingDownstream implements Downstream { + + private final Deque captured = new ConcurrentLinkedDeque<>(); + + @Override + public boolean push(final OUTPUT capturedElement) { + captured.push(capturedElement); + return true; // Unused + } + } +} diff --git a/src/test/java/com/ginsberg/gatherers4j/AveragingBigDecimalGathererTest.java b/src/test/java/com/ginsberg/gatherers4j/AveragingBigDecimalGathererTest.java new file mode 100644 index 0000000..a4b08d9 --- /dev/null +++ b/src/test/java/com/ginsberg/gatherers4j/AveragingBigDecimalGathererTest.java @@ -0,0 +1,373 @@ +/* + * Copyright 2024 Todd Ginsberg + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ginsberg.gatherers4j; + + +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.math.RoundingMode; +import java.util.List; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class AveragingBigDecimalGathererTest { + + @Test + void averageOfBigDecimals() { + // Arrange + final Stream input = Stream.of( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0") + ); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals()) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("4.3333333333333333") + ); + } + + @Test + void mathContextChange() { + // Arrange + final Stream input = Stream.of( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0") + ); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals().withMathContext(new MathContext(2))) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("4.33") + ); + } + + @Test + void roundingModeChange() { + // Arrange + final Stream input = Stream.of( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0") + ); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals().withRoundingMode(RoundingMode.CEILING)) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("4.3333333333333334") + ); + } + + + @Test + void averageOfZero() { + // Arrange + final Stream input = Stream.of(BigDecimal.ZERO, new BigDecimal("-1"), BigDecimal.ONE); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals()) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + BigDecimal.ZERO, + new BigDecimal("-0.5"), + BigDecimal.ZERO + ); + } + + @Test + void treatNullAsZero() { + // Arrange + final Stream input = Stream.of(null, BigDecimal.ONE, null, BigDecimal.ONE); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals().treatNullAsZero()) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + BigDecimal.ZERO, + new BigDecimal("0.5"), + new BigDecimal("0.3333333333333333"), + new BigDecimal("0.5") + ); + } + + @Test + void ignoresNulls() { + // Arrange + final Stream input = Stream.of(null, BigDecimal.ONE, BigDecimal.TWO); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals()) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + BigDecimal.ONE, + new BigDecimal("1.5") + ); + } + + @Test + void treatNullAsNonZero() { + // Arrange + final Stream input = Stream.of(null, BigDecimal.ONE, null, BigDecimal.ONE); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals().treatNullAs(BigDecimal.TEN)) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + BigDecimal.TEN, + new BigDecimal("5.5"), + new BigDecimal("7"), + new BigDecimal("5.5") + ); + } + + @Test + void trailingAverageOfBigDecimals() { + // Arrange + final Stream input = Stream.of( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0"), + new BigDecimal("20.0"), + new BigDecimal("30.0") + ); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals().trailing(2)) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1.5"), + new BigDecimal("6"), + new BigDecimal("15"), + new BigDecimal("25") + ); + } + + @Test + void trailingAverageOfBigDecimalsWithPartials() { + // Arrange + final Stream input = Stream.of( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0"), + new BigDecimal("20.0"), + new BigDecimal("30.0") + ); + + // Act + final List output = input + .gather(Gatherers4j.averageBigDecimals().trailing(2).includePartialTailingValues()) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("6"), + new BigDecimal("15"), + new BigDecimal("25") + ); + } + + @Test + void withOriginalBigDecimal() { + // Arrange + final Stream input = Stream.of( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0"), + new BigDecimal("20.0"), + new BigDecimal("30.0") + ); + + // Act + final List> output = input + .gather(Gatherers4j.averageBigDecimals().withOriginal()) + .toList(); + + // Assert + assertThat(output) + .map(WithOriginal::calculated) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("4.3333333333333333"), + new BigDecimal("8.25"), + new BigDecimal("12.6") + ); + + assertThat(output) + .map(WithOriginal::original) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1.0"), + new BigDecimal("2.0"), + new BigDecimal("10.0"), + new BigDecimal("20.0"), + new BigDecimal("30.0") + ); + } + + @Test + void withMappedField() { + // Arrange + final List input = List.of( + new TestValueHolder(1, new BigDecimal("1.0")), + new TestValueHolder(2, new BigDecimal("2.0")), + new TestValueHolder(3, new BigDecimal("10.0")), + new TestValueHolder(4, new BigDecimal("20.0")), + new TestValueHolder(5, new BigDecimal("30.0")) + ); + + // Act + final List output = input.stream() + .gather(Gatherers4j.averageBigDecimalsBy(TestValueHolder::value)) + .toList(); + + // Assert + assertThat(output) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("4.3333333333333333"), + new BigDecimal("8.25"), + new BigDecimal("12.6") + ); + } + + @Test + void withOriginalRecordByMappedField() { + // Arrange + final List input = List.of( + new TestValueHolder(1, new BigDecimal("1.0")), + new TestValueHolder(2, new BigDecimal("2.0")), + new TestValueHolder(3, new BigDecimal("10.0")), + new TestValueHolder(4, new BigDecimal("20.0")), + new TestValueHolder(5, new BigDecimal("30.0")) + ); + + // Act + final List> output = input.stream() + .gather(Gatherers4j.averageBigDecimalsBy(TestValueHolder::value).withOriginal()) + .toList(); + + // Assert + assertThat(output) + .map(WithOriginal::calculated) + .usingComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .containsExactly( + new BigDecimal("1"), + new BigDecimal("1.5"), + new BigDecimal("4.3333333333333333"), + new BigDecimal("8.25"), + new BigDecimal("12.6") + ); + + assertThat(output) + .map(WithOriginal::original) + .containsExactlyInAnyOrderElementsOf(input); + } + + @Test + void roundingModeCannotBeNull() { + assertThatThrownBy(() -> + Stream.of(BigDecimal.ONE).gather(Gatherers4j.averageBigDecimals().withRoundingMode(null)) + ).isExactlyInstanceOf(IllegalArgumentException.class); + } + + @Test + void mathContextCannotBeNull() { + assertThatThrownBy(() -> + Stream.of(BigDecimal.ONE).gather(Gatherers4j.averageBigDecimals().withMathContext(null)) + ).isExactlyInstanceOf(IllegalArgumentException.class); + } + + @Test + void trailingInvalidRangeZero() { + assertThatThrownBy(() -> + Stream.of(BigDecimal.ONE).gather(Gatherers4j.averageBigDecimals().trailing(0)) + ).isExactlyInstanceOf(IllegalArgumentException.class); + } + + @Test + void trailingInvalidRangeNegative() { + assertThatThrownBy(() -> + Stream.of(BigDecimal.ONE).gather(Gatherers4j.averageBigDecimals().trailing(-1)) + ).isExactlyInstanceOf(IllegalArgumentException.class); + } + + record TestValueHolder(int id, BigDecimal value) { + } +} \ No newline at end of file diff --git a/src/test/java/com/ginsberg/gatherers4j/GathererUtilsTest.java b/src/test/java/com/ginsberg/gatherers4j/GathererUtilsTest.java index a1a68ef..e51fd20 100644 --- a/src/test/java/com/ginsberg/gatherers4j/GathererUtilsTest.java +++ b/src/test/java/com/ginsberg/gatherers4j/GathererUtilsTest.java @@ -19,7 +19,8 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -import static org.assertj.core.api.Assertions.assertThat; +import static com.ginsberg.gatherers4j.GathererUtils.mustNotBeNull; +import static org.assertj.core.api.Assertions.*; class GathererUtilsTest { @@ -54,4 +55,19 @@ void withLeftNotNullRightNull() { } } + @SuppressWarnings("DataFlowIssue") + @Nested + class MustNotBeNull { + @Test + void whenNull() { + assertThatThrownBy(() -> mustNotBeNull(null, "123")) + .isExactlyInstanceOf(IllegalArgumentException.class) + .hasMessage("123"); + } + + @Test + void whenNotNull() { + assertThatNoException().isThrownBy(() -> mustNotBeNull("NonNull", "123")); + } + } } \ No newline at end of file