From 40c951d8bef8c34cbddba91790afbe5b446587fe Mon Sep 17 00:00:00 2001 From: Florian Agbuya Date: Wed, 2 Oct 2024 17:16:22 +0800 Subject: [PATCH] add optional ssl support --- sipyco/asyncio_tools.py | 9 ++++++++- sipyco/pc_rpc.py | 30 ++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/sipyco/asyncio_tools.py b/sipyco/asyncio_tools.py index 4a96dcb..572f793 100644 --- a/sipyco/asyncio_tools.py +++ b/sipyco/asyncio_tools.py @@ -4,6 +4,7 @@ import atexit import collections import logging +import ssl from copy import copy from sipyco import keepalive @@ -44,8 +45,13 @@ class AsyncioServer: :meth:`~sipyco.asyncio_server.AsyncioServer._handle_connection_cr` method/coroutine. """ - def __init__(self): + def __init__(self, certfile=None, keyfile=None): self._client_tasks = set() + if certfile: + self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self.ssl_context.load_cert_chain(certfile, keyfile) + else: + self.ssl_context = None async def start(self, host, port): """Starts the server. @@ -61,6 +67,7 @@ async def start(self, host, port): """ self.server = await asyncio.start_server(self._handle_connection, host, port, + ssl=self.ssl_context, limit=4*1024*1024) async def stop(self): diff --git a/sipyco/pc_rpc.py b/sipyco/pc_rpc.py index aa2bf19..26233e2 100644 --- a/sipyco/pc_rpc.py +++ b/sipyco/pc_rpc.py @@ -17,6 +17,7 @@ import socket import threading import time +import ssl from operator import itemgetter from sipyco import keepalive, pyon @@ -106,9 +107,12 @@ class Client: client). """ - def __init__(self, host, port, target_name=AutoTarget, timeout=None): + def __init__(self, host, port, target_name=AutoTarget, timeout=None, cafile=None): self.__socket = socket.create_connection((host, port), timeout) + if cafile: + ssl_context = ssl.create_default_context(cafile=cafile) + self.__socket = ssl_context.wrap_socket(self.__socket, server_hostname=host) try: self.__socket.sendall(_init_string) @@ -206,12 +210,16 @@ def __init__(self): self.__description = None self.__valid_methods = set() - async def connect_rpc(self, host, port, target_name=AutoTarget): + async def connect_rpc(self, host, port, target_name=AutoTarget, cafile=None): """Connects to the server. This cannot be done in __init__ because this method is a coroutine. See :class:`sipyco.pc_rpc.Client` for a description of the parameters.""" + if cafile: + ssl_context = ssl.create_default_context(cafile=cafile) + else: + ssl_context = None self.__reader, self.__writer = \ - await keepalive.async_open_connection(host, port, limit=100 * 1024 * 1024) + await keepalive.async_open_connection(host, port, ssl=ssl_context, limit=100 * 1024 * 1024) try: self.__writer.write(_init_string) server_identification = await self.__recv() @@ -310,11 +318,12 @@ class BestEffortClient: """ def __init__(self, host, port, target_name, - firstcon_timeout=1.0, retry=5.0): + firstcon_timeout=1.0, retry=5.0, cafile=None): self.__host = host self.__port = port self.__target_name = target_name self.__retry = retry + self.__cafile = cafile self.__conretry_terminate = False self.__socket = None @@ -337,6 +346,10 @@ def __coninit(self, timeout): else: self.__socket = socket.create_connection( (self.__host, self.__port), timeout) + if self.__cafile: + ssl_context = ssl.create_default_context(cafile=self.__cafile) + self.__socket = ssl_context.wrap_socket(self.__socket, + server_hostname=self.__host) self.__socket.sendall(_init_string) server_identification = self.__recv() target_name = _validate_target_name(self.__target_name, @@ -488,8 +501,8 @@ class Server(_AsyncioServer): """ def __init__(self, targets, description=None, builtin_terminate=False, - allow_parallel=False): - _AsyncioServer.__init__(self) + allow_parallel=False, certfile=None, keyfile=None): + _AsyncioServer.__init__(self, certfile=certfile, keyfile=keyfile) self.targets = targets self.description = description self.builtin_terminate = builtin_terminate @@ -636,7 +649,8 @@ async def wait_terminate(self): await self._terminate_request.wait() -def simple_server_loop(targets, host, port, description=None, allow_parallel=False, *, loop=None): +def simple_server_loop(targets, host, port, description=None, allow_parallel=False, *, loop=None, + certfile=None, keyfile=None): """Runs a server until an exception is raised (e.g. the user hits Ctrl-C) or termination is requested by a client. @@ -651,7 +665,7 @@ def simple_server_loop(targets, host, port, description=None, allow_parallel=Fal signal_handler = SignalHandler() signal_handler.setup() try: - server = Server(targets, description, True, allow_parallel) + server = Server(targets, description, True, allow_parallel, certfile, keyfile) used_loop.run_until_complete(server.start(host, port)) try: _, pending = used_loop.run_until_complete(asyncio.wait(