From 18f2c12c6e306fca4fa5488fcd2169ec5fdceddc Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Fri, 18 Aug 2023 09:04:06 -0700 Subject: [PATCH] Update `discrete_detect_columns` logic (#1546) * Update discrete_detector * Add threshold for unique ints * Clean up detect_discrete logic --- sdv/single_table/utils.py | 43 ++++++++++++++++++--------- tests/unit/single_table/test_utils.py | 21 ++++++++++++- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 69b36a720..45d95db2f 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -15,6 +15,10 @@ def detect_discrete_columns(metadata, data): """Detect the discrete columns in a dataset. + Because the metadata doesn't necessarily match the data (we only preprocess the data, + while the metadata stays static), this method tries to infer whether the data is + discrete. + Args: metadata (sdv.metadata.SingleTableMetadata): Metadata that belongs to the given ``data``. @@ -27,24 +31,35 @@ def detect_discrete_columns(metadata, data): A list of discrete columns to be used with some of ``sdv`` synthesizers. """ discrete_columns = [] - for column in data.columns: - if column in metadata.columns: - if metadata.columns[column]['sdtype'] not in ['numerical', 'datetime']: - discrete_columns.append(column) + # Numerical and datetime columns never get preprocessed into categorical ones + if column in metadata.columns and \ + metadata.columns[column]['sdtype'] in ['numerical', 'datetime']: + continue - else: - column_data = data[column].dropna() - if set(column_data.unique()) == {0.0, 1.0}: - column_data = column_data.astype(bool) + column_data = data[column].dropna() - try: - dtype = column_data.infer_objects().dtype.kind - if dtype in ['O', 'b']: - discrete_columns.append(column) + # Ignore columns with only nans and empty datasets + if column_data.empty: + continue - except Exception: - discrete_columns.append(column) + # Non-integer floats and integers with too many unique values are not categorical + try: + column_data = column_data.astype('float') + is_int = column_data.equals(column_data.round()) + is_float = not is_int + num_values = len(column_data) + num_categories = column_data.nunique() + threshold = max(10, num_values * .1) + has_many_categories = num_categories > threshold + if is_float or (is_int and has_many_categories): + continue + + except (ValueError, TypeError): + pass + + # Everything else is presumed categorical + discrete_columns.append(column) return discrete_columns diff --git a/tests/unit/single_table/test_utils.py b/tests/unit/single_table/test_utils.py index 619d907f8..94dd77a8a 100644 --- a/tests/unit/single_table/test_utils.py +++ b/tests/unit/single_table/test_utils.py @@ -36,7 +36,7 @@ def test_detect_discrete_columns(): 'join_date': ['2021-02-02', '2022-03-04', '2015-05-06', '2018-09-30'], 'uses_synthetic': [np.nan, True, False, False], 'surname': [object(), object(), object(), object()], - 'bool': [0., 0., 1., np.nan] + 'bool': [0., 0., 1., np.nan], }) # Run @@ -46,6 +46,25 @@ def test_detect_discrete_columns(): assert result == ['name', 'subscribed', 'uses_synthetic', 'surname', 'bool'] +def test_detect_discrete_columns_numerical(): + """Test it for numerical columns.""" + # Setup + metadata = SingleTableMetadata() + data = pd.DataFrame({ + 'float': [.1] * 1000, + 'nan': [np.nan] * 1000, + 'cat_int': list(range(100)) * 10, + 'num_int': list(range(125)) * 8, + 'float_int': [1, np.nan] * 500, + }) + + # Run + result = detect_discrete_columns(metadata, data) + + # Assert + assert result == ['cat_int', 'float_int'] + + def test_flatten_array_default(): """Test get flatten array.""" # Run