From c9b4fc5d1f3c775c7a9a135ef38f4a5b0e01b64e Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 17 Aug 2023 13:36:13 -0700 Subject: [PATCH] Docstrings + duplicate method --- sdv/metadata/multi_table.py | 22 ++++++++++++++++------ sdv/multi_table/base.py | 36 +++++++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 2506e987b..99e61db3c 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -534,6 +534,7 @@ def validate(self): ) def _validate_missing_tables(self, data): + """Validate the data doesn't have all the columns in the metadata.""" errors = [] missing_tables = set(self.tables) - set(data) if missing_tables: @@ -541,14 +542,12 @@ def _validate_missing_tables(self, data): return errors - def _validate_all_tables(self, data, table_synthesizers=None): + def _validate_all_tables(self, data): + """Validate every table of the data has a valid table/metadata pair.""" errors = [] for table_name, table_data in data.items(): try: - if table_synthesizers: - table_synthesizers[table_name].validate(table_data) - else: - self.tables[table_name].validate_data(table_data) + self.tables[table_name].validate_data(table_data) except InvalidDataError as error: error_msg = f"Table: '{table_name}'" @@ -566,6 +565,7 @@ def _validate_all_tables(self, data, table_synthesizers=None): return errors def _validate_foreign_keys(self, data): + """Validate all foreign key relationships.""" error_msg = None errors = [] for relation in self.relationships: @@ -597,7 +597,17 @@ def _validate_foreign_keys(self, data): return [error_msg] if error_msg else [] def validate_data(self, data): - """Validate the data matches the metadata.""" + """Validate the data matches the metadata. + + Checks the following rules: + * all tables of the metadata are present in the data + * every table of the data satisfies its own metadata + * all foreign keys belong to a primay key + + Args: + data (pd.DataFrame): + The data to validate. + """ errors = [] errors += self._validate_missing_tables(data) errors += self._validate_all_tables(data) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 92240267c..835775016 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -141,13 +141,27 @@ def get_metadata(self): """Return the ``MultiTableMetadata`` for this synthesizer.""" return self.metadata - def _get_all_foreign_keys(self, table_name): - foreign_keys = [] - for relation in self.metadata.relationships: - if table_name == relation['child_table_name']: - foreign_keys.append(deepcopy(relation['child_foreign_key'])) + def _validate_all_tables(self, data): + """Validate every table of the data has a valid table/metadata pair.""" + errors = [] + for table_name, table_data in data.items(): + try: + self._table_synthesizers[table_name].validate(table_data) - return foreign_keys + except InvalidDataError as error: + error_msg = f"Table: '{table_name}'" + for _error in error.errors: + error_msg += f'\nError: {_error}' + + errors.append(error_msg) + + except ValueError as error: + errors.append(str(error)) + + except KeyError: + continue + + return errors def validate(self, data): """Validate data. @@ -170,7 +184,7 @@ def validate(self, data): """ errors = [] errors += self.metadata._validate_missing_tables(data) - errors += self.metadata._validate_all_tables(data, self._table_synthesizers) + errors += self._validate_all_tables(data) errors += self.metadata._validate_foreign_keys(data) if errors: @@ -180,6 +194,14 @@ def _validate_table_name(self, table_name): if table_name not in self._table_synthesizers: raise InvalidDataError([f"Table '{table_name}' is not present in the metadata."]) + def _get_all_foreign_keys(self, table_name): + foreign_keys = [] + for relation in self.metadata.relationships: + if table_name == relation['child_table_name']: + foreign_keys.append(deepcopy(relation['child_foreign_key'])) + + return foreign_keys + def _assign_table_transformers(self, synthesizer, table_name, table_data): """Update the ``synthesizer`` to ignore the foreign keys while preprocessing the data.""" synthesizer.auto_assign_transformers(table_data)