Skip to content

Commit

Permalink
Mock every usage of is_faker_function to speed up the unit tests (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Aug 7, 2024
1 parent 0d24da1 commit 17aaa28
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 117 deletions.
110 changes: 110 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import re
import warnings
from unittest.mock import Mock

import faker
import numpy as np
Expand Down Expand Up @@ -2019,3 +2020,112 @@ def test_fit_int_primary_key_regex_includes_zero(regex):
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)


def test__estimate_num_columns_to_be_modeled_various_sdtypes():
"""Test the estimated number of columns is correct for various sdtypes.
To check that the number columns is correct we Mock the ``_finalize`` method
and compare its output with the estimated number of columns.
The dataset used follows the structure below:
R1 R2
| /
GP
|
P
"""
# Setup
root1 = pd.DataFrame({'R1': [0, 1, 2]})
root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]})
grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]})
parent = pd.DataFrame({
'P': [0, 1, 2],
'GP': [0, 1, 2],
'numerical': [0.1, 0.5, np.nan],
'categorical': ['a', np.nan, 'c'],
'datetime': [None, '2019-01-02', '2019-01-03'],
'boolean': [float('nan'), False, True],
'id': [0, 1, 2],
})
data = {
'root1': root1,
'root2': root2,
'grandparent': grandparent,
'parent': parent,
}
metadata = MultiTableMetadata.load_from_dict({
'tables': {
'root1': {
'primary_key': 'R1',
'columns': {
'R1': {'sdtype': 'id'},
},
},
'root2': {
'primary_key': 'R2',
'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}},
},
'grandparent': {
'primary_key': 'GP',
'columns': {
'GP': {'sdtype': 'id'},
'R1': {'sdtype': 'id'},
'R2': {'sdtype': 'id'},
},
},
'parent': {
'primary_key': 'P',
'columns': {
'P': {'sdtype': 'id'},
'GP': {'sdtype': 'id'},
'numerical': {'sdtype': 'numerical'},
'categorical': {'sdtype': 'categorical'},
'datetime': {'sdtype': 'datetime'},
'boolean': {'sdtype': 'boolean'},
'id': {'sdtype': 'id'},
},
},
},
'relationships': [
{
'parent_table_name': 'root1',
'parent_primary_key': 'R1',
'child_table_name': 'grandparent',
'child_foreign_key': 'R1',
},
{
'parent_table_name': 'root2',
'parent_primary_key': 'R2',
'child_table_name': 'grandparent',
'child_foreign_key': 'R2',
},
{
'parent_table_name': 'grandparent',
'parent_primary_key': 'GP',
'child_table_name': 'parent',
'child_foreign_key': 'GP',
},
],
})
synthesizer = HMASynthesizer(metadata)
synthesizer._finalize = Mock(return_value=data)

# Run estimation
estimated_num_columns = synthesizer._estimate_num_columns(metadata)

# Run actual modeling
synthesizer.fit(data)
synthesizer.sample()

# Assert estimated number of columns is correct
tables = synthesizer._finalize.call_args[0][0]
for table_name, table in tables.items():
# Subract all the id columns present in the data, as those are not estimated
num_table_cols = len(table.columns)
if table_name in {'parent', 'grandparent'}:
num_table_cols -= 3
if table_name in {'root1', 'root2'}:
num_table_cols -= 1

