diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index d2cdc3c74..5ee7ed58b 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -9,7 +9,7 @@ import pandas as pd import rdt from pandas.api.types import is_float_dtype, is_integer_dtype -from rdt.transformers import AnonymizedFaker, RegexGenerator +from rdt.transformers import AnonymizedFaker, RegexGenerator, get_default_transformers from sdv.constraints import Constraint from sdv.constraints.base import get_subclasses @@ -58,21 +58,6 @@ class DataProcessor: Faker's default locale. """ - _DEFAULT_TRANSFORMERS_BY_SDTYPE = { - 'numerical': rdt.transformers.FloatFormatter( - learn_rounding_scheme=True, - enforce_min_max_values=True, - missing_value_replacement='mean', - missing_value_generation='random', - ), - 'categorical': rdt.transformers.LabelEncoder(add_noise=True), - 'boolean': rdt.transformers.LabelEncoder(add_noise=True), - 'datetime': rdt.transformers.UnixTimestampEncoder( - missing_value_replacement='mean', - missing_value_generation='random', - ), - 'id': rdt.transformers.RegexGenerator() - } _DTYPE_TO_SDTYPE = { 'i': 'numerical', 'f': 'numerical', @@ -101,7 +86,11 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True, self._constraints = [] self._constraints_to_reverse = [] self._custom_constraint_classes = {} - self._transformers_by_sdtype = self._DEFAULT_TRANSFORMERS_BY_SDTYPE.copy() + + self._transformers_by_sdtype = deepcopy(get_default_transformers()) + self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator() + del self._transformers_by_sdtype['text'] + self._update_numerical_transformer(enforce_rounding, enforce_min_max_values) self._hyper_transformer = rdt.HyperTransformer() self.table_name = table_name @@ -464,7 +453,7 @@ def _create_config(self, data, columns_created_by_constraints): for column in set(data.columns) - columns_created_by_constraints: column_metadata = self.metadata.columns.get(column) sdtype = column_metadata.get('sdtype') - pii = column_metadata.get('pii', sdtype not in self._DEFAULT_TRANSFORMERS_BY_SDTYPE) + pii = column_metadata.get('pii', sdtype not in self._transformers_by_sdtype) sdtypes[column] = 'pii' if pii else sdtype if sdtype == 'id': diff --git a/setup.py b/setup.py index dabcd5f29..78dae06fc 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ 'copulas>=0.9.0,<0.10', 'ctgan>=0.7.2,<0.8', 'deepecho>=0.4.1,<0.5', - 'rdt>=1.5.0,<2', + 'rdt>=1.6.1.dev0', 'sdmetrics>=0.10.0,<0.11', 'cloudpickle>=2.1.0,<3.0', 'boto3>=1.15.0,<2', diff --git a/tasks.py b/tasks.py index bde50b3a5..f3f5c82af 100644 --- a/tasks.py +++ b/tasks.py @@ -71,15 +71,14 @@ def install_minimum(c): if _validate_python_version(line): requirement = re.match(r'[^>]*', line).group(0) requirement = re.sub(r"""['",]""", '', requirement) - version = re.search(r'>=?(\d\.?)+', line).group(0) + version = re.search(r'>=?(\d\.?)+\w*', line).group(0) if version: version = re.sub(r'>=?', '==', version) version = re.sub(r"""['",]""", '', version) requirement += version versions.append(requirement) - elif (line.startswith('install_requires = [') or - line.startswith('pomegranate_requires = [')): + elif (line.startswith('install_requires = [')): started = True c.run(f'python -m pip install {" ".join(versions)}') diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 93051a6ea..abc39732b 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -43,13 +43,20 @@ def test__update_numerical_transformer(self): assert transformer.learn_rounding_scheme is False assert transformer.enforce_min_max_values is False + @patch('sdv.data_processing.data_processor.rdt.transformers.RegexGenerator') + @patch('sdv.data_processing.data_processor.get_default_transformers') @patch('sdv.data_processing.data_processor.rdt') @patch('sdv.data_processing.data_processor.DataProcessor._update_numerical_transformer') - def test___init__(self, update_transformer_mock, mock_rdt): + def test___init__( + self, update_transformer_mock, mock_rdt, mock_default_transformers, mock_regex_generator + ): """Test the ``__init__`` method. Setup: - - Patch the ``Constraint`` module. + - Patch the ``RegexGenerator`` class. + - Patch the ``get_default_transformers`` function. + - Patch the ``rdt`` module. + - Patch the ``_update_numerical_transformer`` method. Input: - A mock for metadata. @@ -63,6 +70,17 @@ def test___init__(self, update_transformer_mock, mock_rdt): metadata.add_alternate_keys(['col_2']) metadata.set_primary_key('col') + mock_default_transformers.return_value = { + 'numerical': 'FloatFormatter()', + 'categorical': 'LabelEncoder(add_noise=True)', + 'boolean': 'LabelEncoder(add_noise=True)', + 'datetime': 'UnixTimestampEncoder()', + 'text': 'RegexGenerator()', + 'pii': 'AnonymizedFaker()', + } + + mock_regex_generator.return_value = 'RegexGenerator()' + # Run data_processor = DataProcessor( metadata=metadata, @@ -90,6 +108,20 @@ def test___init__(self, update_transformer_mock, mock_rdt): assert data_processor._hyper_transformer == mock_rdt.HyperTransformer.return_value update_transformer_mock.assert_called_with(True, False) + mock_default_transformers.assert_called_once() + mock_regex_generator.assert_called_once() + + expected_default_transformers = { + 'numerical': 'FloatFormatter()', + 'categorical': 'LabelEncoder(add_noise=True)', + 'boolean': 'LabelEncoder(add_noise=True)', + 'datetime': 'UnixTimestampEncoder()', + 'id': 'RegexGenerator()', + 'pii': 'AnonymizedFaker()', + } + + assert data_processor._transformers_by_sdtype == expected_default_transformers + def test___init___without_mocks(self): """Test the ``__init__`` method without using mocks.