diff --git a/sdv/constraints/base.py b/sdv/constraints/base.py index 158ebc1e9..2d7d97fc9 100644 --- a/sdv/constraints/base.py +++ b/sdv/constraints/base.py @@ -14,7 +14,7 @@ from sdv.constraints.errors import ( AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError) from sdv.errors import ConstraintsNotMetError -from sdv.utils import groupby_list +from sdv.utils import format_invalid_values_string, groupby_list LOGGER = logging.getLogger(__name__) @@ -201,12 +201,11 @@ def _validate_data_meets_constraint(self, table_data): if not is_valid_data.all(): constraint_data = table_data[list(self.constraint_columns)] invalid_rows = constraint_data[~is_valid_data] + invalid_rows_str = format_invalid_values_string(invalid_rows) err_msg = ( f"Data is not valid for the '{self.__class__.__name__}' constraint:\n" - f'{invalid_rows[:5]}' + f'{invalid_rows_str}' ) - if len(invalid_rows) > 5: - err_msg += f'\n+{len(invalid_rows) - 5} more' raise ConstraintsNotMetError(err_msg) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 84bc53129..5a41d170a 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -9,7 +9,6 @@ import pandas as pd -from sdv.constraints.errors import AggregateConstraintsError from sdv.errors import InvalidDataError from sdv.metadata.anonymization import SDTYPE_ANONYMIZERS, is_faker_function from sdv.metadata.errors import InvalidMetadataError @@ -18,8 +17,8 @@ from sdv.metadata.visualization import ( create_columns_node, create_summarized_columns_node, visualize_graph) from sdv.utils import ( - cast_to_iterable, is_boolean_type, is_datetime_type, is_numerical_type, load_data_from_csv, - validate_datetime_format) + cast_to_iterable, format_invalid_values_string, is_boolean_type, is_datetime_type, + is_numerical_type, load_data_from_csv, validate_datetime_format) LOGGER = logging.getLogger(__name__) @@ -542,22 +541,14 @@ def _validate_keys_dont_have_missing_values(self, data): return errors - @staticmethod - def _format_invalid_values_string(invalid_values): - invalid_values = sorted(invalid_values, key=lambda x: str(x)) - if len(invalid_values) > 3: - return invalid_values[:3] + [f'+ {len(invalid_values) - 3} more'] - - return invalid_values - def _validate_key_values_are_unique(self, data): errors = [] keys = self._get_primary_and_alternate_keys() for key in sorted(keys): repeated_values = set(data[key][data[key].duplicated()]) if repeated_values: - repeated_values = self._format_invalid_values_string(repeated_values) - errors.append(f"Key column '{key}' contains repeating values: {repeated_values}") + repeated_values = format_invalid_values_string(repeated_values) + errors.append(f"Key column '{key}' contains repeating values: " + repeated_values) return errors @@ -597,21 +588,11 @@ def _validate_column_data(self, column): ) if invalid_values: - invalid_values = self._format_invalid_values_string(invalid_values) + invalid_values = format_invalid_values_string(invalid_values) return [f"Invalid values found for {sdtype} column '{column.name}': {invalid_values}."] return [] - def _validate_constraints(self, data): - """Validate that the data satisfies the constraints.""" - errors = [] - try: - self._data_processor._fit_constraints(data) - except AggregateConstraintsError as e: - errors.append(e) - - return errors - def validate_data(self, data): """Validate the data matches the metadata. diff --git a/sdv/utils.py b/sdv/utils.py index ef2b1defd..9d065034c 100644 --- a/sdv/utils.py +++ b/sdv/utils.py @@ -193,3 +193,27 @@ def create_unique_name(name, list_names): result += '_' return result + + +def format_invalid_values_string(invalid_values): + """Convert ``invalid_values`` into a string of invalid values. + + Args: + invalid_values (pd.DataFrame, set): + Object of values to be converted into string. + + Returns: + str: + A stringified version of the object. + """ + if isinstance(invalid_values, pd.DataFrame): + if len(invalid_values) > 5: + return f'{invalid_values.head(5)}\n+{len(invalid_values) - 5} more' + + if isinstance(invalid_values, set): + invalid_values = sorted(invalid_values, key=lambda x: str(x)) + if len(invalid_values) > 3: + extra_missing_values = [f'+ {len(invalid_values) - 3} more'] + return f'{invalid_values[:3] + extra_missing_values}' + + return f'{invalid_values}'