Skip to content

Commit

Permalink
adding unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Jul 28, 2023
1 parent 794c5e5 commit 8cc966c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
1 change: 1 addition & 0 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ def reverse_transform(self, data, reset_keys=False):
column_names=missing_columns
)
sampled_columns.extend(missing_columns)
reversed_data[anonymized_data.columns] = anonymized_data[anonymized_data.notna()]

if self._keys and num_rows:
generated_keys = self.generate_keys(num_rows, reset_keys)
Expand Down
41 changes: 28 additions & 13 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sdv.errors import SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.base import BaseSynthesizer
from tests.utils import DataFrameMatcher


class TestDataProcessor:
Expand Down Expand Up @@ -1821,7 +1822,7 @@ def test__transform_constraints_is_condition_false_returns_data(self):
assert result.equals(expected_result)
assert dp._constraints_to_reverse == []

def test_reverse_transform(self):
def test_reverse_transform_blah(self):
"""Test the ``reverse_transform`` method.
This method should attempt to reverse transform all the columns using the
Expand All @@ -1846,21 +1847,26 @@ def test_reverse_transform(self):
dp = DataProcessor(SingleTableMetadata())
dp.fitted = True
dp.metadata = Mock()
dp.metadata.columns = {'a': None, 'b': None, 'c': None, 'd': None}
dp.metadata.columns = {'a': None, 'b': None, 'c': None, 'key': None, 'd': None}
data = pd.DataFrame({
'a': [1, 2, 3],
'b': [True, True, False],
'c': ['d', 'e', 'f'],

})
dp._keys = ['key']
dp._hyper_transformer = Mock()
dp._hyper_transformer.create_anonymized_columns.return_value = pd.DataFrame({
'd': ['a@gmail.com', 'b@gmail.com', 'c@gmail.com']
})
dp._hyper_transformer.create_anonymized_columns.side_effect = [
pd.DataFrame({'d': ['a@gmail.com', 'b@gmail.com', 'c@gmail.com']}),
pd.DataFrame({'key': ['sdv_0', 'sdv_1', 'sdv_2']})
]
dp._constraints_to_reverse = [constraint_mock]
dp._hyper_transformer.reverse_transform_subset.return_value = data
dp._hyper_transformer.reverse_transform_subset.return_value = data.copy()
dp._hyper_transformer._output_columns = ['a', 'b', 'c']
dp._dtypes = pd.Series(
[np.float64, np.bool_, np.object_, np.object_], index=['a', 'b', 'c', 'd'])
[np.float64, np.bool_, np.object_, np.object_, np.object_],
index=['a', 'b', 'c', 'd', 'key']
)
constraint_mock.reverse_transform.return_value = data

# Run
Expand All @@ -1872,20 +1878,29 @@ def test_reverse_transform(self):
'b': [True, True, False],
'c': ['d', 'e', 'f']
})
constraint_mock.reverse_transform.assert_called_once_with(data)
expected_constraint_input = pd.DataFrame({
'a': [1, 2, 3],
'b': [True, True, False],
'c': ['d', 'e', 'f'],
'd': ['a@gmail.com', 'b@gmail.com', 'c@gmail.com'],
'key': ['sdv_0', 'sdv_1', 'sdv_2']
})
constraint_mock.reverse_transform.assert_called_once_with(
DataFrameMatcher(expected_constraint_input))
data_from_call = dp._hyper_transformer.reverse_transform_subset.mock_calls[0][1][0]
pd.testing.assert_frame_equal(input_data, data_from_call)
dp._hyper_transformer.reverse_transform_subset.assert_called_once()
dp._hyper_transformer.create_anonymized_columns.has_calls(
call(num_rows=3, column_names=['d']),
call(num_rows=3, column_names=['key'])
)
expected_output = pd.DataFrame({
'a': [1., 2., 3.],
'b': [True, True, False],
'c': ['d', 'e', 'f'],
'd': ['a@gmail.com', 'b@gmail.com', 'c@gmail.com']
'key': ['sdv_0', 'sdv_1', 'sdv_2'],
'd': ['a@gmail.com', 'b@gmail.com', 'c@gmail.com'],
})
dp._hyper_transformer.create_anonymized_columns.assert_called_once_with(
num_rows=3,
column_names=['d']
)
pd.testing.assert_frame_equal(reverse_transformed, expected_output)

@patch('sdv.data_processing.data_processor.LOGGER')
Expand Down

0 comments on commit 8cc966c

Please sign in to comment.