diff --git a/pytd/tests/test_writer.py b/pytd/tests/test_writer.py index 0202d7a..3ff14b1 100644 --- a/pytd/tests/test_writer.py +++ b/pytd/tests/test_writer.py @@ -1,7 +1,7 @@ -import io import os +import tempfile import unittest -from unittest.mock import ANY, MagicMock +from unittest.mock import ANY, MagicMock, patch import numpy as np import pandas as pd @@ -281,14 +281,13 @@ def test_write_dataframe_msgpack(self): api_client = self.table.client.api_client self.assertTrue(api_client.create_bulk_import.called) self.assertTrue(api_client.create_bulk_import().upload_part.called) - _bytes = BulkImportWriter()._write_msgpack_stream( - df.to_dict(orient="records"), io.BytesIO() - ) - size = _bytes.getbuffer().nbytes + fp = tempfile.NamedTemporaryFile(delete=False) + fp = BulkImportWriter()._write_msgpack_stream(df.to_dict(orient="records"), fp) api_client.create_bulk_import().upload_part.assert_called_with( - "part-0", ANY, size + "part-0", ANY, 62 ) self.assertFalse(api_client.create_bulk_import().upload_file.called) + os.unlink(fp.name) def test_write_dataframe_msgpack_with_int_na(self): # Although this conversion ensures pd.NA Int64 dtype to None, @@ -305,13 +304,14 @@ def test_write_dataframe_msgpack_with_int_na(self): {"a": 3, "b": 4, "c": 5, "time": 1234}, ) self.writer._write_msgpack_stream = MagicMock() - self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") - self.assertTrue(self.writer._write_msgpack_stream.called) - print(self.writer._write_msgpack_stream.call_args[0][0][0:2]) - self.assertEqual( - self.writer._write_msgpack_stream.call_args[0][0][0:2], - expected_list, - ) + with patch("pytd.writer.os.unlink"): + self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") + self.assertTrue(self.writer._write_msgpack_stream.called) + print(self.writer._write_msgpack_stream.call_args[0][0][0:2]) + self.assertEqual( + self.writer._write_msgpack_stream.call_args[0][0][0:2], + expected_list, + ) @unittest.skipIf( pd.__version__ < "1.0.0", "pd.NA not supported in this pandas version" @@ -327,12 +327,13 @@ def test_write_dataframe_msgpack_with_string_na(self): {"a": "buzz", "b": "buzz", "c": "alice", "time": 1234}, ) self.writer._write_msgpack_stream = MagicMock() - self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") - self.assertTrue(self.writer._write_msgpack_stream.called) - self.assertEqual( - self.writer._write_msgpack_stream.call_args[0][0][0:2], - expected_list, - ) + with patch("pytd.writer.os.unlink"): + self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") + self.assertTrue(self.writer._write_msgpack_stream.called) + self.assertEqual( + self.writer._write_msgpack_stream.call_args[0][0][0:2], + expected_list, + ) @unittest.skipIf( pd.__version__ < "1.0.0", "pd.NA not supported in this pandas version" @@ -348,12 +349,13 @@ def test_write_dataframe_msgpack_with_boolean_na(self): {"a": "false", "b": "true", "c": "true", "time": 1234}, ) self.writer._write_msgpack_stream = MagicMock() - self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") - self.assertTrue(self.writer._write_msgpack_stream.called) - self.assertEqual( - self.writer._write_msgpack_stream.call_args[0][0][0:2], - expected_list, - ) + with patch("pytd.writer.os.unlink"): + self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack") + self.assertTrue(self.writer._write_msgpack_stream.called) + self.assertEqual( + self.writer._write_msgpack_stream.call_args[0][0][0:2], + expected_list, + ) def test_write_dataframe_invalid_if_exists(self): with self.assertRaises(ValueError): diff --git a/pytd/writer.py b/pytd/writer.py index 9bb3346..9423f2d 100644 --- a/pytd/writer.py +++ b/pytd/writer.py @@ -1,6 +1,5 @@ import abc import gzip -import io import logging import os import tempfile @@ -450,11 +449,18 @@ def write_dataframe( _replace_pd_na(dataframe) records = dataframe.to_dict(orient="records") - for group in zip_longest(*(iter(records),) * chunk_record_size): - fp = io.BytesIO() - fp = self._write_msgpack_stream(group, fp) - fps.append(fp) - stack.callback(fp.close) + try: + for group in zip_longest(*(iter(records),) * chunk_record_size): + fp = tempfile.NamedTemporaryFile(suffix=".msgpack.gz", delete=False) + fp = self._write_msgpack_stream(group, fp) + fps.append(fp) + stack.callback(os.unlink, fp.name) + stack.callback(fp.close) + except OSError as e: + raise RuntimeError( + "failed to create a temporary file. " + "Increase chunk_record_size may mitigate the issue." + ) from e else: raise ValueError( f"unsupported format '{fmt}' for bulk import. " @@ -514,19 +520,21 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv", max_workers=5): bulk_import = table.client.api_client.create_bulk_import( session_name, table.database, table.table, params=params ) + s_time = time.time() try: logger.info(f"uploading data converted into a {fmt} file") if fmt == "msgpack": with ThreadPoolExecutor(max_workers=max_workers) as executor: - _ = [ + for i, fp in enumerate(file_like): + fsize = fp.tell() + fp.seek(0) executor.submit( bulk_import.upload_part, f"part-{i}", fp, - fp.getbuffer().nbytes, + fsize, ) - for i, fp in enumerate(file_like) - ] + logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B") else: fp = file_like[0] bulk_import.upload_file("part", fmt, fp) @@ -535,6 +543,8 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv", max_workers=5): bulk_import.delete() raise RuntimeError(f"failed to upload file: {e}") + logger.info(f"uploaded data in {time.time() - s_time:.2f} sec") + logger.info("performing a bulk import job") job = bulk_import.perform(wait=True) @@ -581,7 +591,9 @@ def _write_msgpack_stream(self, items, stream): mp = packer.pack(normalized_msgpack(item)) gz.write(mp) - stream.seek(0) + logger.debug( + f"created a msgpack file: {stream.name}. File size: {stream.tell()}" + ) return stream