Skip to content

Commit

Permalink
feat: Move TLS and server files to src directory
Browse files Browse the repository at this point in the history
  • Loading branch information
drazisil committed Aug 13, 2024
1 parent c54b0dd commit ede62ef
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 59 deletions.
7 changes: 5 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
{
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "basic"
}
"python.analysis.typeCheckingMode": "basic",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"python.envFile": "${workspaceFolder}/.venv"
}
17 changes: 17 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
lint:
# stop the build if there are Python syntax errors or undefined names
@pdm run flake8 pyswitch --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
@pdm run flake8 pyswitch --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics


test:
@pdm run pytest --verbose --cov=pyswitch --cov-report=term-missing && pdm run python -m coverage xml

install:
@pdm install

start:
@python main.py

.PHONY: lint test install start
6 changes: 3 additions & 3 deletions server.py → main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from os import getenv
from pyswitch.PySwitch import PySwitch
from dotenv import load_dotenv, find_dotenv
from pyswitch.server import PySwitch
from dotenv import load_dotenv

import sentry_sdk


def main():
load_dotenv(find_dotenv())
load_dotenv(verbose=True)

sentry_dsn = getenv("SENTRY_DSN")

Expand Down
36 changes: 24 additions & 12 deletions pyswitch/ConnectionHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from ssl import create_default_context
import ssl
from loguru import logger
from pyswitch.ssl import SSLv2Record
from pyswitch.tls import TLSProtocolVersion
from pyswitch.tls_constants import TLSContentType
from pyswitch.utils import bin2hex
from pyswitch.src.ssl import SSLv2Record
from pyswitch.src.tls import TLSProtocolVersion
from pyswitch.src.tls_constants import TLSContentType
from pyswitch.src.utils import bin2hex
import socket
from socketserver import StreamRequestHandler

Expand All @@ -16,17 +16,12 @@ class ConnectionHandler(StreamRequestHandler):

def handle(self):
logger.debug(
"Connection from: ", self.client_address, "id: ", self.request.fileno()
)
print(
"Connection from: {}, id: {}".format(
self.client_address, self.request.fileno()
)
)

first_bytes = self.peek_data(len=32814)

print("first_bytes_len: ", len(first_bytes))
first_bytes = peek_data(self.request, len=32814)

logger.debug("First bytes: {}".format(bin2hex(first_bytes)))

Expand All @@ -40,10 +35,12 @@ def handle(self):
ssl_record = SSLv2Record(self.rfile.read(record_length))

logger.debug("SSL Record: {}".format(ssl_record))
self.request.close()
return

except Exception as e:
print("Error: ", e)
self.request.close()
return

protocol_version = TLSProtocolVersion(first_bytes[1:3])
Expand All @@ -54,5 +51,20 @@ def handle(self):
with ssl_context.wrap_socket(self.request, server_side=True) as ssl_socket:
print("SSL version: ", ssl_socket.version())

def peek_data(self, len: int):
return self.request.recv(len, socket.MSG_PEEK)

def peek_data(sock: socket.socket, len: int):
"""
Receive data from the socket without removing it from the receive buffer.
Args:
sock (socket.socket): The socket object to receive data from.
len (int): The maximum number of bytes to receive.
Returns:
bytes: The received data as bytes.
Raises:
OSError: If an error occurs while receiving data.
"""
return sock.recv(len, socket.MSG_PEEK)
2 changes: 1 addition & 1 deletion pyswitch/PySwitch.py → pyswitch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from socketserver import ThreadingTCPServer

from pyswitch.setup_logging import setup_logging
from pyswitch.src.setup_logging import setup_logging

DEFAULT_LOGGING_LEVEL = logging.DEBUG

Expand Down
File renamed without changes.
100 changes: 100 additions & 0 deletions pyswitch/src/ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from enum import Enum
from pyswitch.src.utils import bin2hex


class SSLv2State(Enum):
HANDSHAKE = "handshake"
DONE = "done"


class SSLv2HandshakeClientHello:
def __init__(self, data: bytes):
self.version = SSLProtocolVersion(data[:2])
self.cipher_specs = data[2:22]
self.connection_id = data[22:24]
self.challenge = data[24:32]

