Skip to content

Commit

Permalink
add optional ssl support
Browse files Browse the repository at this point in the history
  • Loading branch information
fsagbuya committed Oct 2, 2024
1 parent 32ddd78 commit 40c951d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
9 changes: 8 additions & 1 deletion sipyco/asyncio_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import atexit
import collections
import logging
import ssl
from copy import copy

from sipyco import keepalive
Expand Down Expand Up @@ -44,8 +45,13 @@ class AsyncioServer:
:meth:`~sipyco.asyncio_server.AsyncioServer._handle_connection_cr`
method/coroutine.
"""
def __init__(self):
def __init__(self, certfile=None, keyfile=None):
self._client_tasks = set()
if certfile:
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.ssl_context.load_cert_chain(certfile, keyfile)
else:
self.ssl_context = None

async def start(self, host, port):
"""Starts the server.
Expand All @@ -61,6 +67,7 @@ async def start(self, host, port):
"""
self.server = await asyncio.start_server(self._handle_connection,
host, port,
ssl=self.ssl_context,
limit=4*1024*1024)

async def stop(self):
Expand Down
30 changes: 22 additions & 8 deletions sipyco/pc_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import socket
import threading
import time
import ssl
from operator import itemgetter

from sipyco import keepalive, pyon
Expand Down Expand Up @@ -106,9 +107,12 @@ class Client:
client).
"""

def __init__(self, host, port, target_name=AutoTarget, timeout=None):
def __init__(self, host, port, target_name=AutoTarget, timeout=None, cafile=None):
self.__socket = socket.create_connection((host, port), timeout)

if cafile:
ssl_context = ssl.create_default_context(cafile=cafile)
self.__socket = ssl_context.wrap_socket(self.__socket, server_hostname=host)
try:
self.__socket.sendall(_init_string)

Expand Down Expand Up @@ -206,12 +210,16 @@ def __init__(self):
self.__description = None
self.__valid_methods = set()

async def connect_rpc(self, host, port, target_name=AutoTarget):
async def connect_rpc(self, host, port, target_name=AutoTarget, cafile=None):
"""Connects to the server. This cannot be done in __init__ because
this method is a coroutine. See :class:`sipyco.pc_rpc.Client` for a description of the
parameters."""
if cafile:
ssl_context = ssl.create_default_context(cafile=cafile)
else:
ssl_context = None
self.__reader, self.__writer = \
await keepalive.async_open_connection(host, port, limit=100 * 1024 * 1024)
await keepalive.async_open_connection(host, port, ssl=ssl_context, limit=100 * 1024 * 1024)
try:
self.__writer.write(_init_string)
server_identification = await self.__recv()
Expand Down Expand Up @@ -310,11 +318,12 @@ class BestEffortClient:
"""

def __init__(self, host, port, target_name,
firstcon_timeout=1.0, retry=5.0):
firstcon_timeout=1.0, retry=5.0, cafile=None):
self.__host = host
self.__port = port
self.__target_name = target_name
self.__retry = retry
self.__cafile = cafile

self.__conretry_terminate = False
self.__socket = None
Expand All @@ -337,6 +346,10 @@ def __coninit(self, timeout):
else:
self.__socket = socket.create_connection(
(self.__host, self.__port), timeout)
if self.__cafile:
ssl_context = ssl.create_default_context(cafile=self.__cafile)
self.__socket = ssl_context.wrap_socket(self.__socket,
server_hostname=self.__host)
self.__socket.sendall(_init_string)
server_identification = self.__recv()
target_name = _validate_target_name(self.__target_name,
Expand Down Expand Up @@ -488,8 +501,8 @@ class Server(_AsyncioServer):
"""

def __init__(self, targets, description=None, builtin_terminate=False,
allow_parallel=False):
_AsyncioServer.__init__(self)
allow_parallel=False, certfile=None, keyfile=None):
_AsyncioServer.__init__(self, certfile=certfile, keyfile=keyfile)
self.targets = targets
self.description = description
self.builtin_terminate = builtin_terminate
Expand Down Expand Up @@ -636,7 +649,8 @@ async def wait_terminate(self):
await self._terminate_request.wait()


def simple_server_loop(targets, host, port, description=None, allow_parallel=False, *, loop=None):
def simple_server_loop(targets, host, port, description=None, allow_parallel=False, *, loop=None,
certfile=None, keyfile=None):
"""Runs a server until an exception is raised (e.g. the user hits Ctrl-C)
or termination is requested by a client.
Expand All @@ -651,7 +665,7 @@ def simple_server_loop(targets, host, port, description=None, allow_parallel=Fal
signal_handler = SignalHandler()
signal_handler.setup()
try:
server = Server(targets, description, True, allow_parallel)
server = Server(targets, description, True, allow_parallel, certfile, keyfile)
used_loop.run_until_complete(server.start(host, port))
try:
_, pending = used_loop.run_until_complete(asyncio.wait(
Expand Down

0 comments on commit 40c951d

Please sign in to comment.