diff --git a/example/datasets/images/cyborg.png b/example/datasets/images/cyborg.png new file mode 100644 index 0000000..a49342a Binary files /dev/null and b/example/datasets/images/cyborg.png differ diff --git a/example/datasets/images/image.jpeg b/example/datasets/images/image.jpeg new file mode 100644 index 0000000..0390098 Binary files /dev/null and b/example/datasets/images/image.jpeg differ diff --git a/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java b/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java index 7ef265b..62b0e84 100644 --- a/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java +++ b/lib/src/main/java/de/edux/functions/imputation/ImputationStrategy.java @@ -12,6 +12,8 @@ public enum ImputationStrategy { */ AVERAGE(new AverageImputation()), + MEDIAN(new MedianImputation()), + /** * Imputation strategy that replaces missing values with the most frequently occurring value * (mode) in the dataset column. This strategy can be used for both numerical and categorical diff --git a/lib/src/main/java/de/edux/functions/imputation/MedianImputation.java b/lib/src/main/java/de/edux/functions/imputation/MedianImputation.java new file mode 100644 index 0000000..271c20d --- /dev/null +++ b/lib/src/main/java/de/edux/functions/imputation/MedianImputation.java @@ -0,0 +1,61 @@ +package de.edux.functions.imputation; + +import java.util.Arrays; + +/** + * Implements the {@code IImputationStrategy} interface to provide a median value imputation. This + * strategy calculates the median of the non-missing numeric values in a column and substitutes the + * missing values with this median. + * + *

It is important to note that this strategy is only applicable to columns with numeric data. + * Attempting to use this strategy on categorical data will result in a {@code RuntimeException}. + */ +public class MedianImputation implements IImputationStrategy { + @Override + public String[] performImputation(String[] datasetColumn) { + checkIfColumnContainsCategoricalValues(datasetColumn); + + String[] updatedDatasetColumn = new String[datasetColumn.length]; + double median = calculateMedian(datasetColumn); + + for (int index = 0; index < datasetColumn.length; index++) { + if (datasetColumn[index].isBlank()) { + updatedDatasetColumn[index] = String.valueOf(median); + + } else { + updatedDatasetColumn[index] = datasetColumn[index]; + } + } + + return updatedDatasetColumn; + } + + private void checkIfColumnContainsCategoricalValues(String[] datasetColumn) { + for (String value : datasetColumn) { + if (!isNumeric(value)) { + throw new RuntimeException( + "MEDIAN imputation strategy can not be used on categorical features. " + + "Use MODE imputation strategy or perform a list wise deletion on the features."); + } + } + } + + private boolean isNumeric(String value) { + return value.matches("-?\\d+(\\.\\d+)?") || value.isBlank(); + } + + double calculateMedian(String[] datasetColumn) { + double[] filteredDatasetColumnInNumbers = Arrays.stream(datasetColumn) + .filter(value -> !value.isBlank()) + .mapToDouble(Double::parseDouble) + .sorted() + .toArray(); + if (filteredDatasetColumnInNumbers.length % 2 == 0) { + Double upper = filteredDatasetColumnInNumbers[filteredDatasetColumnInNumbers.length / 2]; + Double lower = + filteredDatasetColumnInNumbers[(filteredDatasetColumnInNumbers.length / 2) - 1]; + return (upper + lower) / 2.0; + } + return filteredDatasetColumnInNumbers[filteredDatasetColumnInNumbers.length / 2]; + } +} diff --git a/lib/src/test/java/de/edux/functions/imputation/MedianImputationTest.java b/lib/src/test/java/de/edux/functions/imputation/MedianImputationTest.java new file mode 100644 index 0000000..d5857a1 --- /dev/null +++ b/lib/src/test/java/de/edux/functions/imputation/MedianImputationTest.java @@ -0,0 +1,69 @@ +package de.edux.functions.imputation; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; +import java.util.Random; +import org.junit.jupiter.api.Test; + +public class MedianImputationTest { + @Test + void performImputationWithCategoricalValuesShouldThrowRuntimeException() { + String[] categoricalFeatures = {"A", "B", "C"}; + assertThrows( + RuntimeException.class, + () -> new MedianImputation().performImputation(categoricalFeatures)); + } + + @Test + void performImputationWithNumericalValuesTest() { + String[] numericalFeaturesWithMissingValues = {"1", "", "2", "3", "", "4"}; + MedianImputation imputter = new MedianImputation(); + String[] numericalFeaturesWithImputtedValues = + imputter.performImputation(numericalFeaturesWithMissingValues); + assertAll( + () -> assertEquals("2.5", numericalFeaturesWithImputtedValues[1]), + () -> assertEquals("2.5", numericalFeaturesWithImputtedValues[4])); + } + + @Test + public void testCalculateMedianWithLargeDataset() { + String[] largeDataset = new String[1000000]; + Random random = new Random(); + for (int i = 0; i < largeDataset.length; i++) { + if (random.nextDouble() < 0.05) { // 5% empty values + largeDataset[i] = ""; + } else { + largeDataset[i] = String.valueOf(random.nextDouble() * 1000000); + } + } + + // Erwarteter Median + double[] numericValues = + Arrays.stream(largeDataset) + .filter(s -> !s.isBlank()) + .mapToDouble(Double::parseDouble) + .sorted() + .toArray(); + double expectedMedian = + numericValues.length % 2 == 0 + ? (numericValues[numericValues.length / 2] + + numericValues[numericValues.length / 2 - 1]) + / 2.0 + : numericValues[numericValues.length / 2]; + + MedianImputation medianImputation = new MedianImputation(); + + long startTime = System.nanoTime(); + double calculatedMedian = medianImputation.calculateMedian(largeDataset); + long endTime = System.nanoTime(); + + System.out.println("Process time in seconds: " + (endTime - startTime) / 1e9); + + assertEquals( + expectedMedian, + calculatedMedian, + 0.001, + "Calculated median should be equal to the expected median."); + } +}