Skip to content

Commit

Permalink
[dtypes] Numerical Formatter Fails to Learn Format of New Data Types (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Aug 7, 2024
1 parent 1c07c22 commit 0517e68
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 128 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
'copulas>=0.11.0',
'ctgan>=0.10.0',
'deepecho>=0.6.0',
'rdt>=1.12.0',
'rdt @ git+https://github.com/sdv-dev/RDT@main',
'sdmetrics>=0.14.0',
'platformdirs>=4.0',
'pyyaml>=6.0.1',
Expand Down
41 changes: 5 additions & 36 deletions sdv/data_processing/numerical_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
import sys

import numpy as np
import pandas as pd
from rdt.transformers.utils import learn_rounding_digits

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,34 +51,6 @@ def __init__(
self.enforce_min_max_values = enforce_min_max_values
self.computer_representation = computer_representation

@staticmethod
def _learn_rounding_digits(data):
"""Check if data has any decimals."""
name = data.name
data = np.array(data)
roundable_data = data[~(np.isinf(data) | pd.isna(data))]

# Doesn't contain numbers
if len(roundable_data) == 0:
return None

# Doesn't contain decimal digits
if ((roundable_data % 1) == 0).all():
return 0

# Try to round to fewer digits
if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
for decimal in range(MAX_DECIMALS + 1):
if (roundable_data == roundable_data.round(decimal)).all():
return decimal

# Can't round, not equal after MAX_DECIMALS digits of precision
LOGGER.info(
f"No rounding scheme detected for column '{name}'."
' Synthetic data will not be rounded.'
)
return None

def learn_format(self, column):
"""Learn the format of a column.
Expand All @@ -92,7 +64,7 @@ def learn_format(self, column):
self._max_value = column.max()

if self.enforce_rounding:
self._rounding_digits = self._learn_rounding_digits(column)
self._rounding_digits = learn_rounding_digits(column)

def format_data(self, column):
"""Format a column according to the learned format.
Expand All @@ -105,20 +77,17 @@ def format_data(self, column):
numpy.ndarray:
containing the formatted data.
"""
column = column.copy().to_numpy()
column = column.copy()
if self.enforce_min_max_values:
column = column.clip(self._min_value, self._max_value)
elif self.computer_representation != 'Float':
elif not self.computer_representation.startswith('Float'):
min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation]
column = column.clip(min_bound, max_bound)

is_integer = np.dtype(self._dtype).kind == 'i'
is_integer = pd.api.types.is_integer_dtype(self._dtype)
if self.enforce_rounding and self._rounding_digits is not None:
column = column.round(self._rounding_digits)
elif is_integer:
column = column.round(0)

if pd.isna(column).any() and is_integer:
return column

return column.astype(self._dtype)
2 changes: 2 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class SingleTableMetadata:
}

_NUMERICAL_REPRESENTATIONS = frozenset([
'Float32',
'Float64',
'Float',
'Int64',
'Int32',
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,37 @@ def test_datetime_values_inside_real_data_range():
assert check_in_synthetic.max() <= check_in_real.max()
assert check_out_synthetic.min() >= check_out_real.min()
assert check_out_synthetic.max() <= check_out_real.max()


def test_support_new_pandas_dtypes():
"""Test that the synthesizer supports the nullable numerical pandas dtypes."""
# Setup
data = pd.DataFrame({
'Int8': pd.Series([1, 2, -3, pd.NA], dtype='Int8'),
'Int16': pd.Series([1, 2, -3, pd.NA], dtype='Int16'),
'Int32': pd.Series([1, 2, -3, pd.NA], dtype='Int32'),
'Int64': pd.Series([1, 2, pd.NA, -3], dtype='Int64'),
'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'),
'Float64': pd.Series([1.113, 2.22, 3.3, pd.NA], dtype='Float64'),
})
metadata = SingleTableMetadata().load_from_dict({
'columns': {
'Int8': {'sdtype': 'numerical', 'computer_representation': 'Int8'},
'Int16': {'sdtype': 'numerical', 'computer_representation': 'Int16'},
'Int32': {'sdtype': 'numerical', 'computer_representation': 'Int32'},
'Int64': {'sdtype': 'numerical', 'computer_representation': 'Int64'},
'Float32': {'sdtype': 'numerical', 'computer_representation': 'Float32'},
'Float64': {'sdtype': 'numerical', 'computer_representation': 'Float64'},
}
})

synthesizer = GaussianCopulaSynthesizer(metadata)

# Run
synthesizer.fit(data)
synthetic_data = synthesizer.sample(10)

# Assert
assert (synthetic_data.dtypes == data.dtypes).all()
assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True)
assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True)
92 changes: 1 addition & 91 deletions tests/unit/data_processing/test_numerical_formatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock

import numpy as np
import pandas as pd
Expand All @@ -20,96 +20,6 @@ def test___init__(self):
assert formatter.enforce_min_max_values is True
assert formatter.computer_representation == 'Int8'

@patch('sdv.data_processing.numerical_formatter.LOGGER')
def test__learn_rounding_digits_more_than_15_decimals(self, log_mock):
"""Test the ``_learn_rounding_digits`` method with more than 15 decimals.
If the data has more than 15 decimals, return None and use ``LOGGER`` to inform the user.
"""
# Setup
data = pd.Series(np.random.random(size=10).round(20), name='col')

# Run
output = NumericalFormatter._learn_rounding_digits(data)

# Assert
log_msg = (
"No rounding scheme detected for column 'col'. Synthetic data will not be rounded."
)
log_mock.info.assert_called_once_with(log_msg)
assert output is None

def test__learn_rounding_digits_less_than_15_decimals(self):
"""Test the ``_learn_rounding_digits`` method with less than 15 decimals.
If the data has less than 15 decimals, the maximum number of decimals should be returned.
Input:
- an array that contains floats with a maximum of 3 decimals and a NaN.
Output:
- 3
"""
# Setup
data = pd.Series(np.array([10, 0.0, 0.1, 0.12, 0.123, np.nan]))

# Run
output = NumericalFormatter._learn_rounding_digits(data)

# Assert
assert output == 3

def test__learn_rounding_digits_negative_decimals_float(self):
"""Test the ``_learn_rounding_digits`` method with floats multiples of powers of 10.
If the data has all multiples of 10 the output should be None.
Input:
- an array that contains floats that are multiples of 10, 100 and 1000 and a NaN.
"""
# Setup
data = pd.Series(np.array([1230.0, 12300.0, 123000.0, np.nan]))

# Run
output = NumericalFormatter._learn_rounding_digits(data)

# Assert
assert output == 0

def test__learn_rounding_digits_negative_decimals_integer(self):
"""Test the ``_learn_rounding_digits`` method with integers multiples of powers of 10.
If the data has all multiples of 10 the output should be None.
Input:
- an array that contains integers that are multiples of 10, 100 and 1000 and a NaN.
"""
# Setup
data = pd.Series(np.array([1230, 12300, 123000, np.nan]))

# Run
output = NumericalFormatter._learn_rounding_digits(data)

# Assert
assert output == 0

def test__learn_rounding_digits_all_nans(self):
"""Test the ``_learn_rounding_digits`` method with data that is all NaNs.
If the data is all NaNs, expect that the output is 0.
Input:
- an array of NaN.
"""
# Setup
data = pd.Series(np.array([np.nan, np.nan, np.nan, np.nan]))

# Run
output = NumericalFormatter._learn_rounding_digits(data)

# Assert
assert output is None

def test_learn_format(self):
"""Test that ``learn_format`` method.
Expand Down

0 comments on commit 0517e68

Please sign in to comment.