Skip to content

Commit

Permalink
Add parallel upload for msgpack option
Browse files Browse the repository at this point in the history
  • Loading branch information
chezou committed Sep 17, 2023
1 parent 3b298d1 commit ef3887a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 23 deletions.
28 changes: 16 additions & 12 deletions pytd/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_write_dataframe_tempfile_deletion(self):
# file pointer to a temp CSV file
fp = self.writer._bulk_import.call_args[0][1]
# temp file should not exist
self.assertFalse(os.path.isfile(fp.name))
self.assertFalse(os.path.isfile(fp[0].name))

# Case #2: bulk import failed
self.writer._bulk_import = MagicMock(side_effect=Exception())
Expand All @@ -273,7 +273,7 @@ def test_write_dataframe_tempfile_deletion(self):
pd.DataFrame([[1, 2], [3, 4]]), self.table, "overwrite"
)
fp = self.writer._bulk_import.call_args[0][1]
self.assertFalse(os.path.isfile(fp.name))
self.assertFalse(os.path.isfile(fp[0].name))

def test_write_dataframe_msgpack(self):
df = pd.DataFrame([[1, 2], [3, 4]])
Expand All @@ -286,7 +286,7 @@ def test_write_dataframe_msgpack(self):
)
size = _bytes.getbuffer().nbytes
api_client.create_bulk_import().upload_part.assert_called_with(
"part", ANY, size
"part-0", ANY, size
)
self.assertFalse(api_client.create_bulk_import().upload_file.called)

Expand All @@ -300,15 +300,17 @@ def test_write_dataframe_msgpack_with_int_na(self):
],
dtype="Int64",
)
expected_list = [
expected_list = (
{"a": 1, "b": 2, "c": None, "time": 1234},
{"a": 3, "b": 4, "c": 5, "time": 1234},
]
)
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
print(self.writer._write_msgpack_stream.call_args[0][0][0:2])
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0], expected_list
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

@unittest.skipIf(
Expand All @@ -320,15 +322,16 @@ def test_write_dataframe_msgpack_with_string_na(self):
dtype="string",
)
df["time"] = 1234
expected_list = [
expected_list = (
{"a": "foo", "b": "bar", "c": None, "time": 1234},
{"a": "buzz", "b": "buzz", "c": "alice", "time": 1234},
]
)
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0], expected_list
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

@unittest.skipIf(
Expand All @@ -340,15 +343,16 @@ def test_write_dataframe_msgpack_with_boolean_na(self):
dtype="boolean",
)
df["time"] = 1234
expected_list = [
expected_list = (
{"a": "true", "b": "false", "c": None, "time": 1234},
{"a": "false", "b": "true", "c": "true", "time": 1234},
]
)
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0], expected_list
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

def test_write_dataframe_invalid_if_exists(self):
Expand Down
50 changes: 39 additions & 11 deletions pytd/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import tempfile
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from itertools import zip_longest

import msgpack
import numpy as np
Expand Down Expand Up @@ -308,7 +310,9 @@ class BulkImportWriter(Writer):
td-client-python's bulk importer.
"""

def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=False):
def write_dataframe(
self, dataframe, table, if_exists, fmt="csv", keep_list=False, max_workers=5
):
"""Write a given DataFrame to a Treasure Data table.
This method internally converts a given :class:`pandas.DataFrame` into a
Expand Down Expand Up @@ -403,6 +407,10 @@ def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=Fals
Or, you can use :func:`Client.load_table_from_dataframe` function as well.
>>> client.load_table_from_dataframe(df, "bulk_import", keep_list=True)
max_workers : int, optional, default: 5
The maximum number of threads that can be used to execute the given calls.
This is used only when ``fmt`` is ``msgpack``.
"""
if self.closed:
raise RuntimeError("this writer is already closed and no longer available")
Expand All @@ -420,26 +428,31 @@ def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=Fals
_cast_dtypes(dataframe, keep_list=keep_list)

with ExitStack() as stack:
fps = []
if fmt == "csv":
fp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
stack.callback(os.unlink, fp.name)
stack.callback(fp.close)
dataframe.to_csv(fp.name)
fps.append(fp)
elif fmt == "msgpack":
_replace_pd_na(dataframe)

fp = io.BytesIO()
fp = self._write_msgpack_stream(dataframe.to_dict(orient="records"), fp)
stack.callback(fp.close)
records = dataframe.to_dict(orient="records")
for group in zip_longest(*(iter(records),) * 10000):
fp = io.BytesIO()
fp = self._write_msgpack_stream(group, fp)
fps.append(fp)
stack.callback(fp.close)
else:
raise ValueError(
f"unsupported format '{fmt}' for bulk import. "
"should be 'csv' or 'msgpack'"
)
self._bulk_import(table, fp, if_exists, fmt)
self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers)
stack.close()

def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
def _bulk_import(self, table, file_like, if_exists, fmt="csv", max_workers=5):
"""Write a specified CSV file to a Treasure Data table.
This method uploads the file to Treasure Data via bulk import API.
Expand All @@ -449,7 +462,7 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
table : :class:`pytd.table.Table`
Target table.
file_like : File like object
file_like : List of file like objects
Data in this file will be loaded to a target table.
if_exists : str, {'error', 'overwrite', 'append', 'ignore'}
Expand All @@ -462,6 +475,10 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
fmt : str, optional, {'csv', 'msgpack'}, default: 'csv'
File format for bulk import. See also :func:`write_dataframe`
max_workers : int, optional, default: 5
The maximum number of threads that can be used to execute the given calls.
This is used only when ``fmt`` is ``msgpack``.
"""
params = None
if table.exists:
Expand Down Expand Up @@ -489,11 +506,19 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
try:
logger.info(f"uploading data converted into a {fmt} file")
if fmt == "msgpack":
size = file_like.getbuffer().nbytes
# To skip API._prepare_file(), which recreate msgpack again.
bulk_import.upload_part("part", file_like, size)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
_ = [
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fp.getbuffer().nbytes,
)
for i, fp in enumerate(file_like)
]
else:
bulk_import.upload_file("part", fmt, file_like)
fp = file_like[0]
bulk_import.upload_file("part", fmt, fp)
bulk_import.freeze()
except Exception as e:
bulk_import.delete()
Expand Down Expand Up @@ -535,6 +560,9 @@ def _write_msgpack_stream(self, items, stream):
with gzip.GzipFile(mode="wb", fileobj=stream) as gz:
packer = msgpack.Packer()
for item in items:
# Ignore None created by zip_longest
if not item:
break
try:
mp = packer.pack(item)
except (OverflowError, ValueError):
Expand Down

0 comments on commit ef3887a

Please sign in to comment.