diff --git a/tests/unit/dpulearn_tests/test_dpulearn_fit.py b/tests/unit/dpulearn_tests/test_dpulearn_fit.py index ce713aa1..b06e37bb 100644 --- a/tests/unit/dpulearn_tests/test_dpulearn_fit.py +++ b/tests/unit/dpulearn_tests/test_dpulearn_fit.py @@ -23,12 +23,14 @@ def create_labels(size): def check_invalid_conditions(X, labels, min_samples=3, check_unique=True): n_samples, n_features = X.shape + n_unique_labels = len(set(labels)) n_unique_samples = len(set(map(tuple, X))) conditions = [ (np.any(np.isinf(X)) or np.any(np.isnan(X)), "X contains NaN or Inf"), (n_samples < min_samples, f"n_samples={n_samples} should be >= {min_samples}"), (n_features < 2, f"n_features={n_features} should be >= 2"), (len(labels) != n_samples, "Length of labels should match n_samples."), + (n_unique_labels < 2, f"n_unique_labels={n_unique_samples} should be >= 2") ] if check_unique: conditions.append((n_unique_samples == 1, "Feature matrix 'X' should not have all identical samples."))