diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 4bc080a1e..ccc13fd05 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -468,7 +468,9 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - self._fit(processed_data) + if not processed_data.empty: + self._fit(processed_data) + self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') self._fitted_sdv_version = pkg_resources.get_distribution('sdv').version diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 572832d0d..8a8db79d7 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -717,6 +717,20 @@ def test_fit_processed_data(self): instance._augment_tables.assert_called_once_with(data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted + + def test_fit_processed_data_empty_table(self): + """Test the fit attributes are properly set when data is empty.""" + # Setup + instance = Mock() + data = pd.DataFrame() + + # Run + BaseMultiTableSynthesizer.fit_processed_data(instance, data) + + # Assert + assert instance._fitted + assert instance._fitted_date + assert instance._fitted_sdv_version def test_fit(self): """Test that ``fit`` calls ``preprocess`` and then ``fit_processed_data``."""