Skip to content

Commit

Permalink
refactor: Move TLS and server files to src directory
Browse files Browse the repository at this point in the history
  • Loading branch information
drazisil committed Aug 15, 2024
1 parent cefd1b8 commit cbc9e38
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 70 deletions.
1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def main():
# We recommend adjusting this value in production.
profiles_sample_rate=1.0,
enable_tracing=True,
debug=True,
)

switch = PySwitch()
Expand Down
20 changes: 17 additions & 3 deletions pyswitch/ConnectionHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from ssl import create_default_context
import ssl
from loguru import logger
from pyswitch.src.ssl import SSLv2Record
from pyswitch.src.ssl.ssl_v2 import SSLv2Record
from pyswitch.src.tls import TLSProtocolVersion
from pyswitch.src.tls_constants import TLSContentType
from pyswitch.src.utils import bin2hex
from pyswitch.src.utils import assert_enough_data, bin2hex, is_msb_set
import socket
from socketserver import StreamRequestHandler

Expand All @@ -31,7 +31,21 @@ def handle(self):
except ValueError:
# Unable to parse as TLS, try SSL
try:
record_length = ((first_bytes[0] & 0x7F) << 8) | first_bytes[1]
if is_msb_set(first_bytes[0]):
logger.debug("Length MSB is set, record is 2 bytes")
record_length = ((first_bytes[0] & 0x7F) << 8) | first_bytes[1]
else:
logger.debug("Length MSB is not set, record is 3 bytes")
record_length = ((first_bytes[0] & 0x3F) << 8) | first_bytes[1]

logger.debug("Record length: {}".format(record_length))
try:
assert_enough_data(
len(self.rfile.peek(record_length)), record_length
)
except ValueError as e:
logger.error("Error: {}".format(e))

ssl_record = SSLv2Record(self.rfile.read(record_length))

logger.debug("SSL Record: {}".format(ssl_record))
Expand Down
6 changes: 6 additions & 0 deletions pyswitch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import logging


DEFAULT_LOGGING_LEVEL = logging.DEBUG
LISTEN_HOST = "0.0.0.0"
LISTEN_PORT = 443
26 changes: 10 additions & 16 deletions pyswitch/server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
from pyswitch.ConnectionHandler import ConnectionHandler


from socketserver import ThreadingTCPServer

from pyswitch.config import DEFAULT_LOGGING_LEVEL, LISTEN_HOST, LISTEN_PORT
from pyswitch.src.setup_logging import setup_logging

DEFAULT_LOGGING_LEVEL = logging.DEBUG


class PySwitch:
server: ThreadingTCPServer
Expand All @@ -16,18 +14,14 @@ def __init__(self):

setup_logging(DEFAULT_LOGGING_LEVEL)

self.server = ThreadingTCPServer(("0.0.0.0", 443), ConnectionHandler)

def run(self):
try:
self.server.serve_forever()
self.server = ThreadingTCPServer(
(LISTEN_HOST, LISTEN_PORT), ConnectionHandler
)
except OSError as e:
print("Error: ", e)
print("Exiting")
exit(1)

except KeyboardInterrupt:
print("Caught KeyboardInterrupt")
print("Exiting")
self.server.server_close()
exit(0)
if e.errno == 98:
print("Port is already in use")
exit(1)

def run(self):
self.server.serve_forever()
61 changes: 17 additions & 44 deletions pyswitch/src/ssl.py → pyswitch/src/ssl/ssl_v2.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,25 @@
from enum import Enum
from loguru import logger
from pyswitch.src.ssl.ssl_v2_constants import SSLv2HandshakeType
from pyswitch.src.utils import bin2hex


class SSLv2State(Enum):
HANDSHAKE = "handshake"
DONE = "done"
from pyswitch.src.utils import assert_enough_data


class SSLv2HandshakeClientHello:
def __init__(self, data: bytes):
self.version = SSLProtocolVersion(data[:2])
self.client_version = SSLProtocolVersion(data[:2])
self.cipher_specs = data[2:22]
self.connection_id = data[22:24]
self.session_id = data[22:24]
self.challenge = data[24:32]

def __str__(self):
return "Version: {}, Cipher Specs: {}, Connection ID: {}, Challenge: {}".format(
self.version,
self.client_version,
bin2hex(self.cipher_specs),
bin2hex(self.connection_id),
bin2hex(self.session_id),
bin2hex(self.challenge),
)


