Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add detect_from_csvs and detect_from_dataframes methods to MultiTableMetadata #1533

Merged
merged 6 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import warnings
from collections import defaultdict
from copy import deepcopy
from pathlib import Path

import pandas as pd

from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable
from sdv.utils import cast_to_iterable, load_data_from_csv

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -344,6 +347,19 @@ def detect_table_from_dataframe(self, table_name, data):
self.tables[table_name] = table
self._log_detected_table(table)

def detect_from_dataframes(self, data):
"""Detect the metadata for all tables in a dictionary of dataframes.

Args:
data (dict):
Dictionary of table names to dataframes.
"""
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')

for table_name, dataframe in data.items():
self.detect_table_from_dataframe(table_name, dataframe)

def detect_table_from_csv(self, table_name, filepath):
"""Detect the metadata for a table from a csv file.

Expand All @@ -355,11 +371,33 @@ def detect_table_from_csv(self, table_name, filepath):
"""
self._validate_table_not_detected(table_name)
table = SingleTableMetadata()
data = table._load_data_from_csv(filepath)
data = load_data_from_csv(filepath)
table._detect_columns(data)
self.tables[table_name] = table
self._log_detected_table(table)

def detect_from_csvs(self, folder_name):
"""Detect the metadata for all tables in a folder of csv files.

Args:
folder_name (str):
Name of the folder to detect the metadata from.

"""
folder_path = Path(folder_name)

if folder_path.is_dir():
csv_files = list(folder_path.rglob('*.csv'))
else:
raise ValueError(f"The folder '{folder_name}' does not exist.")

if not csv_files:
raise ValueError(f"No CSV files detected in the folder '{folder_name}'.")

for csv_file in csv_files:
table_name = csv_file.stem
self.detect_table_from_csv(table_name, str(csv_file))

def set_primary_key(self, table_name, column_name):
"""Set the primary key of a table.

Expand Down
136 changes: 136 additions & 0 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json

from sdv.datasets.demo import download_demo
from sdv.metadata import MultiTableMetadata


Expand Down Expand Up @@ -132,3 +133,138 @@ def test_upgrade_metadata(tmp_path):
assert new_metadata['tables'] == expected_metadata['tables']
for relationship in new_metadata['relationships']:
assert relationship in expected_metadata['relationships']


def test_detect_from_dataframes():
"""Test the ``detect_from_dataframes`` method."""
# Setup
real_data, _ = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)

metadata = MultiTableMetadata()

# Run
metadata.detect_from_dataframes(real_data)

# Assert
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'categorical'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'}
}
},
'guests': {
'columns': {
'guest_email': {'sdtype': 'categorical'},
'hotel_id': {'sdtype': 'categorical'},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical'},
'checkin_date': {'sdtype': 'categorical'},
'checkout_date': {'sdtype': 'categorical'},
'room_rate': {'sdtype': 'numerical'},
'billing_address': {'sdtype': 'categorical'},
'credit_card_number': {'sdtype': 'numerical'}
}
}
},
'relationships': [],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will add the logic to detect the relationships in another issue, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

assert metadata.to_dict() == expected_metadata


def test_detect_from_csvs(tmp_path):
"""Test the ``detect_from_csvs`` method."""
# Setup
real_data, _ = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)

metadata = MultiTableMetadata()

for table_name, dataframe in real_data.items():
csv_path = tmp_path / f'{table_name}.csv'
dataframe.to_csv(csv_path, index=False)

# Run
metadata.detect_from_csvs(folder_name=tmp_path)

# Assert
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'categorical'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'}
}
},
'guests': {
'columns': {
'guest_email': {'sdtype': 'categorical'},
'hotel_id': {'sdtype': 'categorical'},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical'},
'checkin_date': {'sdtype': 'categorical'},
'checkout_date': {'sdtype': 'categorical'},
'room_rate': {'sdtype': 'numerical'},
'billing_address': {'sdtype': 'categorical'},
'credit_card_number': {'sdtype': 'numerical'}
}
}
},
'relationships': [],
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

assert metadata.to_dict() == expected_metadata


def test_detect_table_from_csv(tmp_path):
"""Test the ``detect_table_from_csv`` method."""
# Setup
real_data, _ = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)

metadata = MultiTableMetadata()

for table_name, dataframe in real_data.items():
csv_path = tmp_path / f'{table_name}.csv'
dataframe.to_csv(csv_path, index=False)

# Run
metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv')

# Assert
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'categorical'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'}
}
}
},
'relationships': [],
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

assert metadata.to_dict() == expected_metadata
102 changes: 99 additions & 3 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,8 @@ def test_update_column_table_does_not_exist(self):

@patch('sdv.metadata.multi_table.LOGGER')
@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test_detect_table_from_csv(self, single_table_mock, log_mock):
@patch('sdv.metadata.multi_table.load_data_from_csv')
def test_detect_table_from_csv(self, load_csv_mock, single_table_mock, log_mock):
"""Test the ``detect_table_from_csv`` method.

