diff --git a/requests_unixsocket/__init__.py b/requests_unixsocket/__init__.py index 0fb5e1f..dcf581b 100644 --- a/requests_unixsocket/__init__.py +++ b/requests_unixsocket/__init__.py @@ -1,19 +1,33 @@ -import requests import sys +import requests +from requests.compat import urlparse, unquote + from .adapters import UnixAdapter -DEFAULT_SCHEME = 'http+unix://' + +def default_urlparse(url): + parsed_url = urlparse(url) + return UnixAdapter.Settings.ParseResult( + sockpath=unquote(parsed_url.netloc), + reqpath=parsed_url.path + '?' + parsed_url.query, + ) + + +default_scheme = 'http+unix://' +default_settings = UnixAdapter.Settings(urlparse=default_urlparse) class Session(requests.Session): - def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs): + def __init__(self, url_scheme=default_scheme, settings=None, + *args, **kwargs): super(Session, self).__init__(*args, **kwargs) - self.mount(url_scheme, UnixAdapter()) + self.settings = settings or default_settings + self.mount(url_scheme, UnixAdapter(settings=self.settings)) class monkeypatch(object): - def __init__(self, url_scheme=DEFAULT_SCHEME): + def __init__(self, url_scheme=default_scheme): self.session = Session() requests = self._get_global_requests_module() diff --git a/requests_unixsocket/adapters.py b/requests_unixsocket/adapters.py index a2c1564..a9159ed 100644 --- a/requests_unixsocket/adapters.py +++ b/requests_unixsocket/adapters.py @@ -1,7 +1,8 @@ import socket +from collections import namedtuple from requests.adapters import HTTPAdapter -from requests.compat import urlparse, unquote +from requests.compat import urlparse try: import http.client as httplib @@ -18,16 +19,12 @@ # https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py class UnixHTTPConnection(httplib.HTTPConnection, object): - def __init__(self, unix_socket_url, timeout=60): - """Create an HTTP connection to a unix domain socket - - :param unix_socket_url: A URL with a scheme of 'http+unix' and the - netloc is a percent-encoded path to a unix domain socket. E.g.: - 'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid' - """ + def __init__(self, url, timeout=60, settings=None): + """Create an HTTP connection to a unix domain socket""" super(UnixHTTPConnection, self).__init__('localhost', timeout=timeout) - self.unix_socket_url = unix_socket_url + self.url = url self.timeout = timeout + self.settings = settings self.sock = None def __del__(self): # base class does not have d'tor @@ -37,27 +34,39 @@ def __del__(self): # base class does not have d'tor def connect(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.timeout) - socket_path = unquote(urlparse(self.unix_socket_url).netloc) - sock.connect(socket_path) + sockpath = self.settings.urlparse(self.url).sockpath + sock.connect(sockpath) self.sock = sock class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): - def __init__(self, socket_path, timeout=60): + def __init__(self, socket_path, timeout=60, settings=None): super(UnixHTTPConnectionPool, self).__init__( 'localhost', timeout=timeout) self.socket_path = socket_path self.timeout = timeout + self.settings = settings def _new_conn(self): - return UnixHTTPConnection(self.socket_path, self.timeout) + return UnixHTTPConnection( + url=self.socket_path, + timeout=self.timeout, + settings=self.settings, + ) class UnixAdapter(HTTPAdapter): + class Settings(object): + class ParseResult(namedtuple('ParseResult', 'sockpath reqpath')): + pass + + def __init__(self, urlparse=None): + self.urlparse = urlparse - def __init__(self, timeout=60, pool_connections=25): + def __init__(self, timeout=60, pool_connections=25, settings=None): super(UnixAdapter, self).__init__() + self.settings = settings self.timeout = timeout self.pools = urllib3._collections.RecentlyUsedContainer( pool_connections, dispose_func=lambda p: p.close() @@ -77,13 +86,17 @@ def get_connection(self, url, proxies=None): if pool: return pool - pool = UnixHTTPConnectionPool(url, self.timeout) + pool = UnixHTTPConnectionPool( + socket_path=url, + settings=self.settings, + timeout=self.timeout, + ) self.pools[url] = pool return pool def request_url(self, request, proxies): - return request.path_url + return self.settings.urlparse(request.url).reqpath def close(self): self.pools.clear() diff --git a/requests_unixsocket/tests/test_requests_unixsocket.py b/requests_unixsocket/tests/test_requests_unixsocket.py index 733aa87..943dc96 100755 --- a/requests_unixsocket/tests/test_requests_unixsocket.py +++ b/requests_unixsocket/tests/test_requests_unixsocket.py @@ -4,9 +4,12 @@ """Tests for requests_unixsocket""" import logging +import os +import stat import pytest import requests +from requests.compat import urlparse import requests_unixsocket from requests_unixsocket.testutils import UnixSocketServerThread @@ -15,6 +18,35 @@ logger = logging.getLogger(__name__) +def is_socket(path): + try: + mode = os.stat(path).st_mode + return stat.S_ISSOCK(mode) + except OSError: + return False + + +def get_sock_prefix(path): + """Keep going up directory tree until we find a socket""" + + sockpath = path + reqpath_parts = [] + + while not is_socket(sockpath): + sockpath, tail = os.path.split(sockpath) + reqpath_parts.append(tail) + + return requests_unixsocket.UnixAdapter.Settings.ParseResult( + sockpath=sockpath, + reqpath='/' + os.path.join(*reversed(reqpath_parts)), + ) + + +alt_settings_1 = requests_unixsocket.UnixAdapter.Settings( + urlparse=lambda url: get_sock_prefix(urlparse(url).path), +) + + def test_unix_domain_adapter_ok(): with UnixSocketServerThread() as usock_thread: session = requests_unixsocket.Session('http+unix://') @@ -41,6 +73,34 @@ def test_unix_domain_adapter_ok(): assert r.text == 'Hello world!' +def test_unix_domain_adapter_alt_settings_1_ok(): + with UnixSocketServerThread() as usock_thread: + session = requests_unixsocket.Session( + url_scheme='http+unix://', + settings=alt_settings_1, + ) + url = 'http+unix://localhost%s/path/to/page' % usock_thread.usock + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', + 'options']: + logger.debug('Calling session.%s(%r) ...', method, url) + r = getattr(session, method)(url) + logger.debug( + 'Received response: %r with text: %r and headers: %r', + r, r.text, r.headers) + assert r.status_code == 200 + assert r.headers['server'] == 'waitress' + assert r.headers['X-Transport'] == 'unix domain socket' + assert r.headers['X-Requested-Path'] == '/path/to/page' + assert r.headers['X-Socket-Path'] == usock_thread.usock + assert isinstance(r.connection, requests_unixsocket.UnixAdapter) + assert r.url.lower() == url.lower() + if method == 'head': + assert r.text == '' + else: + assert r.text == 'Hello world!' + + def test_unix_domain_adapter_url_with_query_params(): with UnixSocketServerThread() as usock_thread: session = requests_unixsocket.Session('http+unix://') @@ -69,6 +129,33 @@ def test_unix_domain_adapter_url_with_query_params(): assert r.text == 'Hello world!' +def test_unix_domain_adapter_url_with_fragment(): + with UnixSocketServerThread() as usock_thread: + session = requests_unixsocket.Session('http+unix://') + urlencoded_usock = requests.compat.quote_plus(usock_thread.usock) + url = ('http+unix://%s' + '/containers/nginx/logs#some-fragment' % urlencoded_usock) + + for method in ['get', 'post', 'head', 'patch', 'put', 'delete', + 'options']: + logger.debug('Calling session.%s(%r) ...', method, url) + r = getattr(session, method)(url) + logger.debug( + 'Received response: %r with text: %r and headers: %r', + r, r.text, r.headers) + assert r.status_code == 200 + assert r.headers['server'] == 'waitress' + assert r.headers['X-Transport'] == 'unix domain socket' + assert r.headers['X-Requested-Path'] == '/containers/nginx/logs' + assert r.headers['X-Socket-Path'] == usock_thread.usock + assert isinstance(r.connection, requests_unixsocket.UnixAdapter) + assert r.url.lower() == url.lower() + if method == 'head': + assert r.text == '' + else: + assert r.text == 'Hello world!' + + def test_unix_domain_adapter_connection_error(): session = requests_unixsocket.Session('http+unix://')