class SSLv2HandshakeType(Enum):
CLIENT_HELLO = 1
CLIENT_MASTER_KEY = 2
CLIENT_FINISHED = 3
SERVER_HELLO = 4
SERVER_VERIFY = 5
SERVER_FINISHED = 6
REQUEST_CERTIFICATE = 7
CLIENT_CERTIFICATE = 8
CLIENT_KEY_EXCHANGE = 9


class SSLv2Handshake:
def __init__(self, data: bytes):
self.handshake_type = SSLv2HandshakeType(data[0])
Expand All @@ -44,24 +29,6 @@ def __str__(self):
return f"Handshake type: {self.handshake_type.name}, Data: {bin2hex(self.data)}"


class SSLv2StateMachine:
def __init__(self):
self.state = SSLv2State.HANDSHAKE
self.handshake = None

def process(self, data: bytes):
if self.state == SSLv2State.HANDSHAKE:
self.handshake = SSLv2Handshake(data)
if self.handshake.handshake_type == SSLv2HandshakeType.CLIENT_HELLO:
client_hello = SSLv2HandshakeClientHello(self.handshake.data[1:])
print(client_hello)
else:
raise ValueError("Invalid state")

def __str__(self):
return f"State: {self.state}, Handshake: {self.handshake}"


class SSLProtocolVersion:
def __init__(self, data: bytes):
self.major = data[1]
Expand All @@ -77,7 +44,13 @@ def __init__(self, data: bytes):
self.is_escape = data[2] & 0x80
if self.is_escape:
raise ValueError("SSLv2 Escape Record, not supported")
self.record_type = data[2]
try:
logger.debug("Data length: {}".format(len(data)))
assert_enough_data(len(data), self.length)
except ValueError as e:
logger.error("Error: {}".format(e))
logger.error("Data: {}".format(bin2hex(data)))
self.record_type = bin2hex(data[2:3])

