Skip to content

Commit

Permalink
Align default transformers between SDV and RDT (#1506)
Browse files Browse the repository at this point in the history
* use rdt default transformers

* fix minimum version

* test

* pomegranate 1

* change assert test
  • Loading branch information
R-Palazzo authored Jul 20, 2023
1 parent 64c5d9a commit 64afb7a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 24 deletions.
25 changes: 7 additions & 18 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import rdt
from pandas.api.types import is_float_dtype, is_integer_dtype
from rdt.transformers import AnonymizedFaker, RegexGenerator
from rdt.transformers import AnonymizedFaker, RegexGenerator, get_default_transformers

from sdv.constraints import Constraint
from sdv.constraints.base import get_subclasses
Expand Down Expand Up @@ -58,21 +58,6 @@ class DataProcessor:
Faker's default locale.
"""

_DEFAULT_TRANSFORMERS_BY_SDTYPE = {
'numerical': rdt.transformers.FloatFormatter(
learn_rounding_scheme=True,
enforce_min_max_values=True,
missing_value_replacement='mean',
missing_value_generation='random',
),
'categorical': rdt.transformers.LabelEncoder(add_noise=True),
'boolean': rdt.transformers.LabelEncoder(add_noise=True),
'datetime': rdt.transformers.UnixTimestampEncoder(
missing_value_replacement='mean',
missing_value_generation='random',
),
'id': rdt.transformers.RegexGenerator()
}
_DTYPE_TO_SDTYPE = {
'i': 'numerical',
'f': 'numerical',
Expand Down Expand Up @@ -101,7 +86,11 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
self._constraints = []
self._constraints_to_reverse = []
self._custom_constraint_classes = {}
self._transformers_by_sdtype = self._DEFAULT_TRANSFORMERS_BY_SDTYPE.copy()

self._transformers_by_sdtype = deepcopy(get_default_transformers())
self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator()
del self._transformers_by_sdtype['text']

self._update_numerical_transformer(enforce_rounding, enforce_min_max_values)
self._hyper_transformer = rdt.HyperTransformer()
self.table_name = table_name
Expand Down Expand Up @@ -464,7 +453,7 @@ def _create_config(self, data, columns_created_by_constraints):
for column in set(data.columns) - columns_created_by_constraints:
column_metadata = self.metadata.columns.get(column)
sdtype = column_metadata.get('sdtype')
pii = column_metadata.get('pii', sdtype not in self._DEFAULT_TRANSFORMERS_BY_SDTYPE)
pii = column_metadata.get('pii', sdtype not in self._transformers_by_sdtype)
sdtypes[column] = 'pii' if pii else sdtype

if sdtype == 'id':
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.2,<0.8',
'deepecho>=0.4.1,<0.5',
'rdt>=1.5.0,<2',
'rdt>=1.6.1.dev0',
'sdmetrics>=0.10.0,<0.11',
'cloudpickle>=2.1.0,<3.0',
'boto3>=1.15.0,<2',
Expand Down
5 changes: 2 additions & 3 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ def install_minimum(c):
if _validate_python_version(line):
requirement = re.match(r'[^>]*', line).group(0)
requirement = re.sub(r"""['",]""", '', requirement)
version = re.search(r'>=?(\d\.?)+', line).group(0)
version = re.search(r'>=?(\d\.?)+\w*', line).group(0)
if version:
version = re.sub(r'>=?', '==', version)
version = re.sub(r"""['",]""", '', version)
requirement += version
versions.append(requirement)

elif (line.startswith('install_requires = [') or
line.startswith('pomegranate_requires = [')):
elif (line.startswith('install_requires = [')):
started = True

c.run(f'python -m pip install {" ".join(versions)}')
Expand Down
36 changes: 34 additions & 2 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,20 @@ def test__update_numerical_transformer(self):
assert transformer.learn_rounding_scheme is False
assert transformer.enforce_min_max_values is False

@patch('sdv.data_processing.data_processor.rdt.transformers.RegexGenerator')
@patch('sdv.data_processing.data_processor.get_default_transformers')
@patch('sdv.data_processing.data_processor.rdt')
@patch('sdv.data_processing.data_processor.DataProcessor._update_numerical_transformer')
def test___init__(self, update_transformer_mock, mock_rdt):
def test___init__(
self, update_transformer_mock, mock_rdt, mock_default_transformers, mock_regex_generator
):
"""Test the ``__init__`` method.
Setup:
- Patch the ``Constraint`` module.
- Patch the ``RegexGenerator`` class.
- Patch the ``get_default_transformers`` function.
- Patch the ``rdt`` module.
- Patch the ``_update_numerical_transformer`` method.
Input:
- A mock for metadata.
Expand All @@ -63,6 +70,17 @@ def test___init__(self, update_transformer_mock, mock_rdt):
metadata.add_alternate_keys(['col_2'])
metadata.set_primary_key('col')

mock_default_transformers.return_value = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'datetime': 'UnixTimestampEncoder()',
'text': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
}

mock_regex_generator.return_value = 'RegexGenerator()'

# Run
data_processor = DataProcessor(
metadata=metadata,
Expand Down Expand Up @@ -90,6 +108,20 @@ def test___init__(self, update_transformer_mock, mock_rdt):
assert data_processor._hyper_transformer == mock_rdt.HyperTransformer.return_value
update_transformer_mock.assert_called_with(True, False)

mock_default_transformers.assert_called_once()
mock_regex_generator.assert_called_once()

expected_default_transformers = {
'numerical': 'FloatFormatter()',
'categorical': 'LabelEncoder(add_noise=True)',
'boolean': 'LabelEncoder(add_noise=True)',
'datetime': 'UnixTimestampEncoder()',
'id': 'RegexGenerator()',
'pii': 'AnonymizedFaker()',
}

assert data_processor._transformers_by_sdtype == expected_default_transformers

def test___init___without_mocks(self):
"""Test the ``__init__`` method without using mocks.
Expand Down

0 comments on commit 64afb7a

Please sign in to comment.