From ef3887a0c9322811643b31a3006e8e494f11bce9 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Sun, 17 Sep 2023 15:10:33 -0700 Subject: [PATCH] Add parallel upload for msgpack option --- pytd/tests/test_writer.py | 28 ++++++++++++---------- pytd/writer.py | 50 ++++++++++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/pytd/tests/test_writer.py b/pytd/tests/test_writer.py index 510d90d..0202d7a 100644 --- a/pytd/tests/test_writer.py +++ b/pytd/tests/test_writer.py @@ -264,7 +264,7 @@ def test_write_dataframe_tempfile_deletion(self): # file pointer to a temp CSV file fp = self.writer._bulk_import.call_args[0][1] # temp file should not exist - self.assertFalse(os.path.isfile(fp.name)) + self.assertFalse(os.path.isfile(fp[0].name)) # Case #2: bulk import failed self.writer._bulk_import = MagicMock(side_effect=Exception()) @@ -273,7 +273,7 @@ def test_write_dataframe_tempfile_deletion(self): pd.DataFrame([[1, 2], [3, 4]]), self.table, "overwrite" ) fp = self.writer._bulk_import.call_args[0][1] - self.assertFalse(os.path.isfile(fp.name)) + self.assertFalse(os.path.isfile(fp[0].name)) def test_write_dataframe_msgpack(self): df = pd.DataFrame([[1, 2], [3, 4]]) @@ -286,7 +286,7 @@ def test_write_dataframe_msgpack(self): ) size = _bytes.getbuffer().nbytes api_client.create_bulk_import().upload_part.assert_called_with( - "part", ANY, size + "part-0", ANY, size ) self.assertFalse(api_client.create_bulk_import().upload_file.called) @@ -300,15 +300,17 @@ def test_write_dataframe_msgpack_with_int_na(self): ], dtype="Int64", ) - expected_list = [ + expected_list = ( {"a": 1, "b": 2, "c": None, "time": 1234}, {"a": 3, "b": 4, "c": 5, "time": 1234}, - ] + ) self.writer._write_msgpack_stream = MagicMock() self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") self.assertTrue(self.writer._write_msgpack_stream.called) + print(self.writer._write_msgpack_stream.call_args[0][0][0:2]) self.assertEqual( - self.writer._write_msgpack_stream.call_args[0][0], expected_list + self.writer._write_msgpack_stream.call_args[0][0][0:2], + expected_list, ) @unittest.skipIf( @@ -320,15 +322,16 @@ def test_write_dataframe_msgpack_with_string_na(self): dtype="string", ) df["time"] = 1234 - expected_list = [ + expected_list = ( {"a": "foo", "b": "bar", "c": None, "time": 1234}, {"a": "buzz", "b": "buzz", "c": "alice", "time": 1234}, - ] + ) self.writer._write_msgpack_stream = MagicMock() self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") self.assertTrue(self.writer._write_msgpack_stream.called) self.assertEqual( - self.writer._write_msgpack_stream.call_args[0][0], expected_list + self.writer._write_msgpack_stream.call_args[0][0][0:2], + expected_list, ) @unittest.skipIf( @@ -340,15 +343,16 @@ def test_write_dataframe_msgpack_with_boolean_na(self): dtype="boolean", ) df["time"] = 1234 - expected_list = [ + expected_list = ( {"a": "true", "b": "false", "c": None, "time": 1234}, {"a": "false", "b": "true", "c": "true", "time": 1234}, - ] + ) self.writer._write_msgpack_stream = MagicMock() self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") self.assertTrue(self.writer._write_msgpack_stream.called) self.assertEqual( - self.writer._write_msgpack_stream.call_args[0][0], expected_list + self.writer._write_msgpack_stream.call_args[0][0][0:2], + expected_list, ) def test_write_dataframe_invalid_if_exists(self): diff --git a/pytd/writer.py b/pytd/writer.py index d486c39..fae5b46 100644 --- a/pytd/writer.py +++ b/pytd/writer.py @@ -6,7 +6,9 @@ import tempfile import time import uuid +from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack +from itertools import zip_longest import msgpack import numpy as np @@ -308,7 +310,9 @@ class BulkImportWriter(Writer): td-client-python's bulk importer. """ - def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=False): + def write_dataframe( + self, dataframe, table, if_exists, fmt="csv", keep_list=False, max_workers=5 + ): """Write a given DataFrame to a Treasure Data table. This method internally converts a given :class:`pandas.DataFrame` into a @@ -403,6 +407,10 @@ def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=Fals Or, you can use :func:`Client.load_table_from_dataframe` function as well. >>> client.load_table_from_dataframe(df, "bulk_import", keep_list=True) + + max_workers : int, optional, default: 5 + The maximum number of threads that can be used to execute the given calls. + This is used only when ``fmt`` is ``msgpack``. """ if self.closed: raise RuntimeError("this writer is already closed and no longer available") @@ -420,26 +428,31 @@ def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=Fals _cast_dtypes(dataframe, keep_list=keep_list) with ExitStack() as stack: + fps = [] if fmt == "csv": fp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False) stack.callback(os.unlink, fp.name) stack.callback(fp.close) dataframe.to_csv(fp.name) + fps.append(fp) elif fmt == "msgpack": _replace_pd_na(dataframe) - fp = io.BytesIO() - fp = self._write_msgpack_stream(dataframe.to_dict(orient="records"), fp) - stack.callback(fp.close) + records = dataframe.to_dict(orient="records") + for group in zip_longest(*(iter(records),) * 10000): + fp = io.BytesIO() + fp = self._write_msgpack_stream(group, fp) + fps.append(fp) + stack.callback(fp.close) else: raise ValueError( f"unsupported format '{fmt}' for bulk import. " "should be 'csv' or 'msgpack'" ) - self._bulk_import(table, fp, if_exists, fmt) + self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers) stack.close() - def _bulk_import(self, table, file_like, if_exists, fmt="csv"): + def _bulk_import(self, table, file_like, if_exists, fmt="csv", max_workers=5): """Write a specified CSV file to a Treasure Data table. This method uploads the file to Treasure Data via bulk import API. @@ -449,7 +462,7 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"): table : :class:`pytd.table.Table` Target table. - file_like : File like object + file_like : List of file like objects Data in this file will be loaded to a target table. if_exists : str, {'error', 'overwrite', 'append', 'ignore'} @@ -462,6 +475,10 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"): fmt : str, optional, {'csv', 'msgpack'}, default: 'csv' File format for bulk import. See also :func:`write_dataframe` + + max_workers : int, optional, default: 5 + The maximum number of threads that can be used to execute the given calls. + This is used only when ``fmt`` is ``msgpack``. """ params = None if table.exists: @@ -489,11 +506,19 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"): try: logger.info(f"uploading data converted into a {fmt} file") if fmt == "msgpack": - size = file_like.getbuffer().nbytes - # To skip API._prepare_file(), which recreate msgpack again. - bulk_import.upload_part("part", file_like, size) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + _ = [ + executor.submit( + bulk_import.upload_part, + f"part-{i}", + fp, + fp.getbuffer().nbytes, + ) + for i, fp in enumerate(file_like) + ] else: - bulk_import.upload_file("part", fmt, file_like) + fp = file_like[0] + bulk_import.upload_file("part", fmt, fp) bulk_import.freeze() except Exception as e: bulk_import.delete() @@ -535,6 +560,9 @@ def _write_msgpack_stream(self, items, stream): with gzip.GzipFile(mode="wb", fileobj=stream) as gz: packer = msgpack.Packer() for item in items: + # Ignore None created by zip_longest + if not item: + break try: mp = packer.pack(item) except (OverflowError, ValueError):