def __str__(self):
return "Length: {}, Is Escape: {}, Record Type: {}".format(
Expand All @@ -87,9 +60,9 @@ def __str__(self):

class SSLv2Record:
def __init__(self, data: bytes):
self.header = SSLv2RecordHeader(data[:3])
self.record_type = data[2]
self.data = data[3:]
self.header = SSLv2RecordHeader(data)
self.record_type = self.header.record_type
self.data = data

def __str__(self):
return "Length: {}, Is Escape: {}, Record Type: {}, Data: {}".format(
Expand Down
61 changes: 61 additions & 0 deletions pyswitch/src/ssl/ssl_v2_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from enum import Enum


class SSLContentType(Enum):
HANDSHAKE = "handshake"
DATA = "data"


class SSLv2HandshakeType(Enum):
CLIENT_HELLO = 1
CLIENT_MASTER_KEY = 2
CLIENT_FINISHED = 3
SERVER_HELLO = 4
SERVER_VERIFY = 5
SERVER_FINISHED = 6
REQUEST_CERTIFICATE = 7
CLIENT_CERTIFICATE = 8
CLIENT_KEY_EXCHANGE = 9


# SSLv2 Handshake Valid State Transition Map (from_state -> to_state)
SSLv2HandshakeTransitionMap = {
SSLv2HandshakeType.CLIENT_HELLO: [SSLv2HandshakeType.SERVER_HELLO],
SSLv2HandshakeType.SERVER_HELLO: [
SSLv2HandshakeType.CLIENT_MASTER_KEY,
SSLv2HandshakeType.CLIENT_FINISHED,
],
SSLv2HandshakeType.CLIENT_MASTER_KEY: [SSLv2HandshakeType.CLIENT_FINISHED],
SSLv2HandshakeType.CLIENT_FINISHED: [SSLv2HandshakeType.SERVER_VERIFY],
SSLv2HandshakeType.SERVER_VERIFY: [
SSLv2HandshakeType.SERVER_FINISHED,
SSLv2HandshakeType.REQUEST_CERTIFICATE,
],
SSLv2HandshakeType.SERVER_FINISHED: [],
SSLv2HandshakeType.REQUEST_CERTIFICATE: [SSLv2HandshakeType.CLIENT_CERTIFICATE],
SSLv2HandshakeType.CLIENT_CERTIFICATE: [SSLv2HandshakeType.SERVER_FINISHED],
}


class SLv2HandshakeState(Enum):
CLIENT_HELLO = "client_hello"
SERVER_HELLO = "server_hello"
CLIENT_MASTER_KEY = "client_master_key"
CLIENT_FINISHED = "client_finished"
SERVER_VERIFY = "server_verify"
REQUEST_CERTIFICATE = "request_certificate"
CLIENT_CERTIFICATE = "client_certificate"
SERVER_FINISHED = "server_finished"


class SSLv2State(Enum):
"""
Represents the state of the SSLv2 protocol.
Attributes:
HANDSHAKE (str): Represents the state during the handshake phase.
DONE (str): Represents the state when the handshake is completed.
"""

HANDSHAKE = "handshake"
DONE = "done"
55 changes: 55 additions & 0 deletions pyswitch/src/ssl/ssl_v2_state_machine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pyswitch.src.ssl.ssl_v2 import (
SSLv2Handshake,
SSLv2HandshakeClientHello,
)
from pyswitch.src.ssl.ssl_v2_constants import (
SSLv2HandshakeTransitionMap,
SSLv2HandshakeType,
SSLv2State,
)


def is_valid_sslv2_handshake_transition(
from_state: SSLv2HandshakeType,
to_state: SSLv2HandshakeType,
) -> bool:
"""
Checks if a transition from the current state to the given state is valid.
Args:
from_state (SSLv2HandshakeType): The state from which the transition is made.
to_state (SSLv2HandshakeType): The state to which the transition is made.
Returns:
bool: True if the transition is valid, False otherwise.
"""
return to_state in SSLv2HandshakeTransitionMap[from_state]


class SSLv2StateMachine:
"""
Represents the SSLv2 state machine.
Attributes:
state (SSLv2State): The current state of the state machine.
handshake (SSLv2Handshake): The current handshake being processed.
Methods:
process(data: bytes): Processes the given data based on the current state.
"""

def __init__(self):
self.state = SSLv2State.HANDSHAKE
self.handshake = None

def process(self, data: bytes):
if self.state == SSLv2State.HANDSHAKE:
self.handshake = SSLv2Handshake(data)
if self.handshake.handshake_type == SSLv2HandshakeType.CLIENT_HELLO:
client_hello = SSLv2HandshakeClientHello(self.handshake.data[1:])
print(client_hello)
else:
raise ValueError("Invalid state")

def __str__(self):
return f"State: {self.state}, Handshake: {self.handshake}"
6 changes: 0 additions & 6 deletions pyswitch/src/ssl_constants.py

This file was deleted.

20 changes: 20 additions & 0 deletions pyswitch/src/utils.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,22 @@
def bin2hex(bin_str):
return "".join(format(x, "02x") for x in bin_str)


def assert_enough_data(data_length: int, expected_length: int):
if data_length < expected_length:
raise ValueError(
f"Expected at least {expected_length} bytes, got {data_length} bytes"
)


def is_msb_set(byte):
"""
Checks if the most significant bit (MSB) of the given byte is set.
Parameters:
- byte: An integer representing the byte to check.
Returns:
- A boolean value indicating whether the MSB is set (True) or not (False).
"""
return byte & 0x80 == 0x80
26 changes: 26 additions & 0 deletions tests/test_is_valid_sslv2_handshake_transition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest

from pyswitch.src.ssl.ssl_v2_constants import SSLv2HandshakeType
from pyswitch.src.ssl.ssl_v2_state_machine import (
is_valid_sslv2_handshake_transition,
)


class SSLv2StateMachineTestCase(unittest.TestCase):
def test_valid_transition(self):
# Test a valid transition from SSLvHandshakeType.CLIENT_HELLO to SSLv2HandshakeType.SERVER_HELLO
from_state = SSLv2HandshakeType.CLIENT_HELLO
to_state = SSLv2HandshakeType.SERVER_HELLO
result = is_valid_sslv2_handshake_transition(from_state, to_state)
self.assertTrue(result)

def test_invalid_transition(self):
# Test an invalid transition from SSLv2HandshakeType.CLIENT_HELLO to SSLv2HandshakeType.CLIENT_MASTER_KEY
from_state = SSLv2HandshakeType.CLIENT_HELLO
to_state = SSLv2HandshakeType.CLIENT_MASTER_KEY
result = is_valid_sslv2_handshake_transition(from_state, to_state)
self.assertFalse(result)


if __name__ == "__main__":
unittest.main()

0 comments on commit cbc9e38

Please sign in to comment.