From 42f62dade616f4f7a1b77d3ca18a6900a6e90a20 Mon Sep 17 00:00:00 2001 From: Felipe Date: Fri, 11 Aug 2023 12:19:04 -0700 Subject: [PATCH] Update test according to sdmetrics release --- setup.py | 4 +- .../evaluation/test_multi_table.py | 44 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index a89b20717..fad4b479d 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,8 @@ 'copulas>=0.9.0,<0.10', 'ctgan>=0.7.2,<0.8', 'deepecho>=0.4.1,<0.5', - 'rdt>=1.6.1.dev0', - 'sdmetrics>=0.10.0,<0.11', + 'rdt>=1.6.1', + 'sdmetrics>=0.11.0,<0.12', 'cloudpickle>=2.1.0,<3.0', 'boto3>=1.15.0,<2', 'botocore>=1.18,<2' diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index b2ef4d712..80087a768 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -3,23 +3,49 @@ from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic from sdv.metadata.multi_table import MultiTableMetadata -from sdv.multi_table.hma import HMASynthesizer def test_evaluation(): """Test ``evaluate_quality`` and ``run_diagnostic``.""" # Setup - table = pd.DataFrame({'col': [1, 2, 3]}) - data = {'table': table} - metadata = MultiTableMetadata() - metadata.detect_table_from_dataframe('table', table) - synthesizer = HMASynthesizer(metadata) + table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4]}) + slightly_different_table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 3.5]}) + data = { + 'table1': table, + 'table2': table, + } + samples = { + 'table1': table, + 'table2': slightly_different_table, + } + metadata = MultiTableMetadata().load_from_dict({ + 'tables': { + 'table1': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'col': {'sdtype': 'numerical'}, + }, + }, + 'table2': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'col': {'sdtype': 'numerical'}, + }, + } + }, + 'relationships': [ + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'id', + 'child_table_name': 'table2', + 'child_foreign_key': 'id' + } + ] + }) # Run and Assert - synthesizer.fit(data) - samples = synthesizer.sample() score = evaluate_quality(data, samples, metadata).get_score() - assert score == 0.6666666666666667 + assert score == .9375 diagnostic = run_diagnostic(data, samples, metadata).get_results() assert diagnostic == {