diff --git a/README.rst b/README.rst index bab78c4..d2e8917 100644 --- a/README.rst +++ b/README.rst @@ -64,6 +64,17 @@ Example: 100 requests per hour except TooManyRequests: return '429 Too Many Requests' +Example: you can also pass an optional list of ignored_clients to bypass Rate Limit + +.. code-block:: python + + from redis_rate_limit import RateLimit, TooManyRequests + try: + with RateLimit(resource='users_list', client='192.168.0.10', max_requests=100, ignored_clients=['192.168.0.10'], expire=3600): + return '200 OK' + except TooManyRequests: + return '429 Too Many Requests' + Example: you can also setup a factory to use it later .. code-block:: python diff --git a/redis_rate_limit/__init__.py b/redis_rate_limit/__init__.py index 147db8f..4becb75 100644 --- a/redis_rate_limit/__init__.py +++ b/redis_rate_limit/__init__.py @@ -43,7 +43,7 @@ class RateLimit(object): This class offers an abstraction of a Rate Limit algorithm implemented on top of Redis >= 2.6.0. """ - def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS_POOL): + def __init__(self, resource, client, max_requests, ignored_clients=None, expire=None, redis_pool=REDIS_POOL): """ Class initialization method checks if the Rate Limit algorithm is actually supported by the installed Redis version and sets some @@ -54,10 +54,14 @@ def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS :param resource: resource identifier string (i.e. ‘user_pictures’) :param client: client identifier string (i.e. ‘192.168.0.10’) :param max_requests: integer (i.e. ‘10’) + :param ignored_clients: list of ip addresses (i.e. ['127.0.0.1']) :param expire: seconds to wait before resetting counters (i.e. ‘60’) :param redis_pool: instance of redis.ConnectionPool. Default: ConnectionPool(host='127.0.0.1', port=6379, db=0) """ + self.client = client + self.ignored_clients = ignored_clients + self._redis = Redis(connection_pool=redis_pool) if not self._is_rate_limit_supported(): raise RedisVersionNotSupported() @@ -129,6 +133,9 @@ def increment_usage(self, increment_by=1): :return: integer: current usage """ + if self.ignored_clients and self.client in self.ignored_clients: + return 0 + if increment_by > self._max_requests: raise ValueError('increment_by {increment_by} overflows ' 'max_requests of {max_requests}' @@ -172,17 +179,19 @@ def _reset(self): class RateLimiter(object): - def __init__(self, resource, max_requests, expire=None, redis_pool=REDIS_POOL): + def __init__(self, resource, max_requests, ignored_clients=None, expire=None, redis_pool=REDIS_POOL): """ Rate limit factory. Checks if RateLimit is supported when limit is called. :param resource: resource identifier string (i.e. ‘user_pictures’) :param max_requests: integer (i.e. ‘10’) + :param ignored_clients: list of ip addresses (i.e. ['127.0.0.1']) :param expire: seconds to wait before resetting counters (i.e. ‘60’) :param redis_pool: instance of redis.ConnectionPool. Default: ConnectionPool(host='127.0.0.1', port=6379, db=0) """ self.resource = resource self.max_requests = max_requests + self.ignored_clients = ignored_clients self.expire = expire self.redis_pool = redis_pool @@ -194,6 +203,7 @@ def limit(self, client): resource=self.resource, client=client, max_requests=self.max_requests, + ignored_clients=self.ignored_clients, expire=self.expire, redis_pool=self.redis_pool, ) diff --git a/tests/rate_limit_test.py b/tests/rate_limit_test.py index 58c5fe6..f4cde87 100644 --- a/tests/rate_limit_test.py +++ b/tests/rate_limit_test.py @@ -41,6 +41,19 @@ def test_limit_10_max_request(self): self.assertEqual(self.rate_limit.get_usage(), 11) self.assertEqual(self.rate_limit.has_been_reached(), True) + def test_ignored_clients(self): + """ + Should not increment counter if client is part of ignored_clients list. + """ + self.rate_limit = RateLimit(resource='test', client='localhost', ignored_clients=['localhost'], + max_requests=10, expire=2) + self.assertEqual(self.rate_limit.get_usage(), 0) + self.assertEqual(self.rate_limit.has_been_reached(), False) + + self._make_10_requests() + self.assertEqual(self.rate_limit.get_usage(), 0) + self.assertEqual(self.rate_limit.has_been_reached(), False) + def test_expire(self): """ Should not raise TooManyRequests Exception when trying to increment for