diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd index 3da10f00..87fc6a9e 100644 --- a/uvloop/sslproto.pxd +++ b/uvloop/sslproto.pxd @@ -53,6 +53,7 @@ cdef class SSLProtocol: object _sslobj object _sslobj_read object _sslobj_write + object _sslobj_pending object _incoming object _incoming_write object _outgoing diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index 42bb7644..27bdad8a 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -480,6 +480,7 @@ cdef class SSLProtocol: server_hostname=self._server_hostname) self._sslobj_read = self._sslobj.read self._sslobj_write = self._sslobj.write + self._sslobj_pending = self._sslobj.pending except Exception as ex: self._on_handshake_complete(ex) else: @@ -719,48 +720,91 @@ cdef class SSLProtocol: cdef _do_read__buffered(self): cdef: - Py_buffer pybuf - bint pybuf_inited = False - size_t wants, offset = 0 - int count = 1 - object buf + Py_ssize_t total_pending = (self._incoming.pending + + self._sslobj_pending()) + # Ask for a little extra in case when decrypted data is bigger than + # original + object app_buffer = self._app_protocol_get_buffer( + total_pending + 256) + Py_ssize_t app_buffer_size = len(app_buffer) + + if app_buffer_size == 0: + return - buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) - wants = len(buf) + cdef: + Py_ssize_t last_bytes_read = -1 + Py_ssize_t total_bytes_read = 0 + Py_buffer pybuf + bint pybuf_initialized = False try: - count = self._sslobj_read(wants, buf) - - if count > 0: - offset = count - if offset < wants: - PyObject_GetBuffer(buf, &pybuf, PyBUF_WRITABLE) - pybuf_inited = True - while offset < wants: - buf = PyMemoryView_FromMemory( - (pybuf.buf) + offset, - wants - offset, + # SSLObject.read may not return all available data in one go. + # We have to keep calling read until it throw SSLWantReadError. + # However, throwing SSLWantReadError is very expensive even in + # the master trunk of cpython. + # See https://github.com/python/cpython/issues/123954 + + # One way to reduce reliance on SSLWantReadError is to check + # self._incoming.pending > 0 and SSLObject.pending() > 0. + # SSLObject.read may still throw SSLWantReadError even when + # self._incoming.pending > 0 SSLObject.pending() == 0, + # but this should happen relatively rarely, only when ssl frame + # is partially received. + + # This optimization works really well especially for peers + # exchanging small messages and wanting to have minimal latency. + + # self._incoming.pending means how much data hasn't + # been processed by ssl yet (read: "still encrypted"). The final + # unencrypted data size maybe different. + + # self._sslobj.pending() means how much data has been already + # decrypted and can be directly read with SSLObject.read. + + # Run test_create_server_ssl_over_ssl to reproduce different cases + # for this method. + while total_pending > 0: + if total_bytes_read > 0: + if not pybuf_initialized: + PyObject_GetBuffer(app_buffer, &pybuf, PyBUF_WRITABLE) + pybuf_initialized = True + + app_buffer = PyMemoryView_FromMemory( + (pybuf.buf) + total_bytes_read, + app_buffer_size - total_bytes_read, PyBUF_WRITE) - count = self._sslobj_read(wants - offset, buf) - if count > 0: - offset += count - else: - break - else: + + last_bytes_read = self._sslobj_read( + app_buffer_size, app_buffer) + total_bytes_read += last_bytes_read + + if last_bytes_read == 0: + break + + # User buffer may not fit all available data. + if total_bytes_read == app_buffer_size: self._loop._call_soon_handle( new_MethodHandle(self._loop, "SSLProtocol._do_read", - self._do_read, + self._do_read, None, # current context is good self)) + break + + total_pending = (self._incoming.pending + + self._sslobj_pending()) except ssl_SSLAgainErrors as exc: pass finally: - if pybuf_inited: + if pybuf_initialized: PyBuffer_Release(&pybuf) - if offset > 0: - self._app_protocol_buffer_updated(offset) - if not count: + + if total_bytes_read > 0: + self._app_protocol_buffer_updated(total_bytes_read) + + # SSLObject.read() may return 0 instead of throwing SSLWantReadError + # This indicates that we reached EOF + if last_bytes_read == 0: # close_notify self._call_eof_received() self._start_shutdown() @@ -772,7 +816,8 @@ cdef class SSLProtocol: bint zero = True, one = False try: - while True: + while (self._incoming.pending > 0 or + self._sslobj_pending() > 0): chunk = self._sslobj_read(SSL_READ_MAX_SIZE) if not chunk: break