diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index e6fa27e2e..a774f51c4 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -4,6 +4,7 @@ import math import re import warnings +from unittest.mock import Mock import faker import numpy as np @@ -2019,3 +2020,112 @@ def test_fit_int_primary_key_regex_includes_zero(regex): ) with pytest.raises(SynthesizerInputError, match=message): instance.fit(data) + + +def test__estimate_num_columns_to_be_modeled_various_sdtypes(): + """Test the estimated number of columns is correct for various sdtypes. + + To check that the number columns is correct we Mock the ``_finalize`` method + and compare its output with the estimated number of columns. + + The dataset used follows the structure below: + R1 R2 + | / + GP + | + P + """ + # Setup + root1 = pd.DataFrame({'R1': [0, 1, 2]}) + root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) + grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]}) + parent = pd.DataFrame({ + 'P': [0, 1, 2], + 'GP': [0, 1, 2], + 'numerical': [0.1, 0.5, np.nan], + 'categorical': ['a', np.nan, 'c'], + 'datetime': [None, '2019-01-02', '2019-01-03'], + 'boolean': [float('nan'), False, True], + 'id': [0, 1, 2], + }) + data = { + 'root1': root1, + 'root2': root2, + 'grandparent': grandparent, + 'parent': parent, + } + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'root1': { + 'primary_key': 'R1', + 'columns': { + 'R1': {'sdtype': 'id'}, + }, + }, + 'root2': { + 'primary_key': 'R2', + 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, + }, + 'grandparent': { + 'primary_key': 'GP', + 'columns': { + 'GP': {'sdtype': 'id'}, + 'R1': {'sdtype': 'id'}, + 'R2': {'sdtype': 'id'}, + }, + }, + 'parent': { + 'primary_key': 'P', + 'columns': { + 'P': {'sdtype': 'id'}, + 'GP': {'sdtype': 'id'}, + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime': {'sdtype': 'datetime'}, + 'boolean': {'sdtype': 'boolean'}, + 'id': {'sdtype': 'id'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'root1', + 'parent_primary_key': 'R1', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R1', + }, + { + 'parent_table_name': 'root2', + 'parent_primary_key': 'R2', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R2', + }, + { + 'parent_table_name': 'grandparent', + 'parent_primary_key': 'GP', + 'child_table_name': 'parent', + 'child_foreign_key': 'GP', + }, + ], + }) + synthesizer = HMASynthesizer(metadata) + synthesizer._finalize = Mock(return_value=data) + + # Run estimation + estimated_num_columns = synthesizer._estimate_num_columns(metadata) + + # Run actual modeling + synthesizer.fit(data) + synthesizer.sample() + + # Assert estimated number of columns is correct + tables = synthesizer._finalize.call_args[0][0] + for table_name, table in tables.items(): + # Subract all the id columns present in the data, as those are not estimated + num_table_cols = len(table.columns) + if table_name in {'parent', 'grandparent'}: + num_table_cols -= 3 + if table_name in {'root1', 'root2'}: + num_table_cols -= 1 + + assert num_table_cols == estimated_num_columns[table_name] diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index c5b979d4f..337451f80 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -322,7 +322,8 @@ def test__validate_unexpected_kwargs_invalid(self, column_name, sdtype, kwargs, with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_unexpected_kwargs(column_name, sdtype, **kwargs) - def test__validate_column_invalid_sdtype(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test__validate_column_invalid_sdtype(self, mock_is_faker_function): """Test the method with an invalid sdtype. If the sdtype isn't one of the supported types, anonymized types or Faker functions, @@ -330,6 +331,7 @@ def test__validate_column_invalid_sdtype(self): """ # Setup instance = SingleTableMetadata() + mock_is_faker_function.return_value = False # Run and Assert error_msg = re.escape( @@ -340,11 +342,13 @@ def test__validate_column_invalid_sdtype(self): instance._validate_column_args('column', 'fake_type') error_msg = re.escape( - 'Invalid sdtype: None is not a string. Please use one of the ' 'supported SDV sdtypes.' + 'Invalid sdtype: None is not a string. Please use one of the supported SDV sdtypes.' ) with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_column_args('column', None) + mock_is_faker_function.assert_called_once_with('fake_type') + @patch('sdv.metadata.single_table.SingleTableMetadata._validate_unexpected_kwargs') @patch('sdv.metadata.single_table.SingleTableMetadata._validate_numerical') def test__validate_column_numerical(self, mock__validate_numerical, mock__validate_kwargs): @@ -599,7 +603,8 @@ def test_add_column_sdtype_not_in_kwargs(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.add_column('synthetic') - def test_add_column_invalid_sdtype(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_add_column_invalid_sdtype(self, mock_is_faker_function): """Test the method with an invalid sdtype. If the sdtype isn't one of the supported types, anonymized types or Faker functions, @@ -607,6 +612,7 @@ def test_add_column_invalid_sdtype(self): """ # Setup instance = SingleTableMetadata() + mock_is_faker_function.return_value = False # Run and Assert error_msg = re.escape( @@ -616,6 +622,8 @@ def test_add_column_invalid_sdtype(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.add_column('column', sdtype='fake_type') + mock_is_faker_function.assert_called_once_with('fake_type') + def test_add_column(self): """Test ``add_column`` method. @@ -794,13 +802,15 @@ def test_update_columns_sdtype_in_kwargs_error(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.update_columns(['col_1', 'col_2'], sdtype='numerical', pii=True) - def test_update_columns_multiple_errors(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_update_columns_multiple_errors(self, mock_is_faker_function): """Test the ``update_columns`` method. Test that ``update_columns`` with multiple errors. Should raise an ``InvalidMetadataError`` with a summary of all the errors. """ # Setup + mock_is_faker_function.return_value = True instance = SingleTableMetadata() instance.columns = { 'col_1': {'sdtype': 'country_code'}, @@ -817,6 +827,8 @@ def test_update_columns_multiple_errors(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.update_columns(['col_1', 'col_2', 'col_3'], pii=True) + mock_is_faker_function.assert_called_once_with('country_code') + def test_update_columns(self): """Test the ``update_columns`` method.""" # Setup @@ -839,9 +851,11 @@ def test_update_columns(self): 'salary': {'sdtype': 'categorical'}, } - def test_update_columns_kwargs_without_sdtype(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_update_columns_kwargs_without_sdtype(self, mock_is_faker_function): """Test the ``update_columns`` method when there is no ``sdtype`` in the kwargs.""" # Setup + mock_is_faker_function.return_value = True instance = SingleTableMetadata() instance.columns = { 'col_1': {'sdtype': 'country_code'}, @@ -859,6 +873,11 @@ def test_update_columns_kwargs_without_sdtype(self): 'col_3': {'sdtype': 'longitude', 'pii': True}, } assert instance._updated is True + mock_is_faker_function.assert_has_calls([ + call('country_code'), + call('latitude'), + call('longitude'), + ]) def test_update_columns_metadata(self): """Test the ``update_columns_metadata`` method.""" @@ -1620,7 +1639,8 @@ def test_set_primary_key_validation_columns(self): instance.set_primary_key('b') # NOTE: used to be ('a', 'b', 'd', 'c') - def test_set_primary_key_validation_categorical(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_set_primary_key_validation_categorical(self, mock_is_faker_function): """Test that ``set_primary_key`` crashes when its sdtype is categorical. Input: @@ -1630,6 +1650,7 @@ def test_set_primary_key_validation_categorical(self): - An ``InvalidMetadataError`` should be raised. """ # Setup + mock_is_faker_function.return_value = False instance = SingleTableMetadata() instance.add_column('column1', sdtype='categorical') instance.add_column('column2', sdtype='categorical') @@ -1640,6 +1661,8 @@ def test_set_primary_key_validation_categorical(self): with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_primary_key('column1') + mock_is_faker_function.assert_called_once_with('categorical') + def test_set_primary_key(self): """Test that ``set_primary_key`` sets the ``_primary_key`` value.""" # Setup @@ -1776,7 +1799,8 @@ def test_set_sequence_key_validation_columns(self): instance.set_sequence_key('b') # NOTE: used to be ('a', 'b', 'd', 'c') - def test_set_sequence_key_validation_categorical(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_set_sequence_key_validation_categorical(self, mock_is_faker_function): """Test that ``set_sequence_key`` crashes when its sdtype is categorical. Input: @@ -1786,6 +1810,7 @@ def test_set_sequence_key_validation_categorical(self): - An ``InvalidMetadataError`` should be raised. """ # Setup + mock_is_faker_function.return_value = False instance = SingleTableMetadata() instance.add_column('column1', sdtype='categorical') instance.add_column('column2', sdtype='categorical') @@ -1796,6 +1821,8 @@ def test_set_sequence_key_validation_categorical(self): with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_sequence_key('column1') + mock_is_faker_function.assert_called_once_with('categorical') + def test_set_sequence_key(self): """Test that ``set_sequence_key`` sets the ``_sequence_key`` value.""" # Setup @@ -1887,7 +1914,8 @@ def test_add_alternate_keys_validation_columns(self): instance.add_alternate_keys(['abc', '123']) # NOTE: used to be ['abc', ('123', '213', '312'), 'bca'] - def test_add_alternate_keys_validation_categorical(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_add_alternate_keys_validation_categorical(self, mock_is_faker_function): """Test that ``add_alternate_keys`` crashes when its sdtype is categorical. Input: @@ -1897,6 +1925,7 @@ def test_add_alternate_keys_validation_categorical(self): - An ``InvalidMetadataError`` should be raised. """ # Setup + mock_is_faker_function.return_value = False instance = SingleTableMetadata() instance.add_column('column1', sdtype='categorical') instance.add_column('column2', sdtype='categorical') @@ -1909,6 +1938,8 @@ def test_add_alternate_keys_validation_categorical(self): with pytest.raises(InvalidMetadataError, match=err_msg): instance.add_alternate_keys(['column1', 'column2', 'column3']) + mock_is_faker_function.assert_has_calls([call('categorical'), call('categorical')]) + def test_add_alternate_keys_validation_primary_key(self): """Test that ``add_alternate_keys`` crashes when the key is a primary key. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index c1fb04aac..63e363f1b 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -140,9 +140,11 @@ def test___init__( 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) - def test__init__column_relationship_warning(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test__init__column_relationship_warning(self, mock_is_faker_function): """Test that a warning is raised only once when the metadata has column relationships.""" # Setup + mock_is_faker_function.return_value = True metadata = get_multi_table_metadata() metadata.add_column('nesreca', 'lat', sdtype='latitude') metadata.add_column('nesreca', 'lon', sdtype='longitude') @@ -165,6 +167,10 @@ def test__init__column_relationship_warning(self): warning for warning in caught_warnings if expected_warning in str(warning.message) ] assert len(column_relationship_warnings) == 1 + mock_is_faker_function.assert_has_calls([ + call('latitude'), + call('longitude'), + ]) def test___init___synthesizer_kwargs_deprecated(self): """Test that the ``synthesizer_kwargs`` method is deprecated.""" diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 51d5aa1a4..0db06ea82 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -1119,111 +1119,3 @@ def test__estimate_num_columns_to_be_modeled(self): num_table_cols -= 1 assert num_table_cols == estimated_num_columns[table_name] - - def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): - """Test the estimated number of columns is correct for various sdtypes. - - To check that the number columns is correct we Mock the ``_finalize`` method - and compare its output with the estimated number of columns. - - The dataset used follows the structure below: - R1 R2 - | / - GP - | - P - """ - # Setup - root1 = pd.DataFrame({'R1': [0, 1, 2]}) - root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) - grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]}) - parent = pd.DataFrame({ - 'P': [0, 1, 2], - 'GP': [0, 1, 2], - 'numerical': [0.1, 0.5, np.nan], - 'categorical': ['a', np.nan, 'c'], - 'datetime': [None, '2019-01-02', '2019-01-03'], - 'boolean': [float('nan'), False, True], - 'id': [0, 1, 2], - }) - data = { - 'root1': root1, - 'root2': root2, - 'grandparent': grandparent, - 'parent': parent, - } - metadata = MultiTableMetadata.load_from_dict({ - 'tables': { - 'root1': { - 'primary_key': 'R1', - 'columns': { - 'R1': {'sdtype': 'id'}, - }, - }, - 'root2': { - 'primary_key': 'R2', - 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, - }, - 'grandparent': { - 'primary_key': 'GP', - 'columns': { - 'GP': {'sdtype': 'id'}, - 'R1': {'sdtype': 'id'}, - 'R2': {'sdtype': 'id'}, - }, - }, - 'parent': { - 'primary_key': 'P', - 'columns': { - 'P': {'sdtype': 'id'}, - 'GP': {'sdtype': 'id'}, - 'numerical': {'sdtype': 'numerical'}, - 'categorical': {'sdtype': 'categorical'}, - 'datetime': {'sdtype': 'datetime'}, - 'boolean': {'sdtype': 'boolean'}, - 'id': {'sdtype': 'id'}, - }, - }, - }, - 'relationships': [ - { - 'parent_table_name': 'root1', - 'parent_primary_key': 'R1', - 'child_table_name': 'grandparent', - 'child_foreign_key': 'R1', - }, - { - 'parent_table_name': 'root2', - 'parent_primary_key': 'R2', - 'child_table_name': 'grandparent', - 'child_foreign_key': 'R2', - }, - { - 'parent_table_name': 'grandparent', - 'parent_primary_key': 'GP', - 'child_table_name': 'parent', - 'child_foreign_key': 'GP', - }, - ], - }) - synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock(return_value=data) - - # Run estimation - estimated_num_columns = synthesizer._estimate_num_columns(metadata) - - # Run actual modeling - synthesizer.fit(data) - synthesizer.sample() - - # Assert estimated number of columns is correct - tables = synthesizer._finalize.call_args[0][0] - for table_name, table in tables.items(): - # Subract all the id columns present in the data, as those are not estimated - num_table_cols = len(table.columns) - if table_name in {'parent', 'grandparent'}: - num_table_cols -= 3 - if table_name in {'root1', 'root2'}: - num_table_cols -= 1 - - assert num_table_cols == estimated_num_columns[table_name]