From be3b72ca1774cbcf64ebb8b55364f9ec1d68474a Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Mon, 29 Apr 2024 22:31:55 +0200 Subject: [PATCH] Implement CSVHandler (#1958) --- sdv/io/local/__init__.py | 8 + sdv/io/local/local.py | 194 +++++++++++++++++++++ tests/integration/io/__init__.py | 0 tests/integration/io/local/__init__.py | 0 tests/integration/io/local/test_local.py | 32 ++++ tests/unit/io/__init__.py | 0 tests/unit/io/local/__init__.py | 0 tests/unit/io/local/test_local.py | 212 +++++++++++++++++++++++ 8 files changed, 446 insertions(+) create mode 100644 sdv/io/local/__init__.py create mode 100644 sdv/io/local/local.py create mode 100644 tests/integration/io/__init__.py create mode 100644 tests/integration/io/local/__init__.py create mode 100644 tests/integration/io/local/test_local.py create mode 100644 tests/unit/io/__init__.py create mode 100644 tests/unit/io/local/__init__.py create mode 100644 tests/unit/io/local/test_local.py diff --git a/sdv/io/local/__init__.py b/sdv/io/local/__init__.py new file mode 100644 index 000000000..a233b25be --- /dev/null +++ b/sdv/io/local/__init__.py @@ -0,0 +1,8 @@ +"""Local I/O module.""" + +from sdv.io.local.local import BaseLocalHandler, CSVHandler + +__all__ = ( + 'BaseLocalHandler', + 'CSVHandler' +) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py new file mode 100644 index 000000000..0d81ab634 --- /dev/null +++ b/sdv/io/local/local.py @@ -0,0 +1,194 @@ +"""Local file handlers.""" +import codecs +import inspect +import os +from pathlib import Path + +import pandas as pd + +from sdv.metadata import MultiTableMetadata + + +class BaseLocalHandler: + """Base class for local handlers.""" + + def __init__(self, decimal='.', float_format=None): + self.decimal = decimal + self.float_format = float_format + + def _infer_metadata(self, data): + """Detect the metadata for all tables in a dictionary of dataframes. + + Args: + data (dict): + Dictionary of table names to dataframes. + + Returns: + MultiTableMetadata: + An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata + properties from the data. + """ + metadata = MultiTableMetadata() + metadata.detect_from_dataframes(data) + return metadata + + def read(self): + """Read data from files and returns it along with metadata. + + This method must be implemented by subclasses. + + Returns: + tuple: + A tuple containing the read data as a dictionary and metadata. The dictionary maps + table names to pandas DataFrames. The metadata is an object describing the data. + """ + raise NotImplementedError() + + def write(self): + """Write data to files. + + This method must be implemented by subclasses. + """ + raise NotImplementedError() + + +class CSVHandler(BaseLocalHandler): + """A class for handling CSV files. + + Args: + sep (str): + The separator used for reading and writing CSV files. Defaults to ``,``. + encoding (str): + The character encoding to use for reading and writing CSV files. Defaults to ``UTF``. + decimal (str): + The character used to denote the decimal point. Defaults to ``.``. + float_format (str or None): + The formatting string for floating-point numbers. Optional. + quotechar (str): + Character used to denote the start and end of a quoted item. + Quoted items can include the delimiter and it will be ignored. Defaults to '"'. + quoting (int or None): + Control field quoting behavior. Default is 0. + + Raises: + ValueError: + If the provided encoding is not available in the system. + """ + + def __init__(self, sep=',', encoding='UTF', decimal='.', float_format=None, + quotechar='"', quoting=0): + super().__init__(decimal, float_format) + try: + codecs.lookup(encoding) + except LookupError as error: + raise ValueError( + f"The provided encoding '{encoding}' is not available in your system." + ) from error + + self.sep = sep + self.encoding = encoding + self.quotechar = quotechar + self.quoting = quoting + + def read(self, folder_name, file_names=None): + """Read data from CSV files and returns it along with metadata. + + Args: + folder_name (str): + The name of the folder containing CSV files. + file_names (list of str, optional): + The names of CSV files to read. If None, all files ending with '.csv' + in the folder are read. + + Returns: + tuple: + A tuple containing the data as a dictionary and metadata. The dictionary maps + table names to pandas DataFrames. The metadata is an object describing the data. + + Raises: + FileNotFoundError: + If the specified files do not exist in the folder. + """ + data = {} + metadata = MultiTableMetadata() + + folder_path = Path(folder_name) + + if file_names is None: + # If file_names is None, read all files in the folder ending with ".csv" + file_paths = folder_path.glob('*.csv') + else: + # Validate if the given files exist in the folder + file_names = file_names + missing_files = [ + file + for file in file_names + if not (folder_path / file).exists() + ] + if missing_files: + raise FileNotFoundError( + f"The following files do not exist in the folder: {', '.join(missing_files)}." + ) + + file_paths = [folder_path / file for file in file_names] + + # Read CSV files + kwargs = { + 'sep': self.sep, + 'encoding': self.encoding, + 'parse_dates': False, + 'low_memory': False, + 'decimal': self.decimal, + 'on_bad_lines': 'warn', + 'quotechar': self.quotechar, + 'quoting': self.quoting + } + + args = inspect.getfullargspec(pd.read_csv) + if 'on_bad_lines' not in args.kwonlyargs: + kwargs.pop('on_bad_lines') + kwargs['error_bad_lines'] = False + + for file_path in file_paths: + table_name = file_path.stem # Remove file extension to get table name + data[table_name] = pd.read_csv( + file_path, + **kwargs + ) + + metadata = self._infer_metadata(data) + return data, metadata + + def write(self, synthetic_data, folder_name, file_name_suffix=None, mode='x'): + """Write synthetic data to CSV files. + + Args: + synthetic_data (dict): + A dictionary mapping table names to pandas DataFrames containing synthetic data. + folder_name (str): + The name of the folder to write CSV files to. + file_name_suffix (str, optional): + An optional suffix to add to each file name. If ``None``, no suffix is added. + mode (str, optional): + The mode of writing to use. Defaults to 'x'. + 'x': Write to new files, raising errors if existing files exist with the same name. + 'w': Write to new files, clearing any existing files that exist. + 'a': Append the new CSV rows to any existing files. + """ + folder_path = Path(folder_name) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + for table_name, table_data in synthetic_data.items(): + file_name = f'{table_name}{file_name_suffix}' if file_name_suffix else f'{table_name}' + file_path = f'{folder_path / file_name}.csv' + table_data.to_csv( + file_path, + sep=self.sep, + encoding=self.encoding, + index=False, + float_format=self.float_format, + quotechar=self.quotechar, + quoting=self.quoting, + mode=mode, + ) diff --git a/tests/integration/io/__init__.py b/tests/integration/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/io/local/__init__.py b/tests/integration/io/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py new file mode 100644 index 000000000..87b3c80ea --- /dev/null +++ b/tests/integration/io/local/test_local.py @@ -0,0 +1,32 @@ +import pandas as pd + +from sdv.io.local import CSVHandler +from sdv.metadata import MultiTableMetadata + + +class TestCSVHandler: + + def test_integration_read_write(self, tmpdir): + """Test end to end the read and write methods of ``CSVHandler``.""" + # Prepare synthetic data + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + # Write synthetic data to CSV files + handler = CSVHandler() + handler.write(synthetic_data, tmpdir) + + # Read data from CSV files + data, metadata = handler.read(tmpdir) + + # Check if data was read correctly + assert len(data) == 2 + assert 'table1' in data + assert 'table2' in data + assert isinstance(metadata, MultiTableMetadata) is True + + # Check if the dataframes match the original synthetic data + pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) + pd.testing.assert_frame_equal(data['table2'], synthetic_data['table2']) diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/io/local/__init__.py b/tests/unit/io/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py new file mode 100644 index 000000000..e69d18636 --- /dev/null +++ b/tests/unit/io/local/test_local.py @@ -0,0 +1,212 @@ +"""Unit tests for local file handlers.""" +import os +from pathlib import Path +from unittest.mock import patch + +import pandas as pd +import pytest + +from sdv.io.local.local import CSVHandler +from sdv.metadata.multi_table import MultiTableMetadata + + +class TestCSVHandler: + + def test___init__(self): + """Test the dafault initialization of the class.""" + # Run + instance = CSVHandler() + + # Assert + assert instance.decimal == '.' + assert instance.float_format is None + assert instance.encoding == 'UTF' + assert instance.sep == ',' + assert instance.quotechar == '"' + assert instance.quoting == 0 + + def test___init___custom(self): + """Test custom initialization of the class.""" + # Run + instance = CSVHandler( + sep=';', + encoding='utf-8', + decimal=',', + float_format='%.2f', + quotechar="'", + quoting=2 + ) + + # Assert + assert instance.decimal == ',' + assert instance.float_format == '%.2f' + assert instance.encoding == 'utf-8' + assert instance.sep == ';' + assert instance.quotechar == "'" + assert instance.quoting == 2 + + def test___init___error_encoding(self): + """Test custom initialization of the class.""" + # Run and Assert + error_msg = "The provided encoding 'sdvutf-8' is not available in your system." + with pytest.raises(ValueError, match=error_msg): + CSVHandler(sep=';', encoding='sdvutf-8', decimal=',', float_format='%.2f') + + @patch('sdv.io.local.local.Path.glob') + @patch('pandas.read_csv') + def test_read(self, mock_read_csv, mock_glob): + """Test the read method of CSVHandler class with a folder.""" + # Setup + mock_glob.return_value = [ + Path('/path/to/data/parent.csv'), + Path('/path/to/data/child.csv') + ] + mock_read_csv.side_effect = [ + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + ] + + handler = CSVHandler() + + # Run + data, metadata = handler.read('/path/to/data') + + # Assert + assert len(data) == 2 + assert 'parent' in data + assert 'child' in data + assert isinstance(metadata, MultiTableMetadata) + assert mock_read_csv.call_count == 2 + pd.testing.assert_frame_equal( + data['parent'], + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) + ) + pd.testing.assert_frame_equal( + data['child'], + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + ) + + def test_read_files(self, tmpdir): + """Test the read method of CSVHandler class with given ``file_names``.""" + # Setup + file_path = Path(tmpdir) + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}).to_csv( + file_path / 'parent.csv', + index=False + ) + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}).to_csv( + file_path / 'child.csv', + index=False + ) + + handler = CSVHandler() + + # Run + data, metadata = handler.read(tmpdir, file_names=['parent.csv']) + + # Assert + assert 'parent' in data + pd.testing.assert_frame_equal( + data['parent'], + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}) + ) + + def test_read_files_missing(self, tmpdir): + """Test the read method of CSVHandler with missing ``file_names``.""" + # Setup + file_path = Path(tmpdir) + pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}).to_csv( + file_path / 'parent.csv', + index=False + ) + pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}).to_csv( + file_path / 'child.csv', + index=False + ) + + handler = CSVHandler() + + # Run and Assert + error_msg = 'The following files do not exist in the folder: grandchild.csv, parents.csv.' + with pytest.raises(FileNotFoundError, match=error_msg): + handler.read(tmpdir, file_names=['grandchild.csv', 'parents.csv']) + + def test_write(self, tmpdir): + """Test the write functionality of a CSVHandler.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + handler = CSVHandler() + + assert os.path.exists(tmpdir / 'synthetic_data') is False + + # Run + handler.write(synthetic_data, tmpdir / 'synthetic_data', file_name_suffix='_synthetic') + + # Assert + assert 'table1_synthetic.csv' in os.listdir(tmpdir / 'synthetic_data') + assert 'table2_synthetic.csv' in os.listdir(tmpdir / 'synthetic_data') + + def test_write_file_exists(self, tmpdir): + """Test that an error is raised when it exists and the mode is `x`.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + os.makedirs(tmpdir / 'synthetic_data') + synthetic_data['table1'].to_csv(tmpdir / 'synthetic_data' / 'table1.csv', index=False) + handler = CSVHandler() + + # Run + with pytest.raises(FileExistsError): + handler.write(synthetic_data, tmpdir / 'synthetic_data') + + def test_write_file_exists_mode_is_a(self, tmpdir): + """Test the write functionality of a CSVHandler when the mode is ``a``.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + os.makedirs(tmpdir / 'synthetic_data') + synthetic_data['table1'].to_csv(tmpdir / 'synthetic_data' / 'table1.csv', index=False) + handler = CSVHandler() + + # Run + handler.write(synthetic_data, tmpdir / 'synthetic_data', mode='a') + + # Assert + dataframe = pd.read_csv(tmpdir / 'synthetic_data' / 'table1.csv') + expected_dataframe = pd.DataFrame({ + 'col1': ['1', '2', '3', 'col1', '1', '2', '3'], + 'col2': ['a', 'b', 'c', 'col2', 'a', 'b', 'c'] + }) + pd.testing.assert_frame_equal(dataframe, expected_dataframe) + + def test_write_file_exists_mode_is_w(self, tmpdir): + """Test the write functionality of a CSVHandler when the mode is ``w``.""" + # Setup + synthetic_data = { + 'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}), + 'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}) + } + + os.makedirs(tmpdir / 'synthetic_data') + synthetic_data['table1'].to_csv(tmpdir / 'synthetic_data' / 'table1.csv', index=False) + handler = CSVHandler() + + # Run + handler.write(synthetic_data, tmpdir / 'synthetic_data', mode='w') + + # Assert + dataframe = pd.read_csv(tmpdir / 'synthetic_data' / 'table1.csv') + expected_dataframe = pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': ['a', 'b', 'c'] + }) + pd.testing.assert_frame_equal(dataframe, expected_dataframe)