From b70aa3a54172ccf4370266c721db803daceb0438 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 3 Oct 2024 18:28:27 +0200 Subject: [PATCH] Update creating single table synthesizer within multitable synthesizer to use Metadata instead of SingleTableMetadata --- sdv/metadata/metadata.py | 1 + sdv/multi_table/base.py | 6 +++++- sdv/multi_table/hma.py | 13 +++++++------ sdv/sampling/hierarchical_sampler.py | 9 +++------ tests/unit/multi_table/test_hma.py | 7 +++++-- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 1d0ebaec2..6075b8d6c 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -131,6 +131,7 @@ def _set_metadata_dict(self, metadata, single_table_name=None): 'No table name was provided to metadata containing only one table. ' f'Assigning name: {single_table_name}' ) + self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) def _get_single_table_name(self): diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 843653d06..fa8bb3b4d 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -209,8 +209,12 @@ def set_table_parameters(self, table_name, table_parameters): A dictionary with the parameters as keys and the values to be used to instantiate the table's synthesizer. """ + single_table_metadata = Metadata.load_from_dict( + self.metadata.tables[table_name].to_dict(), single_table_name=table_name + ) + self._table_synthesizers[table_name] = self._synthesizer( - metadata=self.metadata.tables[table_name], **table_parameters + metadata=single_table_metadata, **table_parameters ) self._table_parameters[table_name].update(deepcopy(table_parameters)) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 5de322ac4..f0de2045b 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -11,6 +11,7 @@ from sdv._utils import _get_root_tables from sdv.errors import SynthesizerInputError +from sdv.metadata import Metadata from sdv.multi_table.base import BaseMultiTableSynthesizer from sdv.sampling import BaseHierarchicalSampler @@ -23,9 +24,8 @@ class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): """Hierarchical Modeling Algorithm One. Args: - metadata (sdv.metadata.multi_table.MultiTableMetadata): - Multi table metadata representing the data tables that this synthesizer will be used - for. + metadata (sdv.metadata.Metadata): + Metadata representing the data tables that this synthesizer will be used for. locales (list or str): The default locale(s) to use for AnonymizedFaker transformers. Defaults to ``['en_US']``. @@ -47,8 +47,8 @@ def _get_num_data_columns(metadata): """Get the number of data columns, ie colums that are not id, for each table. Args: - metadata (MultiTableMetadata): - Metadata of the datasets. + metadata (sdv.metadata.Metadata): + Metadata representing the data tables that this synthesizer will be used for. """ columns_per_table = {} for table_name, table in metadata.tables.items(): @@ -552,8 +552,9 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): if parent_row is not None: parameters = self._extract_parameters(parent_row, child_name, foreign_key) default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) - table_meta = self.metadata.tables[child_name] + table_meta = Metadata.load_from_dict(table_meta.to_dict(), single_table_name=child_name) + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) synthesizer._set_parameters(parameters, default_parameters) else: diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index be0fbaa3e..3606d5cbf 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -95,13 +95,10 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] if num_rows is None: num_rows = parent_row[f'__{child_name}__{foreign_key}__num_rows'] - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*") - child_synthesizer = self._recreate_child_synthesizer( - child_name, parent_name, parent_row - ) - sampled_rows = self._sample_rows(child_synthesizer, num_rows) + child_synthesizer = self._recreate_child_synthesizer(child_name, parent_name, parent_row) + + sampled_rows = self._sample_rows(child_synthesizer, num_rows) if len(sampled_rows): parent_key = self.metadata.tables[parent_name].primary_key if foreign_key in sampled_rows: diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 727b40f5a..f5a9ce47b 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -451,7 +451,8 @@ def test__extract_parameters(self): assert result == expected_result - def test__recreate_child_synthesizer(self): + @patch('sdv.multi_table.hma.Metadata') + def test__recreate_child_synthesizer(self, mock_metadata): """Test that this method returns a synthesizer for the given child table.""" # Setup instance = Mock() @@ -477,7 +478,9 @@ def test__recreate_child_synthesizer(self): # Assert assert synthesizer == instance._synthesizer.return_value assert synthesizer._data_processor == table_synthesizer._data_processor - instance._synthesizer.assert_called_once_with(table_meta, a=1) + instance._synthesizer.assert_called_once_with( + mock_metadata.load_from_dict.return_value, a=1 + ) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value, {'colA': 'default_param', 'colB': 'default_param'},