Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ipv6 support to should_bypass_proxies #5953

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 70 additions & 13 deletions src/requests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,18 +679,46 @@ def requote_uri(uri):
return quote(uri, safe=safe_without_percent)


def _get_mask_bits(mask, totalbits=32):
"""Converts a mask from /xx format to a int
to be used as a mask for IP's in int format

Example: if mask is 24 function returns 0xFFFFFF00
if mask is 24 and totalbits=128 function
returns 0xFFFFFF00000000000000000000000000

:rtype: int
"""
bits = ((1 << mask) - 1) << (totalbits - mask)
return bits


def address_in_network(ip, net):
"""This function allows you to check if an IP belongs to a network subnet

Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24
returns False if ip = 192.168.1.1 and net = 192.168.100.0/24
returns True if ip = 1:2:3:4::1 and net = 1:2:3:4::/64

:rtype: bool
"""
ipaddr = struct.unpack("=L", socket.inet_aton(ip))[0]
netaddr, bits = net.split("/")
netmask = struct.unpack("=L", socket.inet_aton(dotted_netmask(int(bits))))[0]
network = struct.unpack("=L", socket.inet_aton(netaddr))[0] & netmask
if is_ipv4_address(ip) and is_ipv4_address(netaddr):
ipaddr = struct.unpack(">L", socket.inet_aton(ip))[0]
netmask = _get_mask_bits(int(bits))
network = struct.unpack(">L", socket.inet_aton(netaddr))[0]
elif is_ipv6_address(ip) and is_ipv6_address(netaddr):
ipaddr_msb, ipaddr_lsb = struct.unpack(
">QQ", socket.inet_pton(socket.AF_INET6, ip)
)
ipaddr = (ipaddr_msb << 64) ^ ipaddr_lsb
netmask = _get_mask_bits(int(bits), 128)
network_msb, network_lsb = struct.unpack(
">QQ", socket.inet_pton(socket.AF_INET6, netaddr)
)
network = (network_msb << 64) ^ network_lsb
else:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is returning False the correct thing to do when this function is called with IPv4 IP and IPv6 network or vice versa? What about raising an exception here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine as is (i.e. returning false), the function is used to test if a IP address is a member of any of the CIDR's listed in no_proxy, having a no_proxy with mixed IPv4 and IPv6 CIDR's would be valid, I don't think I'd expect an exception to the raised in this instance.

return (ipaddr & netmask) == (network & netmask)


Expand All @@ -710,30 +738,59 @@ def is_ipv4_address(string_ip):
:rtype: bool
"""
try:
socket.inet_aton(string_ip)
socket.inet_pton(socket.AF_INET, string_ip)
except OSError:
return False
return True


def is_ipv6_address(string_ip):
"""
:rtype: bool
"""
try:
socket.inet_pton(socket.AF_INET6, string_ip)
except OSError:
return False
return True


def compare_ips(a, b):
"""
Compare 2 IP's, uses socket.inet_pton to normalize IPv6 IPs

:rtype: bool
"""
if a == b:
return True
try:
return socket.inet_pton(socket.AF_INET6, a) == socket.inet_pton(
socket.AF_INET6, b
)
except OSError:
return False


def is_valid_cidr(string_network):
"""
Very simple check of the cidr format in no_proxy variable.

