From dc26f5ec02e0d94fd6fa3f2b9cba88df3c626a21 Mon Sep 17 00:00:00 2001 From: Felipe Date: Tue, 15 Aug 2023 14:57:57 -0700 Subject: [PATCH 1/3] Update discrete_detector --- sdv/single_table/utils.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 69b36a720..373fc6fef 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -15,6 +15,10 @@ def detect_discrete_columns(metadata, data): """Detect the discrete columns in a dataset. + Because the metadata doesn't necessarily match the data (we only preprocess the data, + while the metadata stays static), this method tries to infer whether the data is + discrete. + Args: metadata (sdv.metadata.SingleTableMetadata): Metadata that belongs to the given ``data``. @@ -27,24 +31,28 @@ def detect_discrete_columns(metadata, data): A list of discrete columns to be used with some of ``sdv`` synthesizers. """ discrete_columns = [] - for column in data.columns: - if column in metadata.columns: - if metadata.columns[column]['sdtype'] not in ['numerical', 'datetime']: - discrete_columns.append(column) + # Numerical and datetime columns never get preprocessed into categorical ones + if column in metadata.columns and \ + metadata.columns[column]['sdtype'] in ['numerical', 'datetime']: + continue - else: - column_data = data[column].dropna() - if set(column_data.unique()) == {0.0, 1.0}: - column_data = column_data.astype(bool) + column_data = data[column].dropna() + + # Ignore columns with only nans and empty datasets + if column_data.empty: + continue - try: - dtype = column_data.infer_objects().dtype.kind - if dtype in ['O', 'b']: - discrete_columns.append(column) + # Non-integer floats cannot be categorical + try: + column_data = column_data.astype('float') + if list(column_data) != list(column_data.round()): + continue + except BaseException: + pass - except Exception: - discrete_columns.append(column) + # Everything else is presumed categorical + discrete_columns.append(column) return discrete_columns From 6680d2590fa2d81ee47032fd14f395c17803e4a3 Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 17 Aug 2023 08:55:58 -0700 Subject: [PATCH 2/3] Add threshold for unique ints --- sdv/single_table/utils.py | 15 ++++++++++++--- tests/unit/single_table/test_utils.py | 21 ++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 373fc6fef..4259efa52 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -43,12 +43,21 @@ def detect_discrete_columns(metadata, data): if column_data.empty: continue - # Non-integer floats cannot be categorical + # Non-integer floats and integers with too many unique values cannot be categorical try: column_data = column_data.astype('float') - if list(column_data) != list(column_data.round()): + rounded_column_data = list(column_data.round()) + column_data = list(column_data) + is_float = column_data != rounded_column_data + is_int = column_data == rounded_column_data + num_values = len(column_data) + num_categories = len(set(column_data)) + threshold = max(10, num_values * .1) + has_many_categories = num_categories > threshold + if is_float or (is_int and has_many_categories): continue - except BaseException: + + except (ValueError, TypeError): pass # Everything else is presumed categorical diff --git a/tests/unit/single_table/test_utils.py b/tests/unit/single_table/test_utils.py index 619d907f8..94dd77a8a 100644 --- a/tests/unit/single_table/test_utils.py +++ b/tests/unit/single_table/test_utils.py @@ -36,7 +36,7 @@ def test_detect_discrete_columns(): 'join_date': ['2021-02-02', '2022-03-04', '2015-05-06', '2018-09-30'], 'uses_synthetic': [np.nan, True, False, False], 'surname': [object(), object(), object(), object()], - 'bool': [0., 0., 1., np.nan] + 'bool': [0., 0., 1., np.nan], }) # Run @@ -46,6 +46,25 @@ def test_detect_discrete_columns(): assert result == ['name', 'subscribed', 'uses_synthetic', 'surname', 'bool'] +def test_detect_discrete_columns_numerical(): + """Test it for numerical columns.""" + # Setup + metadata = SingleTableMetadata() + data = pd.DataFrame({ + 'float': [.1] * 1000, + 'nan': [np.nan] * 1000, + 'cat_int': list(range(100)) * 10, + 'num_int': list(range(125)) * 8, + 'float_int': [1, np.nan] * 500, + }) + + # Run + result = detect_discrete_columns(metadata, data) + + # Assert + assert result == ['cat_int', 'float_int'] + + def test_flatten_array_default(): """Test get flatten array.""" # Run From 4f28458c45cf6fb03728ce37d6098fe7591c70f4 Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 17 Aug 2023 11:48:32 -0700 Subject: [PATCH 3/3] Clean up detect_discrete logic --- sdv/single_table/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 4259efa52..45d95db2f 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -43,15 +43,13 @@ def detect_discrete_columns(metadata, data): if column_data.empty: continue - # Non-integer floats and integers with too many unique values cannot be categorical + # Non-integer floats and integers with too many unique values are not categorical try: column_data = column_data.astype('float') - rounded_column_data = list(column_data.round()) - column_data = list(column_data) - is_float = column_data != rounded_column_data - is_int = column_data == rounded_column_data + is_int = column_data.equals(column_data.round()) + is_float = not is_int num_values = len(column_data) - num_categories = len(set(column_data)) + num_categories = column_data.nunique() threshold = max(10, num_values * .1) has_many_categories = num_categories > threshold if is_float or (is_int and has_many_categories):