If the table does not already exist, a ``SingleTableMetadata`` instance
Expand All @@ -1604,7 +1605,7 @@ def test_detect_table_from_csv(self, single_table_mock, log_mock):
# Setup
metadata = MultiTableMetadata()
fake_data = Mock()
single_table_mock.return_value._load_data_from_csv.return_value = fake_data
load_csv_mock.return_value = fake_data
single_table_mock.return_value.to_dict.return_value = {
'columns': {'a': {'sdtype': 'numerical'}}
}
Expand All @@ -1613,7 +1614,7 @@ def test_detect_table_from_csv(self, single_table_mock, log_mock):
metadata.detect_table_from_csv('table', 'path.csv')

# Assert
single_table_mock.return_value._load_data_from_csv.assert_called_once_with('path.csv')
load_csv_mock.assert_called_once_with('path.csv')
single_table_mock.return_value._detect_columns.assert_called_once_with(fake_data)
assert metadata.tables == {'table': single_table_mock.return_value}

Expand Down Expand Up @@ -1656,6 +1657,59 @@ def test_detect_table_from_csv_table_already_exists(self):
with pytest.raises(InvalidMetadataError, match=error_message):
metadata.detect_table_from_csv('table', 'path.csv')

def test_detect_from_csvs(self, tmp_path):
"""Test the ``detect_from_csvs`` method.

The method should call ``detect_table_from_csv`` for each csv in the folder.
"""
# Setup
instance = MultiTableMetadata()
instance.detect_table_from_csv = Mock()

data1 = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})
data2 = pd.DataFrame({'col1': [5, 6], 'col2': [7, 8]})

filepath1 = tmp_path / 'table1.csv'
filepath2 = tmp_path / 'table2.csv'
data1.to_csv(filepath1, index=False)
data2.to_csv(filepath2, index=False)

json_filepath = tmp_path / 'not_csv.json'
with open(json_filepath, 'w') as json_file:
json_file.write('{"key": "value"}')

# Run
instance.detect_from_csvs(tmp_path)

# Assert
expected_calls = [
call('table1', str(filepath1)),
call('table2', str(filepath2))
]

instance.detect_table_from_csv.assert_has_calls(expected_calls, any_order=True)
assert instance.detect_table_from_csv.call_count == 2

def test_detect_from_csvs_no_csv(self, tmp_path):
"""Test the ``detect_from_csvs`` method with no csv file in the folder."""
# Setup
instance = MultiTableMetadata()

json_filepath = tmp_path / 'not_csv.json'
with open(json_filepath, 'w') as json_file:
json_file.write('{"key": "value"}')

# Run and Assert
expected_message = re.escape("No CSV files detected in the folder '{}'.".format(tmp_path))
with pytest.raises(ValueError, match=expected_message):
instance.detect_from_csvs(tmp_path)

expected_message_folder = re.escape(
"The folder '{}' does not exist.".format('not_a_folder')
)
with pytest.raises(ValueError, match=expected_message_folder):
instance.detect_from_csvs('not_a_folder')

@patch('sdv.metadata.multi_table.LOGGER')
@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test_detect_table_from_dataframe(self, single_table_mock, log_mock):
Expand Down Expand Up @@ -1723,6 +1777,48 @@ def test_detect_table_from_dataframe_table_already_exists(self):
with pytest.raises(InvalidMetadataError, match=error_message):
metadata.detect_table_from_dataframe('table', pd.DataFrame())

def test_detect_from_dataframes(self):
"""Test ``detect_from_dataframes``.

Expected to call ``detect_table_from_dataframe`` for each table name and dataframe
in the input.
"""
# Setup
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe = Mock()

guests_table = pd.DataFrame()
hotels_table = pd.DataFrame()

# Run
metadata.detect_from_dataframes(
data={
'guests': guests_table,
'hotels': hotels_table
}
)

# Assert
metadata.detect_table_from_dataframe.assert_any_call('guests', guests_table)
metadata.detect_table_from_dataframe.assert_any_call('hotels', hotels_table)

def test_detect_from_dataframes_no_dataframes(self):
"""Test ``detect_from_dataframes`` with no dataframes in the input.

Expected to raise an error.
"""
# Setup
metadata = MultiTableMetadata()

# Run and Assert
expected_message = 'The provided dictionary must contain only pandas DataFrame objects.'

with pytest.raises(ValueError, match=expected_message):
metadata.detect_from_dataframes(data={})

with pytest.raises(ValueError, match=expected_message):
metadata.detect_from_dataframes(data={'a': 1})

def test__validate_table_exists(self):
"""Test ``_validate_table_exists``.

Expand Down
Loading