Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align default transformers between SDV and RDT #1506

Merged
merged 5 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 7 additions & 18 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import rdt
from pandas.api.types import is_float_dtype, is_integer_dtype
from rdt.transformers import AnonymizedFaker, RegexGenerator
from rdt.transformers import AnonymizedFaker, RegexGenerator, get_default_transformers

from sdv.constraints import Constraint
from sdv.constraints.base import get_subclasses
Expand Down Expand Up @@ -58,21 +58,6 @@ class DataProcessor:
Faker's default locale.
"""

_DEFAULT_TRANSFORMERS_BY_SDTYPE = {
'numerical': rdt.transformers.FloatFormatter(
learn_rounding_scheme=True,
enforce_min_max_values=True,
missing_value_replacement='mean',
missing_value_generation='random',
),
'categorical': rdt.transformers.LabelEncoder(add_noise=True),
'boolean': rdt.transformers.LabelEncoder(add_noise=True),
'datetime': rdt.transformers.UnixTimestampEncoder(
missing_value_replacement='mean',
missing_value_generation='random',
),
'id': rdt.transformers.RegexGenerator()
}
_DTYPE_TO_SDTYPE = {
'i': 'numerical',
'f': 'numerical',
Expand Down Expand Up @@ -101,7 +86,11 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
self._constraints = []
self._constraints_to_reverse = []
self._custom_constraint_classes = {}
self._transformers_by_sdtype = self._DEFAULT_TRANSFORMERS_BY_SDTYPE.copy()

self._transformers_by_sdtype = deepcopy(get_default_transformers())
self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator()
del self._transformers_by_sdtype['text']

self._update_numerical_transformer(enforce_rounding, enforce_min_max_values)
self._hyper_transformer = rdt.HyperTransformer()
self.table_name = table_name
Expand Down Expand Up @@ -464,7 +453,7 @@ def _create_config(self, data, columns_created_by_constraints):
for column in set(data.columns) - columns_created_by_constraints:
column_metadata = self.metadata.columns.get(column)
sdtype = column_metadata.get('sdtype')
pii = column_metadata.get('pii', sdtype not in self._DEFAULT_TRANSFORMERS_BY_SDTYPE)
pii = column_metadata.get('pii', sdtype not in self._transformers_by_sdtype)
sdtypes[column] = 'pii' if pii else sdtype

if sdtype == 'id':
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.2,<0.8',
'deepecho>=0.4.1,<0.5',
'rdt>=1.5.0,<2',
'rdt>=1.6.1.dev0',
'sdmetrics>=0.10.0,<0.11',
'cloudpickle>=2.1.0,<3.0',
'boto3>=1.15.0,<2',
Expand Down
5 changes: 2 additions & 3 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ def install_minimum(c):
if _validate_python_version(line):
requirement = re.match(r'[^>]*', line).group(0)
requirement = re.sub(r"""['",]""", '', requirement)
version = re.search(r'>=?(\d\.?)+', line).group(0)
version = re.search(r'>=?(\d\.?)+\w*', line).group(0)
if version:
version = re.sub(r'>=?', '==', version)
version = re.sub(r"""['",]""", '', version)
requirement += version
versions.append(requirement)

elif (line.startswith('install_requires = [') or
line.startswith('pomegranate_requires = [')):
elif (line.startswith('install_requires = [')):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you leave a comment in the description of the PR explaining why we made this change and the one on line 74? If we're going to merge this in this PR it would be good to understand why

started = True

c.run(f'python -m pip install {" ".join(versions)}')
Expand Down
36 changes: 34 additions & 2 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,20 @@ def test__update_numerical_transformer(self):
assert transformer.learn_rounding_scheme is False
assert transformer.enforce_min_max_values is False

@patch('sdv.data_processing.data_processor.rdt.transformers.RegexGenerator')
@patch('sdv.data_processing.data_processor.get_default_transformers')
@patch('sdv.data_processing.data_processor.rdt')
@patch('sdv.data_processing.data_processor.DataProcessor._update_numerical_transformer')
def test___init__(self, update_transformer_mock, mock_rdt):
def test___init__(
self, update_transformer_mock, mock_rdt, mock_default_transformers, mock_regex_generator
):
"""Test the ``__init__`` method.

Setup:
- Patch the ``Constraint`` module.
- Patch the ``RegexGenerator`` class.
- Patch the ``get_default_transformers`` function.
- Patch the ``rdt`` module.
- Patch the ``_update_numerical_transformer`` method.

Input:
- A mock for metadata.
Expand All @@ -63,6 +70,17 @@ def test___init__(self, update_transformer_mock, mock_rdt):
metadata.add_alternate_keys(['col_2'])
metadata.set_primary_key('col')

mock_default_transformers.return_value = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'datetime': 'UnixTimestampEncoder()',
'text': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
}

mock_regex_generator.return_value = 'RegexGenerator()'

# Run
data_processor = DataProcessor(
metadata=metadata,
Expand Down Expand Up @@ -90,6 +108,20 @@ def test___init__(self, update_transformer_mock, mock_rdt):
assert data_processor._hyper_transformer == mock_rdt.HyperTransformer.return_value
update_transformer_mock.assert_called_with(True, False)

mock_default_transformers.assert_called_once()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should mock the return value and make sure the id default is properly added and that the text one is deleted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes done in d84def8

mock_regex_generator.assert_called_once()

expected_default_transformers = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'datetime': 'UnixTimestampEncoder()',
'id': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
}

assert data_processor._transformers_by_sdtype == expected_default_transformers

def test___init___without_mocks(self):
"""Test the ``__init__`` method without using mocks.

Expand Down
Loading