diff --git a/mocket/mocket.py b/mocket/mocket.py index fcc2a8af..b9393aac 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -134,15 +134,75 @@ def wrap_socket(sock=sock, *args, **kwargs): @staticmethod def wrap_bio(incoming, outcoming, *args, **kwargs): - ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] - return ssl_obj + return FakeSSLObject(kwargs["server_hostname"], incoming, outcoming) def __getattr__(self, name): if self.sock is not None: return getattr(self.sock, name) +class FakeSSLObject: + cipher = lambda s: ("ADH", "AES256", "SHA") + compression = lambda s: ssl.OP_NO_COMPRESSION + + _did_handshake = False + _sent_non_empty_bytes = False + + def __init__(self, server_hostname, incoming, outgoing): + self._host = server_hostname + self._port = None + self._incoming = incoming + self._outgoing = outgoing + + def do_handshake(self): + self._did_handshake = True + + def getpeercert(self, *args, **kwargs): + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", "*.%s" % self._host), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", "*.%s" % self._host),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", "*.%s" % self._host),), + ), + } + + def write(self, data): + return self._outgoing.write(data) + + def read(self, max_size): + rv = self._incoming.read(max_size) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def pending(self): + return bool(self._incoming.pending) + + def unwrap(self): + pass + + def __getattr__(self, name): + """Do nothing catchall function, for methods like shutdown()""" + + def do_nothing(*args, **kwargs): + pass + + return do_nothing + + def create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: @@ -171,13 +231,9 @@ class MocketSocket: _host = None _port = None _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION _mode = None _bufsize = None _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False read_fd = None write_fd = None @@ -228,9 +284,6 @@ def settimeout(self, timeout): def getsockopt(level, optname, buflen=None): return socket.SOCK_STREAM - def do_handshake(self): - self._did_handshake = True - def getpeername(self): return self._address @@ -240,26 +293,6 @@ def setblocking(self, block): def getsockname(self): return socket.gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, *args, **kwargs): - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", "*.%s" % self._host), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", "*.%s" % self._host),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", "*.%s" % self._host),), - ), - } - def unwrap(self): return self @@ -304,12 +337,7 @@ def sendall(self, data, entry=None, *args, **kwargs): self.fd.seek(0) def read(self, buffersize): - rv = self.fd.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv + return self.fd.read(buffersize) def recv_into(self, buffer, buffersize=None, flags=None): if hasattr(buffer, "write"):