diff --git a/pyproject.toml b/pyproject.toml index 026cebf80..91d33107e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ 'copulas>=0.11.0', 'ctgan>=0.10.0', 'deepecho>=0.6.0', - 'rdt>=1.12.0', + 'rdt @ git+https://github.com/sdv-dev/RDT@main', 'sdmetrics>=0.14.0', 'platformdirs>=4.0', 'pyyaml>=6.0.1', diff --git a/sdv/data_processing/numerical_formatter.py b/sdv/data_processing/numerical_formatter.py index fe9f72881..1d7100ba1 100644 --- a/sdv/data_processing/numerical_formatter.py +++ b/sdv/data_processing/numerical_formatter.py @@ -3,8 +3,8 @@ import logging import sys -import numpy as np import pandas as pd +from rdt.transformers.utils import learn_rounding_digits LOGGER = logging.getLogger(__name__) @@ -51,34 +51,6 @@ def __init__( self.enforce_min_max_values = enforce_min_max_values self.computer_representation = computer_representation - @staticmethod - def _learn_rounding_digits(data): - """Check if data has any decimals.""" - name = data.name - data = np.array(data) - roundable_data = data[~(np.isinf(data) | pd.isna(data))] - - # Doesn't contain numbers - if len(roundable_data) == 0: - return None - - # Doesn't contain decimal digits - if ((roundable_data % 1) == 0).all(): - return 0 - - # Try to round to fewer digits - if (roundable_data == roundable_data.round(MAX_DECIMALS)).all(): - for decimal in range(MAX_DECIMALS + 1): - if (roundable_data == roundable_data.round(decimal)).all(): - return decimal - - # Can't round, not equal after MAX_DECIMALS digits of precision - LOGGER.info( - f"No rounding scheme detected for column '{name}'." - ' Synthetic data will not be rounded.' - ) - return None - def learn_format(self, column): """Learn the format of a column. @@ -92,7 +64,7 @@ def learn_format(self, column): self._max_value = column.max() if self.enforce_rounding: - self._rounding_digits = self._learn_rounding_digits(column) + self._rounding_digits = learn_rounding_digits(column) def format_data(self, column): """Format a column according to the learned format. @@ -105,20 +77,17 @@ def format_data(self, column): numpy.ndarray: containing the formatted data. """ - column = column.copy().to_numpy() + column = column.copy() if self.enforce_min_max_values: column = column.clip(self._min_value, self._max_value) - elif self.computer_representation != 'Float': + elif not self.computer_representation.startswith('Float'): min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation] column = column.clip(min_bound, max_bound) - is_integer = np.dtype(self._dtype).kind == 'i' + is_integer = pd.api.types.is_integer_dtype(self._dtype) if self.enforce_rounding and self._rounding_digits is not None: column = column.round(self._rounding_digits) elif is_integer: column = column.round(0) - if pd.isna(column).any() and is_integer: - return column - return column.astype(self._dtype) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index ac77f9ff9..799f7ab2b 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -55,6 +55,8 @@ class SingleTableMetadata: } _NUMERICAL_REPRESENTATIONS = frozenset([ + 'Float32', + 'Float64', 'Float', 'Int64', 'Int32', diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index a2553c90e..8bf756ccf 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -469,3 +469,37 @@ def test_datetime_values_inside_real_data_range(): assert check_in_synthetic.max() <= check_in_real.max() assert check_out_synthetic.min() >= check_out_real.min() assert check_out_synthetic.max() <= check_out_real.max() + + +def test_support_new_pandas_dtypes(): + """Test that the synthesizer supports the nullable numerical pandas dtypes.""" + # Setup + data = pd.DataFrame({ + 'Int8': pd.Series([1, 2, -3, pd.NA], dtype='Int8'), + 'Int16': pd.Series([1, 2, -3, pd.NA], dtype='Int16'), + 'Int32': pd.Series([1, 2, -3, pd.NA], dtype='Int32'), + 'Int64': pd.Series([1, 2, pd.NA, -3], dtype='Int64'), + 'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'), + 'Float64': pd.Series([1.113, 2.22, 3.3, pd.NA], dtype='Float64'), + }) + metadata = SingleTableMetadata().load_from_dict({ + 'columns': { + 'Int8': {'sdtype': 'numerical', 'computer_representation': 'Int8'}, + 'Int16': {'sdtype': 'numerical', 'computer_representation': 'Int16'}, + 'Int32': {'sdtype': 'numerical', 'computer_representation': 'Int32'}, + 'Int64': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'Float32': {'sdtype': 'numerical', 'computer_representation': 'Float32'}, + 'Float64': {'sdtype': 'numerical', 'computer_representation': 'Float64'}, + } + }) + + synthesizer = GaussianCopulaSynthesizer(metadata) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(10) + + # Assert + assert (synthetic_data.dtypes == data.dtypes).all() + assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True) + assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True) diff --git a/tests/unit/data_processing/test_numerical_formatter.py b/tests/unit/data_processing/test_numerical_formatter.py index 826599d9e..ec214db58 100644 --- a/tests/unit/data_processing/test_numerical_formatter.py +++ b/tests/unit/data_processing/test_numerical_formatter.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock import numpy as np import pandas as pd @@ -20,96 +20,6 @@ def test___init__(self): assert formatter.enforce_min_max_values is True assert formatter.computer_representation == 'Int8' - @patch('sdv.data_processing.numerical_formatter.LOGGER') - def test__learn_rounding_digits_more_than_15_decimals(self, log_mock): - """Test the ``_learn_rounding_digits`` method with more than 15 decimals. - - If the data has more than 15 decimals, return None and use ``LOGGER`` to inform the user. - """ - # Setup - data = pd.Series(np.random.random(size=10).round(20), name='col') - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - log_msg = ( - "No rounding scheme detected for column 'col'. Synthetic data will not be rounded." - ) - log_mock.info.assert_called_once_with(log_msg) - assert output is None - - def test__learn_rounding_digits_less_than_15_decimals(self): - """Test the ``_learn_rounding_digits`` method with less than 15 decimals. - - If the data has less than 15 decimals, the maximum number of decimals should be returned. - - Input: - - an array that contains floats with a maximum of 3 decimals and a NaN. - - Output: - - 3 - """ - # Setup - data = pd.Series(np.array([10, 0.0, 0.1, 0.12, 0.123, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output == 3 - - def test__learn_rounding_digits_negative_decimals_float(self): - """Test the ``_learn_rounding_digits`` method with floats multiples of powers of 10. - - If the data has all multiples of 10 the output should be None. - - Input: - - an array that contains floats that are multiples of 10, 100 and 1000 and a NaN. - """ - # Setup - data = pd.Series(np.array([1230.0, 12300.0, 123000.0, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output == 0 - - def test__learn_rounding_digits_negative_decimals_integer(self): - """Test the ``_learn_rounding_digits`` method with integers multiples of powers of 10. - - If the data has all multiples of 10 the output should be None. - - Input: - - an array that contains integers that are multiples of 10, 100 and 1000 and a NaN. - """ - # Setup - data = pd.Series(np.array([1230, 12300, 123000, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output == 0 - - def test__learn_rounding_digits_all_nans(self): - """Test the ``_learn_rounding_digits`` method with data that is all NaNs. - - If the data is all NaNs, expect that the output is 0. - - Input: - - an array of NaN. - """ - # Setup - data = pd.Series(np.array([np.nan, np.nan, np.nan, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output is None - def test_learn_format(self): """Test that ``learn_format`` method.