def __str__(self):
return "Version: {}, Cipher Specs: {}, Connection ID: {}, Challenge: {}".format(
self.version,
bin2hex(self.cipher_specs),
bin2hex(self.connection_id),
bin2hex(self.challenge),
)


class SSLv2HandshakeType(Enum):
CLIENT_HELLO = 1
CLIENT_MASTER_KEY = 2
CLIENT_FINISHED = 3
SERVER_HELLO = 4
SERVER_VERIFY = 5
SERVER_FINISHED = 6
REQUEST_CERTIFICATE = 7
CLIENT_CERTIFICATE = 8
CLIENT_KEY_EXCHANGE = 9


class SSLv2Handshake:
def __init__(self, data: bytes):
self.handshake_type = SSLv2HandshakeType(data[0])
self.data = data

def __str__(self):
return f"Handshake type: {self.handshake_type.name}, Data: {bin2hex(self.data)}"


class SSLv2StateMachine:
def __init__(self):
self.state = SSLv2State.HANDSHAKE
self.handshake = None

def process(self, data: bytes):
if self.state == SSLv2State.HANDSHAKE:
self.handshake = SSLv2Handshake(data)
if self.handshake.handshake_type == SSLv2HandshakeType.CLIENT_HELLO:
client_hello = SSLv2HandshakeClientHello(self.handshake.data[1:])
print(client_hello)
else:
raise ValueError("Invalid state")

def __str__(self):
return f"State: {self.state}, Handshake: {self.handshake}"


class SSLProtocolVersion:
def __init__(self, data: bytes):
self.major = data[1]
self.minor = data[0]

def __str__(self):
return f"Major: {self.major}, Minor: {self.minor}"


class SSLv2RecordHeader:
def __init__(self, data: bytes):
self.length = ((data[0] & 0x7F) << 8) | data[1]
self.is_escape = data[2] & 0x80
if self.is_escape:
raise ValueError("SSLv2 Escape Record, not supported")
self.record_type = data[2]

def __str__(self):
return "Length: {}, Is Escape: {}, Record Type: {}".format(
self.length, self.is_escape, self.record_type
)


class SSLv2Record:
def __init__(self, data: bytes):
self.header = SSLv2RecordHeader(data[:3])
self.record_type = data[2]
self.data = data[3:]

def __str__(self):
return "Length: {}, Is Escape: {}, Record Type: {}, Data: {}".format(
self.header.length,
self.header.is_escape,
self.header.record_type,
bin2hex(self.data),
)
File renamed without changes.
4 changes: 2 additions & 2 deletions pyswitch/tls.py → pyswitch/src/tls.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pyswitch.tls_constants import TLSHandshakeType
from pyswitch.utils import bin2hex
from pyswitch.src.tls_constants import TLSHandshakeType
from pyswitch.src.utils import bin2hex


class TLSProtocolVersion:
Expand Down
File renamed without changes.
File renamed without changes.
39 changes: 0 additions & 39 deletions pyswitch/ssl.py

This file was deleted.

34 changes: 34 additions & 0 deletions tests/test_peek_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import socket
import unittest
from unittest.mock import MagicMock

from pyswitch.ConnectionHandler import peek_data


class PeekDataTestCase(unittest.TestCase):
def test_peek_data_returns_bytes(self):
# Create a mock socket object
sock = MagicMock(spec=socket.socket)

# Set the return value of the recv method to a byte string
sock.recv.return_value = b"Hello, World!"

# Call the peek_data function
result = peek_data(sock, 10)

# Assert that the result is a byte string
self.assertIsInstance(result, bytes)

def test_peek_data_calls_recv_with_correct_arguments(self):
# Create a mock socket object
sock = MagicMock(spec=socket.socket)

# Call the peek_data function
peek_data(sock, 10)

# Assert that the recv method was called with the correct arguments
sock.recv.assert_called_once_with(10, socket.MSG_PEEK)


if __name__ == "__main__":
unittest.main()

0 comments on commit ede62ef

Please sign in to comment.