diff --git a/.gitignore b/.gitignore index c47828c..0bfc0eb 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ dist *.egg-info docs/_build/* .tox/ +.idea/ diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..860d1e1 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +.fabric/ @unacademy/sre \ No newline at end of file diff --git a/ratelimit/__init__.py b/ratelimit/__init__.py index 9babbfb..5362284 100644 --- a/ratelimit/__init__.py +++ b/ratelimit/__init__.py @@ -1,4 +1,4 @@ -VERSION = (1, 1, 0) +VERSION = (1, 3, 7) __version__ = '.'.join(map(str, VERSION)) ALL = (None,) # Sentinel value for all HTTP methods. diff --git a/ratelimit/decorators.py b/ratelimit/decorators.py index 663bce6..0904d69 100644 --- a/ratelimit/decorators.py +++ b/ratelimit/decorators.py @@ -12,7 +12,7 @@ __all__ = ['ratelimit'] -def ratelimit(group=None, key=None, rate=None, method=ALL, block=False): +def ratelimit(group=None, key=None, rate=None, method=ALL, block=False, reset=None): def decorator(fn): @wraps(fn) def _wrapped(*args, **kw): @@ -24,9 +24,9 @@ def _wrapped(*args, **kw): request.limited = getattr(request, 'limited', False) ratelimited = is_ratelimited(request=request, group=group, fn=fn, key=key, rate=rate, method=method, - increment=True) + increment=True, reset=reset) if ratelimited and block: - raise Ratelimited() + raise Ratelimited("Too many requests", 429) return fn(*args, **kw) return _wrapped return decorator diff --git a/ratelimit/exceptions.py b/ratelimit/exceptions.py index f39a0f4..e169baa 100644 --- a/ratelimit/exceptions.py +++ b/ratelimit/exceptions.py @@ -3,3 +3,20 @@ class Ratelimited(PermissionDenied): pass + + +class CustomException(Exception): + def __init__(self, message): + super(CustomException, self).__init__(message) + + +class DatastoreConnectionError(CustomException): + def __init__(self): + message = 'Could not establish connection to data store' + super(DatastoreConnectionError, self).__init__(message) + + +class RateLimited(CustomException): + def __init__(self): + message = 'Rate limited' + super(RateLimited, self).__init__(message) diff --git a/ratelimit/redis_rate_limit.py b/ratelimit/redis_rate_limit.py new file mode 100644 index 0000000..643f1ca --- /dev/null +++ b/ratelimit/redis_rate_limit.py @@ -0,0 +1,141 @@ +import time +import redis + +from functools import wraps +from django.conf import settings +from rediscluster import StrictRedisCluster +from .exceptions import RateLimited, DatastoreConnectionError + +__author__ = 'vikaschahal' + + +class RateLimiter(object): + """Base class for Rate Limiting""" + + def __init__(self, limit, window, connection, key): + """ + :param limit: number of requests allowed + :type limit: int + :param window: window in secs in which :limit number requests allowed + :type window: int + """ + self._connection = connection + self._key = key + self._limit = limit + self._window = window + + def is_allowed(self, log_current_request=True): + """ + :param log_current_request: Consider the call for rate limiting + :type log_current_request: bool + :return: Whether a requests is allowed or rate limited. + :rtype: bool + """ + raise NotImplementedError + + @property + def remaining_requests(self): + raise NotImplementedError + + def limit(self, func): + """Decorator to check the rate limit.""" + + @wraps(func) + def decorated(*args, **kwargs): + if self.is_allowed(): + return func(*args, **kwargs) + raise RateLimited + + return decorated + + +class RedisRateLimiterConnection(object): + def __init__(self, host=None, port=None, db=0, connection=None): + self.connection = None + if host: + if settings.REDIS_HOST_INTERNAL_NEW_IS_CLUSTER: + connection = StrictRedisCluster(startup_nodes=[{'host': host, 'port': 6379}], skip_full_coverage_check=True) + else: + connection = redis.StrictRedis(host, port, db) + if not connection.ping(): + raise DatastoreConnectionError + self.connection = connection + elif connection: + if not connection.ping(): + raise DatastoreConnectionError + self.connection = connection + else: + raise DatastoreConnectionError + + +class RedisRateLimiter(RateLimiter): + def __init__(self, limit, window, connection, key): + super(RedisRateLimiter, self).__init__(limit, window, connection, key) + self._pipeline = self._connection.connection.pipeline() + + def _increment_request(self): + key_value = int(time.time()) + self._window + self._pipeline.zadd( + self._key, key_value, key_value + ) + self._pipeline.expire(self._key, self._window) # set key expiry + self._pipeline.execute() + + def is_allowed(self, log_current_request=True): + if log_current_request: + self._increment_request() + current_time = time.time() + self._pipeline.zremrangebyscore(self._key, '-inf', current_time) + self._pipeline.zcount(self._key, '-inf', '+inf') + result = self._pipeline.execute() + return result[-1] <= self._limit + + def count(self, log_current_request=True): + if log_current_request: + self._increment_request() + current_time = time.time() + self._pipeline.zremrangebyscore(self._key, '-inf', current_time) + self._pipeline.zcount(self._key, '-inf', '+inf') + result = self._pipeline.execute() + return result[-1] + + +class IpRateLimiter(RateLimiter): + def __init__(self, limit, window, connection, key): + super(IpRateLimiter, self).__init__(limit, window, connection, key) + self._pipeline = self._connection.connection.pipeline() + + def add(self, value): + self._pipeline.exists(self._key) + results = self._pipeline.execute() + self._pipeline.sadd(self._key, value) + if not results[0]: + self._pipeline.expire(self._key, self._window) + self._pipeline.execute() + + def count(self): + self._pipeline.scard(self._key) + result = self._pipeline.execute() + return result[0] + + def delete(self): + self._pipeline.delete(self._key) + self._pipeline.execute() + + def is_allowed(self): + self._pipeline.exists(self._key) + self._pipeline.scard(self._key) + results = self._pipeline.execute() + if not results[0]: + return True + return results[1] < self._limit + + @property + def remaining_requests(self): + return self._limit - self.count() + + +if settings.REDIS_HOST_INTERNAL_NEW_IS_CLUSTER: + redis_connection = RedisRateLimiterConnection(host=settings.REDIS_HOST_INTERNAL_NEW) +else: + redis_connection = RedisRateLimiterConnection(host=settings.REDIS_HOST_INTERNAL, port=6379, db=0) \ No newline at end of file diff --git a/ratelimit/utils.py b/ratelimit/utils.py index 285e999..9d1ed90 100644 --- a/ratelimit/utils.py +++ b/ratelimit/utils.py @@ -1,4 +1,5 @@ import hashlib +import json import re import time import zlib @@ -8,10 +9,12 @@ from django.core.cache import caches from django.core.exceptions import ImproperlyConfigured -from ratelimit import ALL, UNSAFE +from redis_rate_limit import redis_connection, RedisRateLimiter, IpRateLimiter +from ratelimit import ALL, UNSAFE -__all__ = ['is_ratelimited'] +__all__ = ['is_ratelimited', 'block_ip', 'is_request_allowed', 'get_custom_ip_from_request', + 'get_region_code_from_request', 'is_request_from_region_allowed', 'block_region'] _PERIODS = { 's': 1, @@ -24,14 +27,21 @@ EXPIRATION_FUDGE = 5 +def get_ip(request): + custom_ip = get_custom_ip_from_request(request) + if custom_ip: + return custom_ip + return request.META['HTTP_X_FORWARDED_FOR'] + + def user_or_ip(request): if is_authenticated(request.user): return str(request.user.pk) - return request.META['REMOTE_ADDR'] + return get_ip(request) _SIMPLE_KEYS = { - 'ip': lambda r: r.META['REMOTE_ADDR'], + 'ip': get_ip, 'user': lambda r: str(r.user.pk), 'user_or_ip': user_or_ip, } @@ -58,6 +68,7 @@ def _method_match(request, method=ALL): rate_re = re.compile('([\d]+)/([\d]*)([smhd])?') +private_ip = re.compile("^172\.(1[6-9]|2[0-9]|3[0-1])\.[0-9]{1,3}\.[0-9]{1,3}$") def _split_rate(rate): @@ -85,10 +96,13 @@ def _get_window(value, period): return w -def _make_cache_key(group, rate, value, methods): +def _make_cache_key(group, rate, value, methods, sliding_window=False): count, period = _split_rate(rate) safe_rate = '%d/%ds' % (count, period) - window = _get_window(value, period) + if sliding_window: + window = '' + else: + window = _get_window(value, period) parts = [group + safe_rate, value, str(window)] if methods is not None: if methods == ALL: @@ -100,8 +114,52 @@ def _make_cache_key(group, rate, value, methods): return prefix + hashlib.md5(u''.join(parts).encode('utf-8')).hexdigest() +def _get_value_from_key(request, group=None, key=None): + if not key: + raise ImproperlyConfigured('Ratelimit key must be specified') + if callable(key): + value = key(group, request) + elif key in _SIMPLE_KEYS: + print(_SIMPLE_KEYS[key](request)) + value = _SIMPLE_KEYS[key](request) + elif ':' in key: + accessor, k = key.split(':', 1) + if accessor not in _ACCESSOR_KEYS: + raise ImproperlyConfigured('Unknown ratelimit key: %s' % key) + value = _ACCESSOR_KEYS[accessor](request, k) + elif '.' in key: + mod, attr = key.rsplit('.', 1) + keyfn = getattr(import_module(mod), attr) + value = keyfn(group, request) + else: + raise ImproperlyConfigured( + 'Could not understand ratelimit key: %s' % key) + return value + + +def _get_usage_count(request, group=None, fn=None, key=None, rate=None, + method=ALL, increment=False, reset=None, sliding_window=True): + value = _get_value_from_key(request, group=group, key=key) + limit, period = _split_rate(rate) + cache_key = _make_cache_key(group, rate, value, method, sliding_window) + redis_limiter = RedisRateLimiter(limit=limit, window=period, connection=redis_connection, key=cache_key) + count = redis_limiter.count() + return {'count': count, 'limit': limit} + + +def get_offence_count(request, group=None, max_offence_rate=None, + key=None, method=ALL, sliding_window=True, count_current_request=False): + value = _get_value_from_key(request, group=group, key=key) + limit, period = _split_rate(max_offence_rate) + cache_key = _make_cache_key(group, max_offence_rate, value, method, sliding_window) + redis_limiter = RedisRateLimiter(limit=limit, window=period, connection=redis_connection, key=cache_key) + count = redis_limiter.count(log_current_request=count_current_request) + return {'count': count, 'limit': limit} + + def is_ratelimited(request, group=None, fn=None, key=None, rate=None, - method=ALL, increment=False): + method=ALL, increment=False, reset=None, sliding_window=True, + max_offence_rate=None): if group is None: if hasattr(fn, '__self__'): parts = fn.__module__, fn.__self__.__class__.__name__, fn.__name__ @@ -124,7 +182,21 @@ def is_ratelimited(request, group=None, fn=None, key=None, rate=None, if rate is None: request.limited = old_limited return False - usage = get_usage_count(request, group, fn, key, rate, method, increment) + + if max_offence_rate is not None: + offence_report = get_offence_count(request, group, max_offence_rate, key, method, sliding_window) + offence_count = offence_report.get('count') + if offence_count is not None: + max_offence_count = offence_report.get('limit') + if offence_count >= max_offence_count: + get_offence_count(request, group, max_offence_rate, key, method, + sliding_window, count_current_request=True) + return True + + if sliding_window: + usage = _get_usage_count(request, group, fn, key, rate, method, increment, reset, sliding_window) + else: + usage = get_usage_count(request, group, fn, key, rate, method, increment, reset) fail_open = getattr(settings, 'RATELIMIT_FAIL_OPEN', False) @@ -137,11 +209,16 @@ def is_ratelimited(request, group=None, fn=None, key=None, rate=None, if increment: request.limited = old_limited or limited + + if max_offence_rate is not None and limited: + get_offence_count(request, group, max_offence_rate, key, method, + sliding_window, count_current_request=True) + return limited def get_usage_count(request, group=None, fn=None, key=None, rate=None, - method=ALL, increment=False): + method=ALL, increment=False, reset=None): if not key: raise ImproperlyConfigured('Ratelimit key must be specified') limit, period = _split_rate(rate) @@ -151,6 +228,7 @@ def get_usage_count(request, group=None, fn=None, key=None, rate=None, if callable(key): value = key(group, request) elif key in _SIMPLE_KEYS: + print(_SIMPLE_KEYS[key](request)) value = _SIMPLE_KEYS[key](request) elif ':' in key: accessor, k = key.split(':', 1) @@ -166,6 +244,12 @@ def get_usage_count(request, group=None, fn=None, key=None, rate=None, 'Could not understand ratelimit key: %s' % key) cache_key = _make_cache_key(group, rate, value, method) + + if reset and callable(reset): + should_reset = reset(request) + if should_reset: + cache.delete(cache_key) + time_left = _get_window(value, period) - int(time.time()) initial_value = 1 if increment else 0 added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) @@ -192,3 +276,73 @@ def is_authenticated(user): return user.is_authenticated() else: return user.is_authenticated + + +def get_cache_key_for_ip_blocking(request, func): + ip = get_custom_ip_from_request(request) + name = func.__name__ + url = request.path + keys = [ip, name, url] + return 'ip_rl_v2:' + hashlib.md5(u''.join(keys).encode('utf-8')).hexdigest() + + +def is_request_allowed(request, func, rate): + limit, period = _split_rate(rate) + cache_key = get_cache_key_for_ip_blocking(request, func) + redis_set = IpRateLimiter(limit=limit, window=period, connection=redis_connection, key=cache_key) + return redis_set.is_allowed() + + +def block_ip(request, func, function_to_get_attributes, rate): + limit, period = _split_rate(rate) + cache_key = get_cache_key_for_ip_blocking(request, func) + redis_set = IpRateLimiter(limit=limit, window=period, connection=redis_connection, key=cache_key) + hash_value = hashlib.md5(json.dumps(function_to_get_attributes(request))).hexdigest() + redis_set.add(hash_value) + + +def get_right_most_public_ip(ips): + index = len(ips) + while index > 0: + ip = ips[index - 1] + if not private_ip.match(ip): + return ip + index -= 1 + return None + + +def get_custom_ip_from_request(request): + ips = request.META.get('HTTP_X_' + settings.PROXY_PASS_CUSTOM_HEADER_NAME.upper() + "_CLIENT_IP", None) + if ips is None: + return None + ips = ips.split(",") + if len(ips) == 0: + return None + return get_right_most_public_ip(ips) + + +def get_cache_key_for_region_blocking(request, func): + country_code = get_region_code_from_request(request) + name = func.__name__ + url = request.path + keys = [country_code, name, url] + return 'rl_region:' + hashlib.md5(u''.join(keys).encode('utf-8')).hexdigest() + + +def get_region_code_from_request(request): + return request.META.get('HTTP_X_' + settings.PROXY_PASS_CUSTOM_HEADER_NAME.upper() + "_COUNTRY_CODE", None) + + +def is_request_from_region_allowed(request, func, rate): + limit, period = _split_rate(rate) + cache_key = get_cache_key_for_region_blocking(request, func) + redis_set = IpRateLimiter(limit=limit, window=period, connection=redis_connection, key=cache_key) + return redis_set.is_allowed() + + +def block_region(request, func, function_to_get_attributes, rate): + limit, period = _split_rate(rate) + cache_key = get_cache_key_for_region_blocking(request, func) + redis_set = IpRateLimiter(limit=limit, window=period, connection=redis_connection, key=cache_key) + hash_value = hashlib.md5(json.dumps(function_to_get_attributes(request))).hexdigest() + redis_set.add(hash_value)