assert num_table_cols == estimated_num_columns[table_name]
47 changes: 39 additions & 8 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,16 @@ def test__validate_unexpected_kwargs_invalid(self, column_name, sdtype, kwargs,
with pytest.raises(InvalidMetadataError, match=error_msg):
instance._validate_unexpected_kwargs(column_name, sdtype, **kwargs)

def test__validate_column_invalid_sdtype(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test__validate_column_invalid_sdtype(self, mock_is_faker_function):
"""Test the method with an invalid sdtype.
If the sdtype isn't one of the supported types, anonymized types or Faker functions,
then an error should be raised.
"""
# Setup
instance = SingleTableMetadata()
mock_is_faker_function.return_value = False

# Run and Assert
error_msg = re.escape(
Expand All @@ -340,11 +342,13 @@ def test__validate_column_invalid_sdtype(self):
instance._validate_column_args('column', 'fake_type')

error_msg = re.escape(
'Invalid sdtype: None is not a string. Please use one of the ' 'supported SDV sdtypes.'
'Invalid sdtype: None is not a string. Please use one of the supported SDV sdtypes.'
)
with pytest.raises(InvalidMetadataError, match=error_msg):
instance._validate_column_args('column', None)

mock_is_faker_function.assert_called_once_with('fake_type')

@patch('sdv.metadata.single_table.SingleTableMetadata._validate_unexpected_kwargs')
@patch('sdv.metadata.single_table.SingleTableMetadata._validate_numerical')
def test__validate_column_numerical(self, mock__validate_numerical, mock__validate_kwargs):
Expand Down Expand Up @@ -599,14 +603,16 @@ def test_add_column_sdtype_not_in_kwargs(self):
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.add_column('synthetic')

def test_add_column_invalid_sdtype(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test_add_column_invalid_sdtype(self, mock_is_faker_function):
"""Test the method with an invalid sdtype.
If the sdtype isn't one of the supported types, anonymized types or Faker functions,
then an error should be raised.
"""
# Setup
instance = SingleTableMetadata()
mock_is_faker_function.return_value = False

# Run and Assert
error_msg = re.escape(
Expand All @@ -616,6 +622,8 @@ def test_add_column_invalid_sdtype(self):
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.add_column('column', sdtype='fake_type')

mock_is_faker_function.assert_called_once_with('fake_type')

def test_add_column(self):
"""Test ``add_column`` method.
Expand Down Expand Up @@ -794,13 +802,15 @@ def test_update_columns_sdtype_in_kwargs_error(self):
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.update_columns(['col_1', 'col_2'], sdtype='numerical', pii=True)

def test_update_columns_multiple_errors(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test_update_columns_multiple_errors(self, mock_is_faker_function):
"""Test the ``update_columns`` method.
Test that ``update_columns`` with multiple errors.
Should raise an ``InvalidMetadataError`` with a summary of all the errors.
"""
# Setup
mock_is_faker_function.return_value = True
instance = SingleTableMetadata()
instance.columns = {
'col_1': {'sdtype': 'country_code'},
Expand All @@ -817,6 +827,8 @@ def test_update_columns_multiple_errors(self):
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.update_columns(['col_1', 'col_2', 'col_3'], pii=True)

mock_is_faker_function.assert_called_once_with('country_code')

def test_update_columns(self):
"""Test the ``update_columns`` method."""
# Setup
Expand All @@ -839,9 +851,11 @@ def test_update_columns(self):
'salary': {'sdtype': 'categorical'},
}

def test_update_columns_kwargs_without_sdtype(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test_update_columns_kwargs_without_sdtype(self, mock_is_faker_function):
"""Test the ``update_columns`` method when there is no ``sdtype`` in the kwargs."""
# Setup
mock_is_faker_function.return_value = True
instance = SingleTableMetadata()
instance.columns = {
'col_1': {'sdtype': 'country_code'},
Expand All @@ -859,6 +873,11 @@ def test_update_columns_kwargs_without_sdtype(self):
'col_3': {'sdtype': 'longitude', 'pii': True},
}
assert instance._updated is True
mock_is_faker_function.assert_has_calls([
call('country_code'),
call('latitude'),
call('longitude'),
])

def test_update_columns_metadata(self):
"""Test the ``update_columns_metadata`` method."""
Expand Down Expand Up @@ -1620,7 +1639,8 @@ def test_set_primary_key_validation_columns(self):
instance.set_primary_key('b')
# NOTE: used to be ('a', 'b', 'd', 'c')

def test_set_primary_key_validation_categorical(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test_set_primary_key_validation_categorical(self, mock_is_faker_function):
"""Test that ``set_primary_key`` crashes when its sdtype is categorical.
Input:
Expand All @@ -1630,6 +1650,7 @@ def test_set_primary_key_validation_categorical(self):
- An ``InvalidMetadataError`` should be raised.
"""
# Setup
mock_is_faker_function.return_value = False
instance = SingleTableMetadata()
instance.add_column('column1', sdtype='categorical')
instance.add_column('column2', sdtype='categorical')
Expand All @@ -1640,6 +1661,8 @@ def test_set_primary_key_validation_categorical(self):
with pytest.raises(InvalidMetadataError, match=err_msg):
instance.set_primary_key('column1')

mock_is_faker_function.assert_called_once_with('categorical')

def test_set_primary_key(self):
"""Test that ``set_primary_key`` sets the ``_primary_key`` value."""
# Setup
Expand Down Expand Up @@ -1776,7 +1799,8 @@ def test_set_sequence_key_validation_columns(self):
instance.set_sequence_key('b')
# NOTE: used to be ('a', 'b', 'd', 'c')

def test_set_sequence_key_validation_categorical(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test_set_sequence_key_validation_categorical(self, mock_is_faker_function):
"""Test that ``set_sequence_key`` crashes when its sdtype is categorical.
Input:
Expand All @@ -1786,6 +1810,7 @@ def test_set_sequence_key_validation_categorical(self):
- An ``InvalidMetadataError`` should be raised.
"""
# Setup
mock_is_faker_function.return_value = False
instance = SingleTableMetadata()
instance.add_column('column1', sdtype='categorical')
instance.add_column('column2', sdtype='categorical')
Expand All @@ -1796,6 +1821,8 @@ def test_set_sequence_key_validation_categorical(self):
with pytest.raises(InvalidMetadataError, match=err_msg):
instance.set_sequence_key('column1')

mock_is_faker_function.assert_called_once_with('categorical')

def test_set_sequence_key(self):
"""Test that ``set_sequence_key`` sets the ``_sequence_key`` value."""
# Setup
Expand Down Expand Up @@ -1887,7 +1914,8 @@ def test_add_alternate_keys_validation_columns(self):
instance.add_alternate_keys(['abc', '123'])
# NOTE: used to be ['abc', ('123', '213', '312'), 'bca']

def test_add_alternate_keys_validation_categorical(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test_add_alternate_keys_validation_categorical(self, mock_is_faker_function):
"""Test that ``add_alternate_keys`` crashes when its sdtype is categorical.
Input:
Expand All @@ -1897,6 +1925,7 @@ def test_add_alternate_keys_validation_categorical(self):
- An ``InvalidMetadataError`` should be raised.
"""
# Setup
mock_is_faker_function.return_value = False
instance = SingleTableMetadata()
instance.add_column('column1', sdtype='categorical')
instance.add_column('column2', sdtype='categorical')
Expand All @@ -1909,6 +1938,8 @@ def test_add_alternate_keys_validation_categorical(self):
with pytest.raises(InvalidMetadataError, match=err_msg):
instance.add_alternate_keys(['column1', 'column2', 'column3'])

mock_is_faker_function.assert_has_calls([call('categorical'), call('categorical')])

def test_add_alternate_keys_validation_primary_key(self):
"""Test that ``add_alternate_keys`` crashes when the key is a primary key.
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def test___init__(
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

def test__init__column_relationship_warning(self):
@patch('sdv.metadata.single_table.is_faker_function')
def test__init__column_relationship_warning(self, mock_is_faker_function):
"""Test that a warning is raised only once when the metadata has column relationships."""
# Setup
mock_is_faker_function.return_value = True
metadata = get_multi_table_metadata()
metadata.add_column('nesreca', 'lat', sdtype='latitude')
metadata.add_column('nesreca', 'lon', sdtype='longitude')
Expand All @@ -165,6 +167,10 @@ def test__init__column_relationship_warning(self):
warning for warning in caught_warnings if expected_warning in str(warning.message)
]
assert len(column_relationship_warnings) == 1
mock_is_faker_function.assert_has_calls([
call('latitude'),
call('longitude'),
])

def test___init___synthesizer_kwargs_deprecated(self):
"""Test that the ``synthesizer_kwargs`` method is deprecated."""
Expand Down
Loading

0 comments on commit 17aaa28

Please sign in to comment.