From f9c4902475999b80070fb77b166e59eee2e41294 Mon Sep 17 00:00:00 2001 From: Sam Stenner Date: Mon, 2 Sep 2024 15:07:00 +0100 Subject: [PATCH] feat: add custom SSL context creation for Requests Co-authored-by: Genie --- src/requests/adapters.py | 20 +++--------------- src/requests/utils.py | 45 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/requests/adapters.py b/src/requests/adapters.py index 9a58b16025..312241fc35 100644 --- a/src/requests/adapters.py +++ b/src/requests/adapters.py @@ -1,16 +1,3 @@ -""" -requests.adapters -~~~~~~~~~~~~~~~~~ - -This module contains the transport adapters that Requests uses to define -and maintain connections. -""" - -import os.path -import socket # noqa: F401 -import typing -import warnings - from urllib3.exceptions import ClosedPoolError, ConnectTimeoutError from urllib3.exceptions import HTTPError as _HTTPError from urllib3.exceptions import InvalidHeader as _InvalidHeader @@ -27,7 +14,6 @@ from urllib3.util import Timeout as TimeoutSauce from urllib3.util import parse_url from urllib3.util.retry import Retry -from urllib3.util.ssl_ import create_urllib3_context from .auth import _basic_auth_str from .compat import basestring, urlparse @@ -47,10 +33,10 @@ from .models import Response from .structures import CaseInsensitiveDict from .utils import ( - DEFAULT_CA_BUNDLE_PATH, extract_zipped_paths, get_auth_from_url, get_encoding_from_headers, + get_ssl_context, prepend_scheme_if_needed, select_proxy, urldefragauth, @@ -117,14 +103,14 @@ def _urllib3_request_context( pool_kwargs["ca_certs"] = verify else: pool_kwargs["ca_cert_dir"] = verify - pool_kwargs["cert_reqs"] = cert_reqs + not has_poolmanager_ssl_context if client_cert is not None: if isinstance(client_cert, tuple) and len(client_cert) == 2: pool_kwargs["cert_file"] = client_cert[0] pool_kwargs["key_file"] = client_cert[1] else: # According to our docs, we allow users to specify just the client - # cert path + pool_kwargs["ssl_context"] = get_ssl_context() pool_kwargs["cert_file"] = client_cert host_params = { "scheme": scheme, diff --git a/src/requests/utils.py b/src/requests/utils.py index 699683e5d9..f43b58521e 100644 --- a/src/requests/utils.py +++ b/src/requests/utils.py @@ -18,8 +18,9 @@ import warnings import zipfile from collections import OrderedDict +from typing import Optional -from urllib3.util import make_headers, parse_url +from urllib3.util import make_headers, parse_url, create_urllib3_context from . import certs from .__version__ import __version__ @@ -62,10 +63,12 @@ NETRC_FILES = (".netrc", "_netrc") DEFAULT_CA_BUNDLE_PATH = certs.where() - +DEFAULT_CA_BUNDLE_PATH = certs.where() DEFAULT_PORTS = {"http": 80, "https": 443} # Ensure that ', ' is used to preserve previous delimiter behavior. +_SSL_CONTEXT: Optional["ssl.SSLContext"] = None + DEFAULT_ACCEPT_ENCODING = ", ".join( re.split(r",\s*", make_headers(accept_encoding=True)["accept-encoding"]) ) @@ -124,6 +127,44 @@ def proxy_bypass(host): # noqa return proxy_bypass_registry(host) +def get_ssl_context() -> "ssl.SSLContext | None": + """ + Returns a custom ``SSLContext`` for Requests. + + This function should only be called once per interpreter instance because it can be expensive to call. + + :rtype: ssl.SSLContext | None + :returns: The custom ``SSLContext`` for Requests if one could be created, otherwise ``None``. + """ + global _SSL_CONTEXT + + if _SSL_CONTEXT is not None: + return _SSL_CONTEXT + + try: + # Import ssl here so if it fails we only error on first use of SSLContext creation. + # This allows users to disable SSL verification without third-party dependencies by setting verify=False. + from urllib3.util.ssl_ import create_urllib3_context # type: ignore[import] + + _SSL_CONTEXT = create_urllib3_context() + + # In some cases, the user may have already loaded a custom CA bundle path into their default SSL context. + # If this is true we want to skip over our default CA load because it may produce a warning or error. + context_has_custom_ca_path = ( + DEFAULT_CA_BUNDLE_PATH not in _SSL_CONTEXT.get_ca_certs([]) # type: ignore[attr-defined] + ) + + if not context_has_custom_ca_path: + _SSL_CONTEXT.load_verify_locations( + extract_zipped_paths(DEFAULT_CA_BUNDLE_PATH) + ) + + except ImportError: + pass + + return _SSL_CONTEXT + + def dict_to_sequence(d): """Returns an internal sequence dictionary update."""