From 906890f9143546f8af93c18993a4fa5f0e37d31a Mon Sep 17 00:00:00 2001 From: Florian Agbuya Date: Wed, 2 Oct 2024 17:16:22 +0800 Subject: [PATCH] add optional ssl support Signed-off-by: Florian Agbuya --- doc/index.rst | 55 +++++++++++++++++++++++++++++++++++++ sipyco/asyncio_tools.py | 18 ++++++++++++- sipyco/pc_rpc.py | 58 +++++++++++++++++++++++++++++++++------- sipyco/sipyco_rpctool.py | 8 +++++- 4 files changed, 127 insertions(+), 12 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index 6042cd0..4063c43 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -127,3 +127,58 @@ Command-line details: .. argparse:: :ref: sipyco.sipyco_rpctool.get_argparser :prog: sipyco_rpctool + + +SSL Setup +========= + +SiPyCo supports SSL/TLS encryption with mutual authentication for secure communication, but it is disabled by default. To enable and use SSL, follow these steps: + +**Generate CA certificate:** + +.. code-block:: bash + + openssl req -x509 -newkey rsa:2048 -keyout ca.key -nodes -out ca.pem -sha256 -days 1095 + +**Generate server and client certificates:** + +.. code-block:: bash + + openssl req -x509 -newkey rsa:2048 -keyout server.key -nodes -out server.pem -sha256 -days 1095 + + openssl req -x509 -newkey rsa:2048 -keyout client.key -nodes -out client.pem -sha256 -days 1095 + +When prompted, enter appropriate information for your certificates like organization name, common name (CN), etc. + +This creates: + +- A CA certificate (``ca.pem``) and key (``ca.key``) +- A server certificate (``server.pem``) and key (``server.key``) +- A client certificate (``client.pem``) and key (``client.key``) + + +Enabling SSL +------------ + +To enable SSL, both server and client need their respective certificates and keys, plus the CA certificate for verification: + +**For servers:** + +.. code-block:: python + + simple_server_loop(targets, host, port, + certfile="path/to/server.pem", + keyfile="path/to/server.key", + cafile="path/to/ca.pem") + +**For clients:** + +.. code-block:: python + + client = Client(host, port, + certfile="path/to/client.pem", + keyfile="path/to/client.key", + cafile="path/to/ca.pem") + +.. note:: + When SSL is enabled, mutual TLS authentication is mandatory. Both server and client must provide valid certificates signed by the same CA for the connection to be established. \ No newline at end of file diff --git a/sipyco/asyncio_tools.py b/sipyco/asyncio_tools.py index 4a96dcb..cf39c31 100644 --- a/sipyco/asyncio_tools.py +++ b/sipyco/asyncio_tools.py @@ -4,6 +4,7 @@ import atexit import collections import logging +import ssl from copy import copy from sipyco import keepalive @@ -43,9 +44,23 @@ class AsyncioServer: Users of this class must derive from it and define the :meth:`~sipyco.asyncio_server.AsyncioServer._handle_connection_cr` method/coroutine. + + :param certfile: Server's SSL certificate file. Providing this enables SSL. + :param keyfile: Server's private key file. Required when certfile is provided. + :param cafile: CA cert file for verifying client certificates. Required when SSL is enabled. """ - def __init__(self): + def __init__(self, certfile=None, keyfile=None, cafile=None): self._client_tasks = set() + self.ssl_context = None + if certfile: + if not keyfile: + raise ValueError("keyfile is required when certfile is provided") + if not cafile: + raise ValueError("cafile is required when SSL is enabled") + self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self.ssl_context.load_cert_chain(certfile, keyfile) + self.ssl_context.verify_mode = ssl.CERT_REQUIRED + self.ssl_context.load_verify_locations(cafile) async def start(self, host, port): """Starts the server. @@ -61,6 +76,7 @@ async def start(self, host, port): """ self.server = await asyncio.start_server(self._handle_connection, host, port, + ssl=self.ssl_context, limit=4*1024*1024) async def stop(self): diff --git a/sipyco/pc_rpc.py b/sipyco/pc_rpc.py index 69a90d8..cd711c2 100644 --- a/sipyco/pc_rpc.py +++ b/sipyco/pc_rpc.py @@ -17,6 +17,7 @@ import socket import threading import time +import ssl from operator import itemgetter from sipyco import keepalive, pyon @@ -97,6 +98,9 @@ class Client: Use ``None`` to skip selecting a target. The list of targets can then be retrieved using :meth:`~sipyco.pc_rpc.Client.get_rpc_id` and then one can be selected later using :meth:`~sipyco.pc_rpc.Client.select_rpc_target`. + :param certfile: Client's certificate file. Providing this enables SSL. + :param keyfile: Client's private key file. Required when certfile is provided. + :param cafile: CA cert file for verifying server certificate. Required when SSL is enabled. :param timeout: Socket operation timeout. Use ``None`` for blocking (default), ``0`` for non-blocking, and a finite value to raise ``socket.timeout`` if an operation does not complete within the @@ -106,9 +110,17 @@ class Client: client). """ - def __init__(self, host, port, target_name=AutoTarget, timeout=None): + def __init__(self, host, port, target_name=AutoTarget, + certfile=None, keyfile=None, cafile=None, timeout=None): self.__socket = socket.create_connection((host, port), timeout) - + if certfile: + if not keyfile: + raise ValueError("keyfile is required when certfile is provided") + if not cafile: + raise ValueError("cafile is required when SSL is enabled") + ssl_context = ssl.create_default_context(cafile=cafile) + ssl_context.load_cert_chain(certfile, keyfile) + self.__socket = ssl_context.wrap_socket(self.__socket, server_hostname=host) try: self.__socket.sendall(_init_string) @@ -206,12 +218,21 @@ def __init__(self): self.__description = None self.__valid_methods = set() - async def connect_rpc(self, host, port, target_name=AutoTarget): + async def connect_rpc(self, host, port, target_name=AutoTarget, + certfile=None, keyfile=None, cafile=None): """Connects to the server. This cannot be done in __init__ because this method is a coroutine. See :class:`sipyco.pc_rpc.Client` for a description of the parameters.""" + ssl_context = None + if certfile: + if not keyfile: + raise ValueError("keyfile is required when certfile is provided") + if not cafile: + raise ValueError("cafile is required when SSL is enabled") + ssl_context = ssl.create_default_context(cafile=cafile) + ssl_context.load_cert_chain(certfile, keyfile) self.__reader, self.__writer = \ - await keepalive.async_open_connection(host, port, limit=100 * 1024 * 1024) + await keepalive.async_open_connection(host, port, ssl=ssl_context, limit=100 * 1024 * 1024) try: self.__writer.write(_init_string) server_identification = await self.__recv() @@ -309,11 +330,14 @@ class BestEffortClient: in the background. """ - def __init__(self, host, port, target_name, - firstcon_timeout=1.0, retry=5.0): + def __init__(self, host, port, target_name, certfile=None, + keyfile=None, cafile=None, firstcon_timeout=1.0, retry=5.0): self.__host = host self.__port = port self.__target_name = target_name + self.__certfile = certfile + self.__keyfile = keyfile + self.__cafile = cafile self.__retry = retry self.__conretry_terminate = False @@ -337,6 +361,15 @@ def __coninit(self, timeout): else: self.__socket = socket.create_connection( (self.__host, self.__port), timeout) + if self.__certfile: + if not self.__keyfile: + raise ValueError("keyfile is required when certfile is provided") + if not self.__cafile: + raise ValueError("cafile is required when SSL is enabled") + ssl_context = ssl.create_default_context(cafile=self.__cafile) + ssl_context.load_cert_chain(self.__certfile, self.__keyfile) + self.__socket = ssl_context.wrap_socket(self.__socket, + server_hostname=self.__host) self.__socket.sendall(_init_string) server_identification = self.__recv() target_name = _validate_target_name(self.__target_name, @@ -485,11 +518,15 @@ class Server(_AsyncioServer): requests from clients. :param allow_parallel: Allow concurrent asyncio calls to the target's methods. + :param certfile: Path to the server's SSL certificate file. If provided along + with ``keyfile``, the server will use SSL encryption. + :param keyfile: Path to the server's SSL private key file. + :param cafile: Path to CA certificate file for verifying client certificates. """ def __init__(self, targets, description=None, builtin_terminate=False, - allow_parallel=False): - _AsyncioServer.__init__(self) + allow_parallel=False, certfile=None, keyfile=None, cafile=None): + _AsyncioServer.__init__(self, certfile=certfile, keyfile=keyfile, cafile=cafile) self.targets = targets self.description = description self.builtin_terminate = builtin_terminate @@ -635,7 +672,8 @@ async def wait_terminate(self): await self._terminate_request.wait() -def simple_server_loop(targets, host, port, description=None, allow_parallel=False, *, loop=None): +def simple_server_loop(targets, host, port, description=None, allow_parallel=False, *, loop=None, + certfile=None, keyfile=None, cafile=None): """Runs a server until an exception is raised (e.g. the user hits Ctrl-C) or termination is requested by a client. @@ -650,7 +688,7 @@ def simple_server_loop(targets, host, port, description=None, allow_parallel=Fal signal_handler = SignalHandler() signal_handler.setup() try: - server = Server(targets, description, True, allow_parallel) + server = Server(targets, description, True, allow_parallel, certfile, keyfile, cafile) used_loop.run_until_complete(server.start(host, port)) try: _, pending = used_loop.run_until_complete(asyncio.wait( diff --git a/sipyco/sipyco_rpctool.py b/sipyco/sipyco_rpctool.py index cb9106c..b511d5e 100755 --- a/sipyco/sipyco_rpctool.py +++ b/sipyco/sipyco_rpctool.py @@ -18,6 +18,12 @@ def get_argparser(): help="hostname or IP of the controller to connect to") parser.add_argument("port", metavar="PORT", type=int, help="TCP port to use to connect to the controller") + parser.add_argument("--ca-cert", default=None, + help="CA certificate file for verifying server certificate") + parser.add_argument("--cert", default=None, + help="SSL/TLS certificate file. Providing this enables SSL/TLS (default: %(default)s)") + parser.add_argument("--key", default=None, + help="Client private key file for mutual TLS authentication") subparsers = parser.add_subparsers(dest="action") subparsers.add_parser("list-targets", help="list existing targets") parser_list_methods = subparsers.add_parser("list-methods", @@ -98,7 +104,7 @@ def main(): if not args.action: args.target = None - remote = Client(args.server, args.port, None) + remote = Client(args.server, args.port, None, cafile=args.ca_cert, certfile=args.cert, keyfile=args.key) targets, description = remote.get_rpc_id() if args.action != "list-targets": if not args.target: