diff --git a/src/requests/utils.py b/src/requests/utils.py index ae6c42f6cb..0363698659 100644 --- a/src/requests/utils.py +++ b/src/requests/utils.py @@ -679,18 +679,46 @@ def requote_uri(uri): return quote(uri, safe=safe_without_percent) +def _get_mask_bits(mask, totalbits=32): + """Converts a mask from /xx format to a int + to be used as a mask for IP's in int format + + Example: if mask is 24 function returns 0xFFFFFF00 + if mask is 24 and totalbits=128 function + returns 0xFFFFFF00000000000000000000000000 + + :rtype: int + """ + bits = ((1 << mask) - 1) << (totalbits - mask) + return bits + + def address_in_network(ip, net): """This function allows you to check if an IP belongs to a network subnet Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24 returns False if ip = 192.168.1.1 and net = 192.168.100.0/24 + returns True if ip = 1:2:3:4::1 and net = 1:2:3:4::/64 :rtype: bool """ - ipaddr = struct.unpack("=L", socket.inet_aton(ip))[0] netaddr, bits = net.split("/") - netmask = struct.unpack("=L", socket.inet_aton(dotted_netmask(int(bits))))[0] - network = struct.unpack("=L", socket.inet_aton(netaddr))[0] & netmask + if is_ipv4_address(ip) and is_ipv4_address(netaddr): + ipaddr = struct.unpack(">L", socket.inet_aton(ip))[0] + netmask = _get_mask_bits(int(bits)) + network = struct.unpack(">L", socket.inet_aton(netaddr))[0] + elif is_ipv6_address(ip) and is_ipv6_address(netaddr): + ipaddr_msb, ipaddr_lsb = struct.unpack( + ">QQ", socket.inet_pton(socket.AF_INET6, ip) + ) + ipaddr = (ipaddr_msb << 64) ^ ipaddr_lsb + netmask = _get_mask_bits(int(bits), 128) + network_msb, network_lsb = struct.unpack( + ">QQ", socket.inet_pton(socket.AF_INET6, netaddr) + ) + network = (network_msb << 64) ^ network_lsb + else: + return False return (ipaddr & netmask) == (network & netmask) @@ -710,12 +738,39 @@ def is_ipv4_address(string_ip): :rtype: bool """ try: - socket.inet_aton(string_ip) + socket.inet_pton(socket.AF_INET, string_ip) + except OSError: + return False + return True + + +def is_ipv6_address(string_ip): + """ + :rtype: bool + """ + try: + socket.inet_pton(socket.AF_INET6, string_ip) except OSError: return False return True +def compare_ips(a, b): + """ + Compare 2 IP's, uses socket.inet_pton to normalize IPv6 IPs + + :rtype: bool + """ + if a == b: + return True + try: + return socket.inet_pton(socket.AF_INET6, a) == socket.inet_pton( + socket.AF_INET6, b + ) + except OSError: + return False + + def is_valid_cidr(string_network): """ Very simple check of the cidr format in no_proxy variable. @@ -723,17 +778,19 @@ def is_valid_cidr(string_network): :rtype: bool """ if string_network.count("/") == 1: + address, mask = string_network.split("/") try: - mask = int(string_network.split("/")[1]) + mask = int(mask) except ValueError: return False - if mask < 1 or mask > 32: - return False - - try: - socket.inet_aton(string_network.split("/")[0]) - except OSError: + if is_ipv4_address(address): + if mask < 1 or mask > 32: + return False + elif is_ipv6_address(address): + if mask < 1 or mask > 128: + return False + else: return False else: return False @@ -790,12 +847,12 @@ def get_proxy(key): # the end of the hostname, both with and without the port. no_proxy = (host for host in no_proxy.replace(" ", "").split(",") if host) - if is_ipv4_address(parsed.hostname): + if is_ipv4_address(parsed.hostname) or is_ipv6_address(parsed.hostname): for proxy_ip in no_proxy: if is_valid_cidr(proxy_ip): if address_in_network(parsed.hostname, proxy_ip): return True - elif parsed.hostname == proxy_ip: + elif compare_ips(parsed.hostname, proxy_ip): # If no_proxy ip was defined in plain IP notation instead of cidr notation & # matches the IP of the index return True diff --git a/tests/test_utils.py b/tests/test_utils.py index 5e9b56ea64..d7047d2c0d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,9 +14,11 @@ from requests.cookies import RequestsCookieJar from requests.structures import CaseInsensitiveDict from requests.utils import ( + _get_mask_bits, _parse_content_type_header, add_dict_to_cookiejar, address_in_network, + compare_ips, dotted_netmask, extract_zipped_paths, get_auth_from_url, @@ -263,8 +265,15 @@ def test_invalid(self, value): class TestIsValidCIDR: - def test_valid(self): - assert is_valid_cidr("192.168.1.0/24") + @pytest.mark.parametrize( + "value", + ( + "192.168.1.0/24", + "1:2:3:4::/64", + ), + ) + def test_valid(self, value): + assert is_valid_cidr(value) @pytest.mark.parametrize( "value", @@ -274,6 +283,11 @@ def test_valid(self): "192.168.1.0/128", "192.168.1.0/-1", "192.168.1.999/24", + "1:2:3:4::1", + "1:2:3:4::/a", + "1:2:3:4::0/321", + "1:2:3:4::/-1", + "1:2:3:4::12211/64", ), ) def test_invalid(self, value): @@ -287,6 +301,12 @@ def test_valid(self): def test_invalid(self): assert not address_in_network("172.16.0.1", "192.168.1.0/24") + def test_valid_v6(self): + assert address_in_network("1:2:3:4::1111", "1:2:3:4::/64") + + def test_invalid_v6(self): + assert not address_in_network("1:2:3:4:1111", "1:2:3:4::/124") + class TestGuessFilename: @pytest.mark.parametrize( @@ -722,6 +742,11 @@ def test_urldefragauth(url, expected): ("http://172.16.1.12:5000/", False), ("http://google.com:5000/v1.0/", False), ("file:///some/path/on/disk", True), + ("http://[1:2:3:4:5:6:7:8]:5000/", True), + ("http://[1:2:3:4::1]/", True), + ("http://[1:2:3:9::1]/", True), + ("http://[1:2:3:9:0:0:0:1]/", True), + ("http://[1:2:3:9::2]/", False), ), ) def test_should_bypass_proxies(url, expected, monkeypatch): @@ -730,11 +755,11 @@ def test_should_bypass_proxies(url, expected, monkeypatch): """ monkeypatch.setenv( "no_proxy", - "192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000", + "192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000", ) monkeypatch.setenv( "NO_PROXY", - "192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000", + "192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000", ) assert should_bypass_proxies(url, no_proxy=None) == expected @@ -956,3 +981,35 @@ def QueryValueEx(key, value_name): monkeypatch.setattr(winreg, "OpenKey", OpenKey) monkeypatch.setattr(winreg, "QueryValueEx", QueryValueEx) assert should_bypass_proxies("http://example.com/", None) is False + +@pytest.mark.parametrize( + "mask, totalbits, maskbits", + ( + (24, None, 0xFFFFFF00), + (31, None, 0xFFFFFFFE), + (0, None, 0x0), + (4, 4, 0xF), + (24, 128, 0xFFFFFF00000000000000000000000000), + ), +) +def test__get_mask_bits(mask, totalbits, maskbits): + args = {"mask": mask} + if totalbits: + args["totalbits"] = totalbits + assert _get_mask_bits(**args) == maskbits + + +@pytest.mark.parametrize( + "a, b, expected", + ( + ("1.2.3.4", "1.2.3.4", True), + ("1.2.3.4", "2.2.3.4", False), + ("1::4", "1.2.3.4", False), + ("1::4", "1::4", True), + ("1::4", "1:0:0:0:0:0:0:4", True), + ("1::4", "1:0:0:0:0:0::4", True), + ("1::4", "1:0:0:0:0:0:1:4", False), + ), +) +def test_compare_ips(a, b, expected): + assert compare_ips(a, b) == expected