Skip to content

Commit

Permalink
Move string formatting to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Aug 17, 2023
1 parent 1123861 commit 09fcc6f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
7 changes: 3 additions & 4 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError)
from sdv.errors import ConstraintsNotMetError
from sdv.utils import groupby_list
from sdv.utils import format_invalid_values_string, groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,12 +201,11 @@ def _validate_data_meets_constraint(self, table_data):
if not is_valid_data.all():
constraint_data = table_data[list(self.constraint_columns)]
invalid_rows = constraint_data[~is_valid_data]
invalid_rows_str = format_invalid_values_string(invalid_rows)
err_msg = (
f"Data is not valid for the '{self.__class__.__name__}' constraint:\n"
f'{invalid_rows[:5]}'
f'{invalid_rows_str}'
)
if len(invalid_rows) > 5:
err_msg += f'\n+{len(invalid_rows) - 5} more'

raise ConstraintsNotMetError(err_msg)

Expand Down
29 changes: 5 additions & 24 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pandas as pd

from sdv.constraints.errors import AggregateConstraintsError
from sdv.errors import InvalidDataError
from sdv.metadata.anonymization import SDTYPE_ANONYMIZERS, is_faker_function
from sdv.metadata.errors import InvalidMetadataError
Expand All @@ -18,8 +17,8 @@
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import (
cast_to_iterable, is_boolean_type, is_datetime_type, is_numerical_type, load_data_from_csv,
validate_datetime_format)
cast_to_iterable, format_invalid_values_string, is_boolean_type, is_datetime_type,
is_numerical_type, load_data_from_csv, validate_datetime_format)

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -542,22 +541,14 @@ def _validate_keys_dont_have_missing_values(self, data):

return errors

@staticmethod
def _format_invalid_values_string(invalid_values):
invalid_values = sorted(invalid_values, key=lambda x: str(x))
if len(invalid_values) > 3:
return invalid_values[:3] + [f'+ {len(invalid_values) - 3} more']

return invalid_values

def _validate_key_values_are_unique(self, data):
errors = []
keys = self._get_primary_and_alternate_keys()
for key in sorted(keys):
repeated_values = set(data[key][data[key].duplicated()])
if repeated_values:
repeated_values = self._format_invalid_values_string(repeated_values)
errors.append(f"Key column '{key}' contains repeating values: {repeated_values}")
repeated_values = format_invalid_values_string(repeated_values)
errors.append(f"Key column '{key}' contains repeating values: " + repeated_values)

return errors

Expand Down Expand Up @@ -597,21 +588,11 @@ def _validate_column_data(self, column):
)

if invalid_values:
invalid_values = self._format_invalid_values_string(invalid_values)
invalid_values = format_invalid_values_string(invalid_values)
return [f"Invalid values found for {sdtype} column '{column.name}': {invalid_values}."]

return []

def _validate_constraints(self, data):
"""Validate that the data satisfies the constraints."""
errors = []
try:
self._data_processor._fit_constraints(data)
except AggregateConstraintsError as e:
errors.append(e)

return errors

def validate_data(self, data):
"""Validate the data matches the metadata.
Expand Down
24 changes: 24 additions & 0 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,27 @@ def create_unique_name(name, list_names):
result += '_'

return result


def format_invalid_values_string(invalid_values):
"""Convert ``invalid_values`` into a string of invalid values.
Args:
invalid_values (pd.DataFrame, set):
Object of values to be converted into string.
Returns:
str:
A stringified version of the object.
"""
if isinstance(invalid_values, pd.DataFrame):
if len(invalid_values) > 5:
return f'{invalid_values.head(5)}\n+{len(invalid_values) - 5} more'

if isinstance(invalid_values, set):
invalid_values = sorted(invalid_values, key=lambda x: str(x))
if len(invalid_values) > 3:
extra_missing_values = [f'+ {len(invalid_values) - 3} more']
return f'{invalid_values[:3] + extra_missing_values}'

return f'{invalid_values}'

0 comments on commit 09fcc6f

Please sign in to comment.