Skip to content

Commit

Permalink
Add threshold for unique ints
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Aug 17, 2023
1 parent dc26f5e commit 6680d25
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
15 changes: 12 additions & 3 deletions sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ def detect_discrete_columns(metadata, data):
if column_data.empty:
continue

# Non-integer floats cannot be categorical
# Non-integer floats and integers with too many unique values cannot be categorical
try:
column_data = column_data.astype('float')
if list(column_data) != list(column_data.round()):
rounded_column_data = list(column_data.round())
column_data = list(column_data)
is_float = column_data != rounded_column_data
is_int = column_data == rounded_column_data
num_values = len(column_data)
num_categories = len(set(column_data))
threshold = max(10, num_values * .1)
has_many_categories = num_categories > threshold
if is_float or (is_int and has_many_categories):
continue
except BaseException:

except (ValueError, TypeError):
pass

# Everything else is presumed categorical
Expand Down
21 changes: 20 additions & 1 deletion tests/unit/single_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_detect_discrete_columns():
'join_date': ['2021-02-02', '2022-03-04', '2015-05-06', '2018-09-30'],
'uses_synthetic': [np.nan, True, False, False],
'surname': [object(), object(), object(), object()],
'bool': [0., 0., 1., np.nan]
'bool': [0., 0., 1., np.nan],
})

# Run
Expand All @@ -46,6 +46,25 @@ def test_detect_discrete_columns():
assert result == ['name', 'subscribed', 'uses_synthetic', 'surname', 'bool']


def test_detect_discrete_columns_numerical():
"""Test it for numerical columns."""
# Setup
metadata = SingleTableMetadata()
data = pd.DataFrame({
'float': [.1] * 1000,
'nan': [np.nan] * 1000,
'cat_int': list(range(100)) * 10,
'num_int': list(range(125)) * 8,
'float_int': [1, np.nan] * 500,
})

# Run
result = detect_discrete_columns(metadata, data)

# Assert
assert result == ['cat_int', 'float_int']


def test_flatten_array_default():
"""Test get flatten array."""
# Run
Expand Down

0 comments on commit 6680d25

Please sign in to comment.