Skip to content

Commit

Permalink
Add validate_data for single table (#1540)
Browse files Browse the repository at this point in the history
* Working version for single_table validation

* Update get methods

* Rename method

* Move string formatting to utils

* Add parameter to format method
  • Loading branch information
fealho authored Aug 18, 2023
1 parent 473ecee commit cc6fa97
Show file tree
Hide file tree
Showing 13 changed files with 580 additions and 510 deletions.
Empty file added output.pkl
Empty file.
7 changes: 3 additions & 4 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError)
from sdv.errors import ConstraintsNotMetError
from sdv.utils import groupby_list
from sdv.utils import format_invalid_values_string, groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,12 +201,11 @@ def _validate_data_meets_constraint(self, table_data):
if not is_valid_data.all():
constraint_data = table_data[list(self.constraint_columns)]
invalid_rows = constraint_data[~is_valid_data]
invalid_rows_str = format_invalid_values_string(invalid_rows, 5)
err_msg = (
f"Data is not valid for the '{self.__class__.__name__}' constraint:\n"
f'{invalid_rows[:5]}'
f'{invalid_rows_str}'
)
if len(invalid_rows) > 5:
err_msg += f'\n+{len(invalid_rows) - 5} more'

raise ConstraintsNotMetError(err_msg)

Expand Down
13 changes: 13 additions & 0 deletions sdv/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,16 @@ class SamplingError(Exception):

class NonParametricError(Exception):
"""Exception to indicate that a model is not parametric."""


class InvalidDataError(Exception):
"""Error to raise when data is not valid."""

def __init__(self, errors):
self.errors = errors

def __str__(self):
return (
'The provided data does not match the metadata:\n' +
'\n\n'.join(map(str, self.errors))
)
157 changes: 152 additions & 5 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
from copy import deepcopy
from datetime import datetime

import pandas as pd

from sdv.errors import InvalidDataError
from sdv.metadata.anonymization import SDTYPE_ANONYMIZERS, is_faker_function
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable, load_data_from_csv
from sdv.utils import (
cast_to_iterable, format_invalid_values_string, is_boolean_type, is_datetime_type,
is_numerical_type, load_data_from_csv, validate_datetime_format)

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,7 +157,7 @@ def _validate_sdtype(self, sdtype):
'supported SDV sdtypes.'
)

def _validate_column(self, column_name, sdtype, **kwargs):
def _validate_column_args(self, column_name, sdtype, **kwargs):
self._validate_sdtype(sdtype)
self._validate_unexpected_kwargs(column_name, sdtype, **kwargs)
if sdtype == 'categorical':
Expand Down Expand Up @@ -193,7 +198,7 @@ def add_column(self, column_name, **kwargs):
if sdtype is None:
raise InvalidMetadataError(f"Please provide a 'sdtype' for column '{column_name}'.")

self._validate_column(column_name, **kwargs)
self._validate_column_args(column_name, **kwargs)
column_kwargs = deepcopy(kwargs)
if sdtype not in self._SDTYPE_KWARGS:
pii = column_kwargs.get('pii', True)
Expand Down Expand Up @@ -234,7 +239,7 @@ def update_column(self, column_name, **kwargs):
sdtype = self.columns[column_name]['sdtype']
_kwargs['sdtype'] = sdtype

self._validate_column(column_name, sdtype, **kwargs)
self._validate_column_args(column_name, sdtype, **kwargs)
self.columns[column_name] = _kwargs

