Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update discrete_detect_columns logic #1546

Merged
merged 3 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
def detect_discrete_columns(metadata, data):
"""Detect the discrete columns in a dataset.

Because the metadata doesn't necessarily match the data (we only preprocess the data,
while the metadata stays static), this method tries to infer whether the data is
discrete.

Args:
metadata (sdv.metadata.SingleTableMetadata):
Metadata that belongs to the given ``data``.
Expand All @@ -27,24 +31,35 @@ def detect_discrete_columns(metadata, data):
A list of discrete columns to be used with some of ``sdv`` synthesizers.
"""
discrete_columns = []

for column in data.columns:
if column in metadata.columns:
if metadata.columns[column]['sdtype'] not in ['numerical', 'datetime']:
discrete_columns.append(column)
# Numerical and datetime columns never get preprocessed into categorical ones
if column in metadata.columns and \
metadata.columns[column]['sdtype'] in ['numerical', 'datetime']:
continue

else:
column_data = data[column].dropna()
if set(column_data.unique()) == {0.0, 1.0}:
column_data = column_data.astype(bool)
column_data = data[column].dropna()

try:
dtype = column_data.infer_objects().dtype.kind
if dtype in ['O', 'b']:
discrete_columns.append(column)
# Ignore columns with only nans and empty datasets
if column_data.empty:
continue

except Exception:
discrete_columns.append(column)
# Non-integer floats and integers with too many unique values are not categorical
try:
column_data = column_data.astype('float')
is_int = column_data.equals(column_data.round())
is_float = not is_int
num_values = len(column_data)
num_categories = column_data.nunique()
threshold = max(10, num_values * .1)
fealho marked this conversation as resolved.
Show resolved Hide resolved
has_many_categories = num_categories > threshold
if is_float or (is_int and has_many_categories):
continue

except (ValueError, TypeError):
pass

# Everything else is presumed categorical
discrete_columns.append(column)
fealho marked this conversation as resolved.
Show resolved Hide resolved

return discrete_columns

Expand Down
21 changes: 20 additions & 1 deletion tests/unit/single_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_detect_discrete_columns():
'join_date': ['2021-02-02', '2022-03-04', '2015-05-06', '2018-09-30'],
'uses_synthetic': [np.nan, True, False, False],
'surname': [object(), object(), object(), object()],
'bool': [0., 0., 1., np.nan]
'bool': [0., 0., 1., np.nan],
})

# Run
Expand All @@ -46,6 +46,25 @@ def test_detect_discrete_columns():
assert result == ['name', 'subscribed', 'uses_synthetic', 'surname', 'bool']


def test_detect_discrete_columns_numerical():
"""Test it for numerical columns."""
# Setup
metadata = SingleTableMetadata()
data = pd.DataFrame({
'float': [.1] * 1000,
'nan': [np.nan] * 1000,
'cat_int': list(range(100)) * 10,
'num_int': list(range(125)) * 8,
'float_int': [1, np.nan] * 500,
})

# Run
result = detect_discrete_columns(metadata, data)

# Assert
assert result == ['cat_int', 'float_int']


def test_flatten_array_default():
"""Test get flatten array."""
# Run
Expand Down
Loading