From ede62ef249acdcdf4323c26d4a0a0671e1a469da Mon Sep 17 00:00:00 2001 From: Molly Draven Date: Tue, 13 Aug 2024 07:38:29 -0400 Subject: [PATCH] feat: Move TLS and server files to src directory --- .vscode/settings.json | 7 +- Makefile | 17 +++++ server.py => main.py | 6 +- pyswitch/ConnectionHandler.py | 36 ++++++---- pyswitch/{PySwitch.py => server.py} | 2 +- pyswitch/{ => src}/setup_logging.py | 0 pyswitch/src/ssl.py | 100 ++++++++++++++++++++++++++++ pyswitch/{ => src}/ssl_constants.py | 0 pyswitch/{ => src}/tls.py | 4 +- pyswitch/{ => src}/tls_constants.py | 0 pyswitch/{ => src}/utils.py | 0 pyswitch/ssl.py | 39 ----------- tests/test_peek_data.py | 34 ++++++++++ 13 files changed, 186 insertions(+), 59 deletions(-) create mode 100644 Makefile rename server.py => main.py (84%) rename pyswitch/{PySwitch.py => server.py} (93%) rename pyswitch/{ => src}/setup_logging.py (100%) create mode 100644 pyswitch/src/ssl.py rename pyswitch/{ => src}/ssl_constants.py (100%) rename pyswitch/{ => src}/tls.py (85%) rename pyswitch/{ => src}/tls_constants.py (100%) rename pyswitch/{ => src}/utils.py (100%) delete mode 100644 pyswitch/ssl.py create mode 100644 tests/test_peek_data.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 29536fd..0ecfcdf 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,7 @@ { "python.analysis.autoImportCompletions": true, - "python.analysis.typeCheckingMode": "basic" -} + "python.analysis.typeCheckingMode": "basic", + "python.terminal.activateEnvironment": true, + "python.terminal.activateEnvInCurrentTerminal": true, + "python.envFile": "${workspaceFolder}/.venv" +} \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..97d4dbb --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +lint: + # stop the build if there are Python syntax errors or undefined names + @pdm run flake8 pyswitch --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + @pdm run flake8 pyswitch --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + +test: + @pdm run pytest --verbose --cov=pyswitch --cov-report=term-missing && pdm run python -m coverage xml + +install: + @pdm install + +start: + @python main.py + +.PHONY: lint test install start \ No newline at end of file diff --git a/server.py b/main.py similarity index 84% rename from server.py rename to main.py index e0d5aa5..eb08909 100644 --- a/server.py +++ b/main.py @@ -1,12 +1,12 @@ from os import getenv -from pyswitch.PySwitch import PySwitch -from dotenv import load_dotenv, find_dotenv +from pyswitch.server import PySwitch +from dotenv import load_dotenv import sentry_sdk def main(): - load_dotenv(find_dotenv()) + load_dotenv(verbose=True) sentry_dsn = getenv("SENTRY_DSN") diff --git a/pyswitch/ConnectionHandler.py b/pyswitch/ConnectionHandler.py index e4fc5cf..11d2997 100644 --- a/pyswitch/ConnectionHandler.py +++ b/pyswitch/ConnectionHandler.py @@ -2,10 +2,10 @@ from ssl import create_default_context import ssl from loguru import logger -from pyswitch.ssl import SSLv2Record -from pyswitch.tls import TLSProtocolVersion -from pyswitch.tls_constants import TLSContentType -from pyswitch.utils import bin2hex +from pyswitch.src.ssl import SSLv2Record +from pyswitch.src.tls import TLSProtocolVersion +from pyswitch.src.tls_constants import TLSContentType +from pyswitch.src.utils import bin2hex import socket from socketserver import StreamRequestHandler @@ -16,17 +16,12 @@ class ConnectionHandler(StreamRequestHandler): def handle(self): logger.debug( - "Connection from: ", self.client_address, "id: ", self.request.fileno() - ) - print( "Connection from: {}, id: {}".format( self.client_address, self.request.fileno() ) ) - first_bytes = self.peek_data(len=32814) - - print("first_bytes_len: ", len(first_bytes)) + first_bytes = peek_data(self.request, len=32814) logger.debug("First bytes: {}".format(bin2hex(first_bytes))) @@ -40,10 +35,12 @@ def handle(self): ssl_record = SSLv2Record(self.rfile.read(record_length)) logger.debug("SSL Record: {}".format(ssl_record)) + self.request.close() return except Exception as e: print("Error: ", e) + self.request.close() return protocol_version = TLSProtocolVersion(first_bytes[1:3]) @@ -54,5 +51,20 @@ def handle(self): with ssl_context.wrap_socket(self.request, server_side=True) as ssl_socket: print("SSL version: ", ssl_socket.version()) - def peek_data(self, len: int): - return self.request.recv(len, socket.MSG_PEEK) + +def peek_data(sock: socket.socket, len: int): + """ + Receive data from the socket without removing it from the receive buffer. + + Args: + sock (socket.socket): The socket object to receive data from. + len (int): The maximum number of bytes to receive. + + Returns: + bytes: The received data as bytes. + + Raises: + OSError: If an error occurs while receiving data. + + """ + return sock.recv(len, socket.MSG_PEEK) diff --git a/pyswitch/PySwitch.py b/pyswitch/server.py similarity index 93% rename from pyswitch/PySwitch.py rename to pyswitch/server.py index d3baeba..2c52e64 100644 --- a/pyswitch/PySwitch.py +++ b/pyswitch/server.py @@ -4,7 +4,7 @@ from socketserver import ThreadingTCPServer -from pyswitch.setup_logging import setup_logging +from pyswitch.src.setup_logging import setup_logging DEFAULT_LOGGING_LEVEL = logging.DEBUG diff --git a/pyswitch/setup_logging.py b/pyswitch/src/setup_logging.py similarity index 100% rename from pyswitch/setup_logging.py rename to pyswitch/src/setup_logging.py diff --git a/pyswitch/src/ssl.py b/pyswitch/src/ssl.py new file mode 100644 index 0000000..2a1ec8b --- /dev/null +++ b/pyswitch/src/ssl.py @@ -0,0 +1,100 @@ +from enum import Enum +from pyswitch.src.utils import bin2hex + + +class SSLv2State(Enum): + HANDSHAKE = "handshake" + DONE = "done" + + +class SSLv2HandshakeClientHello: + def __init__(self, data: bytes): + self.version = SSLProtocolVersion(data[:2]) + self.cipher_specs = data[2:22] + self.connection_id = data[22:24] + self.challenge = data[24:32] + + def __str__(self): + return "Version: {}, Cipher Specs: {}, Connection ID: {}, Challenge: {}".format( + self.version, + bin2hex(self.cipher_specs), + bin2hex(self.connection_id), + bin2hex(self.challenge), + ) + + +class SSLv2HandshakeType(Enum): + CLIENT_HELLO = 1 + CLIENT_MASTER_KEY = 2 + CLIENT_FINISHED = 3 + SERVER_HELLO = 4 + SERVER_VERIFY = 5 + SERVER_FINISHED = 6 + REQUEST_CERTIFICATE = 7 + CLIENT_CERTIFICATE = 8 + CLIENT_KEY_EXCHANGE = 9 + + +class SSLv2Handshake: + def __init__(self, data: bytes): + self.handshake_type = SSLv2HandshakeType(data[0]) + self.data = data + + def __str__(self): + return f"Handshake type: {self.handshake_type.name}, Data: {bin2hex(self.data)}" + + +class SSLv2StateMachine: + def __init__(self): + self.state = SSLv2State.HANDSHAKE + self.handshake = None + + def process(self, data: bytes): + if self.state == SSLv2State.HANDSHAKE: + self.handshake = SSLv2Handshake(data) + if self.handshake.handshake_type == SSLv2HandshakeType.CLIENT_HELLO: + client_hello = SSLv2HandshakeClientHello(self.handshake.data[1:]) + print(client_hello) + else: + raise ValueError("Invalid state") + + def __str__(self): + return f"State: {self.state}, Handshake: {self.handshake}" + + +class SSLProtocolVersion: + def __init__(self, data: bytes): + self.major = data[1] + self.minor = data[0] + + def __str__(self): + return f"Major: {self.major}, Minor: {self.minor}" + + +class SSLv2RecordHeader: + def __init__(self, data: bytes): + self.length = ((data[0] & 0x7F) << 8) | data[1] + self.is_escape = data[2] & 0x80 + if self.is_escape: + raise ValueError("SSLv2 Escape Record, not supported") + self.record_type = data[2] + + def __str__(self): + return "Length: {}, Is Escape: {}, Record Type: {}".format( + self.length, self.is_escape, self.record_type + ) + + +class SSLv2Record: + def __init__(self, data: bytes): + self.header = SSLv2RecordHeader(data[:3]) + self.record_type = data[2] + self.data = data[3:] + + def __str__(self): + return "Length: {}, Is Escape: {}, Record Type: {}, Data: {}".format( + self.header.length, + self.header.is_escape, + self.header.record_type, + bin2hex(self.data), + ) diff --git a/pyswitch/ssl_constants.py b/pyswitch/src/ssl_constants.py similarity index 100% rename from pyswitch/ssl_constants.py rename to pyswitch/src/ssl_constants.py diff --git a/pyswitch/tls.py b/pyswitch/src/tls.py similarity index 85% rename from pyswitch/tls.py rename to pyswitch/src/tls.py index 6c7993b..1e55183 100644 --- a/pyswitch/tls.py +++ b/pyswitch/src/tls.py @@ -1,5 +1,5 @@ -from pyswitch.tls_constants import TLSHandshakeType -from pyswitch.utils import bin2hex +from pyswitch.src.tls_constants import TLSHandshakeType +from pyswitch.src.utils import bin2hex class TLSProtocolVersion: diff --git a/pyswitch/tls_constants.py b/pyswitch/src/tls_constants.py similarity index 100% rename from pyswitch/tls_constants.py rename to pyswitch/src/tls_constants.py diff --git a/pyswitch/utils.py b/pyswitch/src/utils.py similarity index 100% rename from pyswitch/utils.py rename to pyswitch/src/utils.py diff --git a/pyswitch/ssl.py b/pyswitch/ssl.py deleted file mode 100644 index 7c18717..0000000 --- a/pyswitch/ssl.py +++ /dev/null @@ -1,39 +0,0 @@ -from pyswitch.utils import bin2hex - - -class SSLProtocolVersion: - def __init__(self, data: bytes): - self.major = data[1] - self.minor = data[0] - - def __str__(self): - return f"Major: {self.major}, Minor: {self.minor}" - - -class SSLv2RecordHeader: - def __init__(self, data: bytes): - self.length = ((data[0] & 0x7F) << 8) | data[1] - self.is_escape = data[2] & 0x80 - if self.is_escape: - raise ValueError("SSLv2 Escape Record, not supported") - self.record_type = data[2] - - def __str__(self): - return "Length: {}, Is Escape: {}, Record Type: {}".format( - self.length, self.is_escape, self.record_type - ) - - -class SSLv2Record: - def __init__(self, data: bytes): - self.header = SSLv2RecordHeader(data[:3]) - self.record_type = data[2] - self.data = data[3:] - - def __str__(self): - return "Length: {}, Is Escape: {}, Record Type: {}, Data: {}".format( - self.header.length, - self.header.is_escape, - self.header.record_type, - bin2hex(self.data), - ) diff --git a/tests/test_peek_data.py b/tests/test_peek_data.py new file mode 100644 index 0000000..406bbb0 --- /dev/null +++ b/tests/test_peek_data.py @@ -0,0 +1,34 @@ +import socket +import unittest +from unittest.mock import MagicMock + +from pyswitch.ConnectionHandler import peek_data + + +class PeekDataTestCase(unittest.TestCase): + def test_peek_data_returns_bytes(self): + # Create a mock socket object + sock = MagicMock(spec=socket.socket) + + # Set the return value of the recv method to a byte string + sock.recv.return_value = b"Hello, World!" + + # Call the peek_data function + result = peek_data(sock, 10) + + # Assert that the result is a byte string + self.assertIsInstance(result, bytes) + + def test_peek_data_calls_recv_with_correct_arguments(self): + # Create a mock socket object + sock = MagicMock(spec=socket.socket) + + # Call the peek_data function + peek_data(sock, 10) + + # Assert that the recv method was called with the correct arguments + sock.recv.assert_called_once_with(10, socket.MSG_PEEK) + + +if __name__ == "__main__": + unittest.main()