Skip to content

Commit

Permalink
Update test according to sdmetrics release
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Aug 11, 2023
1 parent f7415a7 commit 42f62da
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.2,<0.8',
'deepecho>=0.4.1,<0.5',
'rdt>=1.6.1.dev0',
'sdmetrics>=0.10.0,<0.11',
'rdt>=1.6.1',
'sdmetrics>=0.11.0,<0.12',
'cloudpickle>=2.1.0,<3.0',
'boto3>=1.15.0,<2',
'botocore>=1.18,<2'
Expand Down
44 changes: 35 additions & 9 deletions tests/integration/evaluation/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,49 @@

from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.multi_table.hma import HMASynthesizer


def test_evaluation():
"""Test ``evaluate_quality`` and ``run_diagnostic``."""
# Setup
table = pd.DataFrame({'col': [1, 2, 3]})
data = {'table': table}
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('table', table)
synthesizer = HMASynthesizer(metadata)
table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4]})
slightly_different_table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 3.5]})
data = {
'table1': table,
'table2': table,
}
samples = {
'table1': table,
'table2': slightly_different_table,
}
metadata = MultiTableMetadata().load_from_dict({
'tables': {
'table1': {
'columns': {
'id': {'sdtype': 'id'},
'col': {'sdtype': 'numerical'},
},
},
'table2': {
'columns': {
'id': {'sdtype': 'id'},
'col': {'sdtype': 'numerical'},
},
}
},
'relationships': [
{
'parent_table_name': 'table1',
'parent_primary_key': 'id',
'child_table_name': 'table2',
'child_foreign_key': 'id'
}
]
})

# Run and Assert
synthesizer.fit(data)
samples = synthesizer.sample()
score = evaluate_quality(data, samples, metadata).get_score()
assert score == 0.6666666666666667
assert score == .9375

diagnostic = run_diagnostic(data, samples, metadata).get_results()
assert diagnostic == {
Expand Down

0 comments on commit 42f62da

Please sign in to comment.