def to_dict(self):
Expand Down Expand Up @@ -460,6 +465,7 @@ def validate(self):
- ``InvalidMetadataError`` if the metadata is invalid.
"""
errors = []

# Validate keys
self._append_error(errors, self._validate_key, self.primary_key, 'primary')
self._append_error(errors, self._validate_key, self.sequence_key, 'sequence')
Expand All @@ -471,14 +477,155 @@ def validate(self):

# Validate columns
for column, kwargs in self.columns.items():
self._append_error(errors, self._validate_column, column, **kwargs)
self._append_error(errors, self._validate_column_args, column, **kwargs)

if errors:
raise InvalidMetadataError(
'The following errors were found in the metadata:\n\n'
+ '\n'.join([str(e) for e in errors])
)

def _validate_metadata_matches_data(self, columns):
errors = []
metadata_columns = self.columns or {}
missing_data_columns = set(columns).difference(metadata_columns)
if missing_data_columns:
errors.append(
f'The columns {sorted(missing_data_columns)} are not present in the metadata.')

missing_metadata_columns = set(metadata_columns).difference(columns)
if missing_metadata_columns:
errors.append(
f'The metadata columns {sorted(missing_metadata_columns)} '
'are not present in the data.'
)

if errors:
raise InvalidDataError(errors)

def _get_primary_and_alternate_keys(self):
"""Get set of primary and alternate keys.
Returns:
set:
Set of keys.
"""
keys = set(self.alternate_keys)
if self.primary_key:
keys.update({self.primary_key})

return keys

def _get_set_of_sequence_keys(self):
"""Get set with a sequence key.
Returns:
set:
Set of keys.
"""
if isinstance(self.sequence_key, tuple):
return set(self.sequence_key)

if isinstance(self.sequence_key, str):
return {self.sequence_key}

return set()

def _validate_keys_dont_have_missing_values(self, data):
errors = []
keys = self._get_primary_and_alternate_keys()
keys.update(self._get_set_of_sequence_keys())
for key in sorted(keys):
if pd.isna(data[key]).any():
errors.append(f"Key column '{key}' contains missing values.")

return errors

def _validate_key_values_are_unique(self, data):
errors = []
keys = self._get_primary_and_alternate_keys()
for key in sorted(keys):
repeated_values = set(data[key][data[key].duplicated()])
if repeated_values:
repeated_values = format_invalid_values_string(repeated_values, 3)
errors.append(f"Key column '{key}' contains repeating values: " + repeated_values)

return errors

@staticmethod
def _get_invalid_column_values(column, validation_function):
valid = column.apply(validation_function)
return set(column[~valid])

def _validate_column_data(self, column):
"""Validate values of the column satisfy its sdtype properties."""
column_metadata = self.columns[column.name]
sdtype = column_metadata['sdtype']
invalid_values = None

# boolean values must be True/False, None or missing values
# int/str are not allowed
if sdtype == 'boolean':
invalid_values = self._get_invalid_column_values(column, is_boolean_type)

# numerical values must be int/float, None or missing values
# str/bool are not allowed
if sdtype == 'numerical':
invalid_values = self._get_invalid_column_values(column, is_numerical_type)

# datetime values must be castable to datetime, None or missing values
if sdtype == 'datetime':
datetime_format = column_metadata.get('datetime_format')
if datetime_format:
invalid_values = validate_datetime_format(column, datetime_format)
else:
# cap number of samples to be validated to improve performance
num_samples_to_validate = min(len(column), 1000)

invalid_values = self._get_invalid_column_values(
column.sample(num_samples_to_validate),
lambda x: pd.isna(x) | is_datetime_type(x)
)

if invalid_values:
invalid_values = format_invalid_values_string(invalid_values, 3)
return [f"Invalid values found for {sdtype} column '{column.name}': {invalid_values}."]

return []

def validate_data(self, data):
"""Validate the data matches the metadata.
Checks the metadata follows the following rules:
* data columns match the metadata
* keys don't have missing values
* primary or alternate keys are unique
* values of a column satisfy their sdtype
Args:
data (pd.DataFrame):
The data to validate.
"""
if not isinstance(data, pd.DataFrame):
raise ValueError(f'Data must be a DataFrame, not a {type(data)}.')

# Both metadata and data must have the same set of columns
self._validate_metadata_matches_data(data.columns)

errors = []
# Primary, sequence and alternate keys can't have missing values
errors += self._validate_keys_dont_have_missing_values(data)

# Primary and alternate key values must be unique
errors += self._validate_key_values_are_unique(data)

# Every column must satisfy the properties of their sdtypes
for column in data:
errors += self._validate_column_data(data[column])

if errors:
raise InvalidDataError(errors)

def visualize(self, show_table_details='full', output_filepath=None):
"""Create a visualization of the single-table dataset.
Expand Down
3 changes: 1 addition & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import pkg_resources
from tqdm import tqdm

from sdv.errors import SynthesizerInputError
from sdv.errors import InvalidDataError, SynthesizerInputError
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError


class BaseMultiTableSynthesizer:
Expand Down
Loading

0 comments on commit cc6fa97

Please sign in to comment.