Skip to content

Commit

Permalink
Update discrete_detect_columns logic (#1546)
Browse files Browse the repository at this point in the history
* Update discrete_detector

* Add threshold for unique ints

* Clean up detect_discrete logic
  • Loading branch information
fealho authored Aug 18, 2023
1 parent cc6fa97 commit 18f2c12
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
43 changes: 29 additions & 14 deletions sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
def detect_discrete_columns(metadata, data):
"""Detect the discrete columns in a dataset.
Because the metadata doesn't necessarily match the data (we only preprocess the data,
while the metadata stays static), this method tries to infer whether the data is
discrete.
Args:
metadata (sdv.metadata.SingleTableMetadata):
Metadata that belongs to the given ``data``.
Expand All @@ -27,24 +31,35 @@ def detect_discrete_columns(metadata, data):
A list of discrete columns to be used with some of ``sdv`` synthesizers.
"""
discrete_columns = []

for column in data.columns:
if column in metadata.columns:
if metadata.columns[column]['sdtype'] not in ['numerical', 'datetime']:
discrete_columns.append(column)
# Numerical and datetime columns never get preprocessed into categorical ones
if column in metadata.columns and \
metadata.columns[column]['sdtype'] in ['numerical', 'datetime']:
continue

else:
column_data = data[column].dropna()
if set(column_data.unique()) == {0.0, 1.0}:
column_data = column_data.astype(bool)
column_data = data[column].dropna()

try:
dtype = column_data.infer_objects().dtype.kind
if dtype in ['O', 'b']:
discrete_columns.append(column)
# Ignore columns with only nans and empty datasets
if column_data.empty:
continue

except Exception:
discrete_columns.append(column)
# Non-integer floats and integers with too many unique values are not categorical
try:
column_data = column_data.astype('float')
is_int = column_data.equals(column_data.round())
is_float = not is_int
num_values = len(column_data)
num_categories = column_data.nunique()
threshold = max(10, num_values * .1)
has_many_categories = num_categories > threshold
if is_float or (is_int and has_many_categories):
continue

except (ValueError, TypeError):
pass

# Everything else is presumed categorical
discrete_columns.append(column)

return discrete_columns

Expand Down
21 changes: 20 additions & 1 deletion tests/unit/single_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_detect_discrete_columns():
'join_date': ['2021-02-02', '2022-03-04', '2015-05-06', '2018-09-30'],
'uses_synthetic': [np.nan, True, False, False],
'surname': [object(), object(), object(), object()],
'bool': [0., 0., 1., np.nan]
'bool': [0., 0., 1., np.nan],
})

# Run
Expand All @@ -46,6 +46,25 @@ def test_detect_discrete_columns():
assert result == ['name', 'subscribed', 'uses_synthetic', 'surname', 'bool']


def test_detect_discrete_columns_numerical():
"""Test it for numerical columns."""
# Setup
metadata = SingleTableMetadata()
data = pd.DataFrame({
'float': [.1] * 1000,
'nan': [np.nan] * 1000,
'cat_int': list(range(100)) * 10,
'num_int': list(range(125)) * 8,
'float_int': [1, np.nan] * 500,
})

# Run
result = detect_discrete_columns(metadata, data)

# Assert
assert result == ['cat_int', 'float_int']


def test_flatten_array_default():
"""Test get flatten array."""
# Run
Expand Down

0 comments on commit 18f2c12

Please sign in to comment.