Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for expire in milliseconds #31

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ Example: 600 requests per minute
except TooManyRequests:
return '429 Too Many Requests'

Example: 1 request per 50 milliseconds

.. code-block:: python

from redis_rate_limit import RateLimit, TooManyRequests, TimeUnit
try:
with RateLimit(resource='users_list', client='192.168.0.10', max_requests=1, expire=50, time_unit=TimeUnit.MILLISECOND):
return '200 OK'
except TooManyRequests:
return '429 Too Many Requests'

Example: 100 requests per hour

.. code-block:: python
Expand Down
59 changes: 45 additions & 14 deletions redis_rate_limit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,34 @@
from redis.exceptions import NoScriptError
from redis import Redis, ConnectionPool

from enum import Enum


class TimeUnit(Enum):
SECOND = "second"
MILLISECOND = "millisecond"


# Adapted from http://redis.io/commands/incr#pattern-rate-limiter-2
INCREMENT_SCRIPT = b"""
local current
current = tonumber(redis.call("incrby", KEYS[1], ARGV[2]))
if current == tonumber(ARGV[2]) then
redis.call("expire", KEYS[1], ARGV[1])
redis.call("PEXPIRE", KEYS[1], ARGV[1])
end
return current
"""
INCREMENT_SCRIPT_HASH = sha1(INCREMENT_SCRIPT).hexdigest()

REDIS_POOL = ConnectionPool(host='127.0.0.1', port=6379, db=0)
REDIS_POOL = ConnectionPool(host="127.0.0.1", port=6379, db=0)


class RedisVersionNotSupported(Exception):
"""
Rate Limit depends on Redis’ commands EVALSHA and EVAL which are
only available since the version 2.6.0 of the database.
"""

pass


Expand All @@ -35,6 +43,7 @@ class TooManyRequests(Exception):
Occurs when the maximum number of requests is reached for a given resource
of an specific user.
"""

pass


Expand All @@ -43,7 +52,16 @@ class RateLimit(object):
This class offers an abstraction of a Rate Limit algorithm implemented on
top of Redis >= 2.6.0.
"""
def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS_POOL):

def __init__(
self,
resource,
client,
max_requests,
expire=None,
redis_pool=REDIS_POOL,
time_unit: TimeUnit = TimeUnit.SECOND,
):
"""
Class initialization method checks if the Rate Limit algorithm is
actually supported by the installed Redis version and sets some
Expand All @@ -65,6 +83,7 @@ def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS
self._rate_limit_key = "rate_limit:{0}_{1}".format(resource, client)
self._max_requests = max_requests
self._expire = expire or 1
self._expire = expire * 1000 if time_unit == TimeUnit.SECOND else expire

def __call__(self, func):
"""
Expand Down Expand Up @@ -130,21 +149,32 @@ def increment_usage(self, increment_by=1):
:return: integer: current usage
"""
if increment_by > self._max_requests:
raise ValueError('increment_by {increment_by} overflows '
'max_requests of {max_requests}'
.format(increment_by=increment_by,
max_requests=self._max_requests))
raise ValueError(
"increment_by {increment_by} overflows "
"max_requests of {max_requests}".format(
increment_by=increment_by, max_requests=self._max_requests
)
)
elif increment_by <= 0:
raise ValueError('{increment_by} is not a valid increment, '
'should be greater than or equal to zero.'
.format(increment_by=increment_by))
raise ValueError(
"{increment_by} is not a valid increment, "
"should be greater than or equal to zero.".format(
increment_by=increment_by
)
)

try:
current_usage = self._redis.evalsha(
INCREMENT_SCRIPT_HASH, 1, self._rate_limit_key, self._expire, increment_by)
INCREMENT_SCRIPT_HASH,
1,
self._rate_limit_key,
self._expire,
increment_by,
)
except NoScriptError:
current_usage = self._redis.eval(
INCREMENT_SCRIPT, 1, self._rate_limit_key, self._expire, increment_by)
INCREMENT_SCRIPT, 1, self._rate_limit_key, self._expire, increment_by
)

if int(current_usage) > self._max_requests:
raise TooManyRequests()
Expand All @@ -160,13 +190,14 @@ def _is_rate_limit_supported(self):
"""
redis_version = self._redis.info()['redis_version']
is_supported = Version(redis_version) >= Version('2.6.0')

return bool(is_supported)

def _reset(self):
"""
Deletes all keys that start with ‘rate_limit:’.
"""
matching_keys = self._redis.scan_iter(match='{0}*'.format('rate_limit:*'))
matching_keys = self._redis.scan_iter(match="{0}*".format("rate_limit:*"))
for rate_limit_key in matching_keys:
self._redis.delete(rate_limit_key)

Expand All @@ -180,7 +211,7 @@ def __init__(self, resource, max_requests, expire=None, redis_pool=REDIS_POOL):
:param expire: seconds to wait before resetting counters (i.e. ‘60’)
:param redis_pool: instance of redis.ConnectionPool.
Default: ConnectionPool(host='127.0.0.1', port=6379, db=0)
"""
"""
self.resource = resource
self.max_requests = max_requests
self.expire = expire
Expand Down