diff --git a/main.py b/main.py index eb08909..3b2abd7 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,6 @@ def main(): # We recommend adjusting this value in production. profiles_sample_rate=1.0, enable_tracing=True, - debug=True, ) switch = PySwitch() diff --git a/pyswitch/ConnectionHandler.py b/pyswitch/ConnectionHandler.py index 11d2997..b5e649e 100644 --- a/pyswitch/ConnectionHandler.py +++ b/pyswitch/ConnectionHandler.py @@ -2,10 +2,10 @@ from ssl import create_default_context import ssl from loguru import logger -from pyswitch.src.ssl import SSLv2Record +from pyswitch.src.ssl.ssl_v2 import SSLv2Record from pyswitch.src.tls import TLSProtocolVersion from pyswitch.src.tls_constants import TLSContentType -from pyswitch.src.utils import bin2hex +from pyswitch.src.utils import assert_enough_data, bin2hex, is_msb_set import socket from socketserver import StreamRequestHandler @@ -31,7 +31,21 @@ def handle(self): except ValueError: # Unable to parse as TLS, try SSL try: - record_length = ((first_bytes[0] & 0x7F) << 8) | first_bytes[1] + if is_msb_set(first_bytes[0]): + logger.debug("Length MSB is set, record is 2 bytes") + record_length = ((first_bytes[0] & 0x7F) << 8) | first_bytes[1] + else: + logger.debug("Length MSB is not set, record is 3 bytes") + record_length = ((first_bytes[0] & 0x3F) << 8) | first_bytes[1] + + logger.debug("Record length: {}".format(record_length)) + try: + assert_enough_data( + len(self.rfile.peek(record_length)), record_length + ) + except ValueError as e: + logger.error("Error: {}".format(e)) + ssl_record = SSLv2Record(self.rfile.read(record_length)) logger.debug("SSL Record: {}".format(ssl_record)) diff --git a/pyswitch/config.py b/pyswitch/config.py new file mode 100644 index 0000000..2eec19f --- /dev/null +++ b/pyswitch/config.py @@ -0,0 +1,6 @@ +import logging + + +DEFAULT_LOGGING_LEVEL = logging.DEBUG +LISTEN_HOST = "0.0.0.0" +LISTEN_PORT = 443 diff --git a/pyswitch/server.py b/pyswitch/server.py index 2c52e64..4e5cc8a 100644 --- a/pyswitch/server.py +++ b/pyswitch/server.py @@ -1,13 +1,11 @@ -import logging from pyswitch.ConnectionHandler import ConnectionHandler from socketserver import ThreadingTCPServer +from pyswitch.config import DEFAULT_LOGGING_LEVEL, LISTEN_HOST, LISTEN_PORT from pyswitch.src.setup_logging import setup_logging -DEFAULT_LOGGING_LEVEL = logging.DEBUG - class PySwitch: server: ThreadingTCPServer @@ -16,18 +14,14 @@ def __init__(self): setup_logging(DEFAULT_LOGGING_LEVEL) - self.server = ThreadingTCPServer(("0.0.0.0", 443), ConnectionHandler) - - def run(self): try: - self.server.serve_forever() + self.server = ThreadingTCPServer( + (LISTEN_HOST, LISTEN_PORT), ConnectionHandler + ) except OSError as e: - print("Error: ", e) - print("Exiting") - exit(1) - - except KeyboardInterrupt: - print("Caught KeyboardInterrupt") - print("Exiting") - self.server.server_close() - exit(0) + if e.errno == 98: + print("Port is already in use") + exit(1) + + def run(self): + self.server.serve_forever() diff --git a/pyswitch/src/ssl.py b/pyswitch/src/ssl/ssl_v2.py similarity index 55% rename from pyswitch/src/ssl.py rename to pyswitch/src/ssl/ssl_v2.py index 2a1ec8b..5880fb7 100644 --- a/pyswitch/src/ssl.py +++ b/pyswitch/src/ssl/ssl_v2.py @@ -1,40 +1,25 @@ -from enum import Enum +from loguru import logger +from pyswitch.src.ssl.ssl_v2_constants import SSLv2HandshakeType from pyswitch.src.utils import bin2hex - - -class SSLv2State(Enum): - HANDSHAKE = "handshake" - DONE = "done" +from pyswitch.src.utils import assert_enough_data class SSLv2HandshakeClientHello: def __init__(self, data: bytes): - self.version = SSLProtocolVersion(data[:2]) + self.client_version = SSLProtocolVersion(data[:2]) self.cipher_specs = data[2:22] - self.connection_id = data[22:24] + self.session_id = data[22:24] self.challenge = data[24:32] def __str__(self): return "Version: {}, Cipher Specs: {}, Connection ID: {}, Challenge: {}".format( - self.version, + self.client_version, bin2hex(self.cipher_specs), - bin2hex(self.connection_id), + bin2hex(self.session_id), bin2hex(self.challenge), ) -class SSLv2HandshakeType(Enum): - CLIENT_HELLO = 1 - CLIENT_MASTER_KEY = 2 - CLIENT_FINISHED = 3 - SERVER_HELLO = 4 - SERVER_VERIFY = 5 - SERVER_FINISHED = 6 - REQUEST_CERTIFICATE = 7 - CLIENT_CERTIFICATE = 8 - CLIENT_KEY_EXCHANGE = 9 - - class SSLv2Handshake: def __init__(self, data: bytes): self.handshake_type = SSLv2HandshakeType(data[0]) @@ -44,24 +29,6 @@ def __str__(self): return f"Handshake type: {self.handshake_type.name}, Data: {bin2hex(self.data)}" -class SSLv2StateMachine: - def __init__(self): - self.state = SSLv2State.HANDSHAKE - self.handshake = None - - def process(self, data: bytes): - if self.state == SSLv2State.HANDSHAKE: - self.handshake = SSLv2Handshake(data) - if self.handshake.handshake_type == SSLv2HandshakeType.CLIENT_HELLO: - client_hello = SSLv2HandshakeClientHello(self.handshake.data[1:]) - print(client_hello) - else: - raise ValueError("Invalid state") - - def __str__(self): - return f"State: {self.state}, Handshake: {self.handshake}" - - class SSLProtocolVersion: def __init__(self, data: bytes): self.major = data[1] @@ -77,7 +44,13 @@ def __init__(self, data: bytes): self.is_escape = data[2] & 0x80 if self.is_escape: raise ValueError("SSLv2 Escape Record, not supported") - self.record_type = data[2] + try: + logger.debug("Data length: {}".format(len(data))) + assert_enough_data(len(data), self.length) + except ValueError as e: + logger.error("Error: {}".format(e)) + logger.error("Data: {}".format(bin2hex(data))) + self.record_type = bin2hex(data[2:3]) def __str__(self): return "Length: {}, Is Escape: {}, Record Type: {}".format( @@ -87,9 +60,9 @@ def __str__(self): class SSLv2Record: def __init__(self, data: bytes): - self.header = SSLv2RecordHeader(data[:3]) - self.record_type = data[2] - self.data = data[3:] + self.header = SSLv2RecordHeader(data) + self.record_type = self.header.record_type + self.data = data def __str__(self): return "Length: {}, Is Escape: {}, Record Type: {}, Data: {}".format( diff --git a/pyswitch/src/ssl/ssl_v2_constants.py b/pyswitch/src/ssl/ssl_v2_constants.py new file mode 100644 index 0000000..8561f3a --- /dev/null +++ b/pyswitch/src/ssl/ssl_v2_constants.py @@ -0,0 +1,61 @@ +from enum import Enum + + +class SSLContentType(Enum): + HANDSHAKE = "handshake" + DATA = "data" + + +class SSLv2HandshakeType(Enum): + CLIENT_HELLO = 1 + CLIENT_MASTER_KEY = 2 + CLIENT_FINISHED = 3 + SERVER_HELLO = 4 + SERVER_VERIFY = 5 + SERVER_FINISHED = 6 + REQUEST_CERTIFICATE = 7 + CLIENT_CERTIFICATE = 8 + CLIENT_KEY_EXCHANGE = 9 + + +# SSLv2 Handshake Valid State Transition Map (from_state -> to_state) +SSLv2HandshakeTransitionMap = { + SSLv2HandshakeType.CLIENT_HELLO: [SSLv2HandshakeType.SERVER_HELLO], + SSLv2HandshakeType.SERVER_HELLO: [ + SSLv2HandshakeType.CLIENT_MASTER_KEY, + SSLv2HandshakeType.CLIENT_FINISHED, + ], + SSLv2HandshakeType.CLIENT_MASTER_KEY: [SSLv2HandshakeType.CLIENT_FINISHED], + SSLv2HandshakeType.CLIENT_FINISHED: [SSLv2HandshakeType.SERVER_VERIFY], + SSLv2HandshakeType.SERVER_VERIFY: [ + SSLv2HandshakeType.SERVER_FINISHED, + SSLv2HandshakeType.REQUEST_CERTIFICATE, + ], + SSLv2HandshakeType.SERVER_FINISHED: [], + SSLv2HandshakeType.REQUEST_CERTIFICATE: [SSLv2HandshakeType.CLIENT_CERTIFICATE], + SSLv2HandshakeType.CLIENT_CERTIFICATE: [SSLv2HandshakeType.SERVER_FINISHED], +} + + +class SLv2HandshakeState(Enum): + CLIENT_HELLO = "client_hello" + SERVER_HELLO = "server_hello" + CLIENT_MASTER_KEY = "client_master_key" + CLIENT_FINISHED = "client_finished" + SERVER_VERIFY = "server_verify" + REQUEST_CERTIFICATE = "request_certificate" + CLIENT_CERTIFICATE = "client_certificate" + SERVER_FINISHED = "server_finished" + + +class SSLv2State(Enum): + """ + Represents the state of the SSLv2 protocol. + + Attributes: + HANDSHAKE (str): Represents the state during the handshake phase. + DONE (str): Represents the state when the handshake is completed. + """ + + HANDSHAKE = "handshake" + DONE = "done" diff --git a/pyswitch/src/ssl/ssl_v2_state_machine.py b/pyswitch/src/ssl/ssl_v2_state_machine.py new file mode 100644 index 0000000..d8cdec1 --- /dev/null +++ b/pyswitch/src/ssl/ssl_v2_state_machine.py @@ -0,0 +1,55 @@ +from pyswitch.src.ssl.ssl_v2 import ( + SSLv2Handshake, + SSLv2HandshakeClientHello, +) +from pyswitch.src.ssl.ssl_v2_constants import ( + SSLv2HandshakeTransitionMap, + SSLv2HandshakeType, + SSLv2State, +) + + +def is_valid_sslv2_handshake_transition( + from_state: SSLv2HandshakeType, + to_state: SSLv2HandshakeType, +) -> bool: + """ + Checks if a transition from the current state to the given state is valid. + + Args: + from_state (SSLv2HandshakeType): The state from which the transition is made. + to_state (SSLv2HandshakeType): The state to which the transition is made. + + Returns: + bool: True if the transition is valid, False otherwise. + """ + return to_state in SSLv2HandshakeTransitionMap[from_state] + + +class SSLv2StateMachine: + """ + Represents the SSLv2 state machine. + + Attributes: + state (SSLv2State): The current state of the state machine. + handshake (SSLv2Handshake): The current handshake being processed. + + Methods: + process(data: bytes): Processes the given data based on the current state. + """ + + def __init__(self): + self.state = SSLv2State.HANDSHAKE + self.handshake = None + + def process(self, data: bytes): + if self.state == SSLv2State.HANDSHAKE: + self.handshake = SSLv2Handshake(data) + if self.handshake.handshake_type == SSLv2HandshakeType.CLIENT_HELLO: + client_hello = SSLv2HandshakeClientHello(self.handshake.data[1:]) + print(client_hello) + else: + raise ValueError("Invalid state") + + def __str__(self): + return f"State: {self.state}, Handshake: {self.handshake}" diff --git a/pyswitch/src/ssl_constants.py b/pyswitch/src/ssl_constants.py deleted file mode 100644 index e8c7116..0000000 --- a/pyswitch/src/ssl_constants.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class SSLContentType(Enum): - HANDSHAKE = "handshake" - DATA = "data" diff --git a/pyswitch/src/utils.py b/pyswitch/src/utils.py index 82e131a..a07d7f4 100644 --- a/pyswitch/src/utils.py +++ b/pyswitch/src/utils.py @@ -1,2 +1,22 @@ def bin2hex(bin_str): return "".join(format(x, "02x") for x in bin_str) + + +def assert_enough_data(data_length: int, expected_length: int): + if data_length < expected_length: + raise ValueError( + f"Expected at least {expected_length} bytes, got {data_length} bytes" + ) + + +def is_msb_set(byte): + """ + Checks if the most significant bit (MSB) of the given byte is set. + + Parameters: + - byte: An integer representing the byte to check. + + Returns: + - A boolean value indicating whether the MSB is set (True) or not (False). + """ + return byte & 0x80 == 0x80 diff --git a/tests/test_is_valid_sslv2_handshake_transition.py b/tests/test_is_valid_sslv2_handshake_transition.py new file mode 100644 index 0000000..71ea695 --- /dev/null +++ b/tests/test_is_valid_sslv2_handshake_transition.py @@ -0,0 +1,26 @@ +import unittest + +from pyswitch.src.ssl.ssl_v2_constants import SSLv2HandshakeType +from pyswitch.src.ssl.ssl_v2_state_machine import ( + is_valid_sslv2_handshake_transition, +) + + +class SSLv2StateMachineTestCase(unittest.TestCase): + def test_valid_transition(self): + # Test a valid transition from SSLvHandshakeType.CLIENT_HELLO to SSLv2HandshakeType.SERVER_HELLO + from_state = SSLv2HandshakeType.CLIENT_HELLO + to_state = SSLv2HandshakeType.SERVER_HELLO + result = is_valid_sslv2_handshake_transition(from_state, to_state) + self.assertTrue(result) + + def test_invalid_transition(self): + # Test an invalid transition from SSLv2HandshakeType.CLIENT_HELLO to SSLv2HandshakeType.CLIENT_MASTER_KEY + from_state = SSLv2HandshakeType.CLIENT_HELLO + to_state = SSLv2HandshakeType.CLIENT_MASTER_KEY + result = is_valid_sslv2_handshake_transition(from_state, to_state) + self.assertFalse(result) + + +if __name__ == "__main__": + unittest.main()