:rtype: bool
"""
if string_network.count("/") == 1:
address, mask = string_network.split("/")
try:
mask = int(string_network.split("/")[1])
mask = int(mask)
except ValueError:
return False

if mask < 1 or mask > 32:
return False

try:
socket.inet_aton(string_network.split("/")[0])
except OSError:
if is_ipv4_address(address):
if mask < 1 or mask > 32:
return False
elif is_ipv6_address(address):
if mask < 1 or mask > 128:
return False
else:
return False
else:
return False
Expand Down Expand Up @@ -790,12 +847,12 @@ def get_proxy(key):
# the end of the hostname, both with and without the port.
no_proxy = (host for host in no_proxy.replace(" ", "").split(",") if host)

if is_ipv4_address(parsed.hostname):
if is_ipv4_address(parsed.hostname) or is_ipv6_address(parsed.hostname):
for proxy_ip in no_proxy:
if is_valid_cidr(proxy_ip):
if address_in_network(parsed.hostname, proxy_ip):
return True
elif parsed.hostname == proxy_ip:
elif compare_ips(parsed.hostname, proxy_ip):
# If no_proxy ip was defined in plain IP notation instead of cidr notation &
# matches the IP of the index
return True
Expand Down
66 changes: 62 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict
from requests.utils import (
_get_mask_bits,
_parse_content_type_header,
add_dict_to_cookiejar,
address_in_network,
compare_ips,
dotted_netmask,
extract_zipped_paths,
get_auth_from_url,
Expand Down Expand Up @@ -263,8 +265,15 @@ def test_invalid(self, value):


class TestIsValidCIDR:
def test_valid(self):
assert is_valid_cidr("192.168.1.0/24")
@pytest.mark.parametrize(
"value",
(
"192.168.1.0/24",
"1:2:3:4::/64",
),
)
def test_valid(self, value):
assert is_valid_cidr(value)

@pytest.mark.parametrize(
"value",
Expand All @@ -274,6 +283,11 @@ def test_valid(self):
"192.168.1.0/128",
"192.168.1.0/-1",
"192.168.1.999/24",
"1:2:3:4::1",
"1:2:3:4::/a",
"1:2:3:4::0/321",
"1:2:3:4::/-1",
"1:2:3:4::12211/64",
),
)
def test_invalid(self, value):
Expand All @@ -287,6 +301,12 @@ def test_valid(self):
def test_invalid(self):
assert not address_in_network("172.16.0.1", "192.168.1.0/24")

def test_valid_v6(self):
assert address_in_network("1:2:3:4::1111", "1:2:3:4::/64")

def test_invalid_v6(self):
assert not address_in_network("1:2:3:4:1111", "1:2:3:4::/124")


class TestGuessFilename:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -722,6 +742,11 @@ def test_urldefragauth(url, expected):
("http://172.16.1.12:5000/", False),
("http://google.com:5000/v1.0/", False),
("file:///some/path/on/disk", True),
("http://[1:2:3:4:5:6:7:8]:5000/", True),
("http://[1:2:3:4::1]/", True),
("http://[1:2:3:9::1]/", True),
("http://[1:2:3:9:0:0:0:1]/", True),
("http://[1:2:3:9::2]/", False),
),
)
def test_should_bypass_proxies(url, expected, monkeypatch):
Expand All @@ -730,11 +755,11 @@ def test_should_bypass_proxies(url, expected, monkeypatch):
"""
monkeypatch.setenv(
"no_proxy",
"192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000",
"192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000",
)
monkeypatch.setenv(
"NO_PROXY",
"192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000",
"192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000",
)
assert should_bypass_proxies(url, no_proxy=None) == expected

Expand Down Expand Up @@ -956,3 +981,36 @@ def QueryValueEx(key, value_name):
monkeypatch.setattr(winreg, "OpenKey", OpenKey)
monkeypatch.setattr(winreg, "QueryValueEx", QueryValueEx)
assert should_bypass_proxies("http://example.com/", None) is False


@pytest.mark.parametrize(
"mask, totalbits, maskbits",
(
(24, None, 0xFFFFFF00),
(31, None, 0xFFFFFFFE),
(0, None, 0x0),
(4, 4, 0xF),
(24, 128, 0xFFFFFF00000000000000000000000000),
),
)
def test__get_mask_bits(mask, totalbits, maskbits):
args = {"mask": mask}
if totalbits:
args["totalbits"] = totalbits
assert _get_mask_bits(**args) == maskbits


@pytest.mark.parametrize(
"a, b, expected",
(
("1.2.3.4", "1.2.3.4", True),
("1.2.3.4", "2.2.3.4", False),
("1::4", "1.2.3.4", False),
("1::4", "1::4", True),
("1::4", "1:0:0:0:0:0:0:4", True),
("1::4", "1:0:0:0:0:0::4", True),
("1::4", "1:0:0:0:0:0:1:4", False),
),
)
def test_compare_ips(a, b, expected):
assert compare_ips(a, b) == expected
Loading