diff --git a/docs/changelog-fragments/638.bugfix.rst b/docs/changelog-fragments/638.bugfix.rst new file mode 100644 index 000000000..4c4849c70 --- /dev/null +++ b/docs/changelog-fragments/638.bugfix.rst @@ -0,0 +1 @@ +Fixed large sftp reads and proper overriding existing files -- by :user:`Jakuje`. diff --git a/src/pylibsshext/sftp.pyx b/src/pylibsshext/sftp.pyx index 6331c529d..0b42ff8f4 100644 --- a/src/pylibsshext/sftp.pyx +++ b/src/pylibsshext/sftp.pyx @@ -89,18 +89,19 @@ cdef class SFTP: rf = sftp.sftp_open(self._libssh_sftp_session, remote_file_b, O_RDONLY, sftp.S_IRWXU) if rf is NULL: - raise LibsshSFTPException("Opening remote file [%s] for read failed with error [%s]" % (remote_file, self._get_sftp_error_str())) - - while True: - file_data = sftp.sftp_read(rf, read_buffer, sizeof(char) * 1024) - if file_data == 0: - break - elif file_data < 0: - sftp.sftp_close(rf) - raise LibsshSFTPException("Reading data from remote file [%s] failed with error [%s]" - % (remote_file, self._get_sftp_error_str())) - - with open(local_file, 'wb+') as f: + raise LibsshSFTPException("Opening remote file [%s] for read failed with error [%s]" + % (remote_file, self._get_sftp_error_str())) + + with open(local_file, 'wb') as f: + while True: + file_data = sftp.sftp_read(rf, read_buffer, sizeof(char) * 1024) + if file_data == 0: + break + elif file_data < 0: + sftp.sftp_close(rf) + raise LibsshSFTPException("Reading data from remote file [%s] failed with error [%s]" + % (remote_file, self._get_sftp_error_str())) + bytes_written = f.write(read_buffer[:file_data]) if bytes_written and file_data != bytes_written: sftp.sftp_close(rf) diff --git a/tests/unit/sftp_test.py b/tests/unit/sftp_test.py index fc70c4512..1822522dc 100644 --- a/tests/unit/sftp_test.py +++ b/tests/unit/sftp_test.py @@ -2,6 +2,8 @@ """Tests suite for sftp.""" +import random +import string import uuid import pytest @@ -48,6 +50,22 @@ def dst_path(file_paths_pair): return path +@pytest.fixture +def other_payload(): + """Generate a binary test payload.""" + uuid_name = uuid.uuid4() + return 'Original content: {name!s}'.format(name=uuid_name).encode() + + +@pytest.fixture +def dst_exists_path(file_paths_pair, other_payload): + """Return a data destination path.""" + path = file_paths_pair[1] + path.write_bytes(other_payload) + assert path.exists() + return path + + def test_make_sftp(sftp_session): """Smoke-test SFTP instance creation.""" assert sftp_session @@ -63,3 +81,47 @@ def test_get(dst_path, src_path, sftp_session, transmit_payload): """Check that SFTP file download works.""" sftp_session.get(str(src_path), str(dst_path)) assert dst_path.read_bytes() == transmit_payload + + +def test_get_existing(dst_exists_path, src_path, sftp_session, transmit_payload): + """Check that SFTP file download works when target file exists.""" + sftp_session.get(str(src_path), str(dst_exists_path)) + assert dst_exists_path.read_bytes() == transmit_payload + + +def test_put_existing(dst_exists_path, src_path, sftp_session, transmit_payload): + """Check that SFTP file download works when target file exists.""" + sftp_session.put(str(src_path), str(dst_exists_path)) + assert dst_exists_path.read_bytes() == transmit_payload + + +@pytest.fixture +def large_payload(): + """Generate a large 1025 byte (1024 + 1B) test payload.""" + payload_len = 1024 + 1 + random_bytes = [ord(random.choice(string.printable)) for _ in range(payload_len)] + return bytes(random_bytes) + + +@pytest.fixture +def src_path_large(tmp_path, large_payload): + """Return a remote path to a 1025 byte-sized file. + + The pylibssh chunk size is 1024 so the test needs a file that would + execute at least two loops. + """ + path = tmp_path / 'large.txt' + path.write_bytes(large_payload) + return path + + +def test_put_large(dst_path, src_path_large, sftp_session, large_payload): + """Check that SFTP can upload large file.""" + sftp_session.put(str(src_path_large), str(dst_path)) + assert dst_path.read_bytes() == large_payload + + +def test_get_large(dst_path, src_path_large, sftp_session, large_payload): + """Check that SFTP can download large file.""" + sftp_session.get(str(src_path_large), str(dst_path)) + assert dst_path.read_bytes() == large_payload