From 891a4454b2fe10c569528d1f814f5a8e5d7e8926 Mon Sep 17 00:00:00 2001 From: Russ Biggs Date: Wed, 7 Feb 2024 18:36:59 -0700 Subject: [PATCH] Release 20240207 (#331) * Feature/ratelimiter update (#330) * updated request_is_limited algorithm * Strip out numbers from paths This is so we can better aggregate by path in cloudwatch --------- Co-authored-by: Christian Parker --- openaq_api/openaq_api/middleware.py | 37 ++++++++++++------------- openaq_api/openaq_api/models/logging.py | 6 ++-- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/openaq_api/openaq_api/middleware.py b/openaq_api/openaq_api/middleware.py index 0ab60fe..eff223a 100644 --- a/openaq_api/openaq_api/middleware.py +++ b/openaq_api/openaq_api/middleware.py @@ -1,6 +1,6 @@ import logging import time -from datetime import timedelta +from datetime import timedelta, datetime from os import environ from fastapi import Response, status from fastapi.responses import JSONResponse @@ -133,19 +133,16 @@ def __init__( self.rate_time = rate_time async def request_is_limited(self, key: str, limit: int, request: Request) -> bool: - if await self.redis_client.set(key, limit, nx=True): - await self.redis_client.expire(key, int(self.rate_time.total_seconds())) - count = await self.redis_client.get(key) - if count in ("-1", "-2"): - logger.error( - RedisErrorLog( - detail=f"redis has an invalid value for limit: {count} for key: {key}" - ) - ) - if count and int(count) > 0: - request.state.counter = await self.redis_client.decrby(key, 1) - return False - return True + now = datetime.now() + k = f"{key}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" + value = await self.redis_client.get(k) + if value is None or int(value) < limit: + async with self.redis_client.pipeline() as pipe: + [incr, _] = await pipe.incr(k).expire(k, 60).execute() + request.state.counter = limit - incr + return False + else: + return True async def check_valid_key(self, key: str) -> bool: if await self.redis_client.sismember("keys", key): @@ -214,11 +211,11 @@ async def dispatch( response.headers["RateLimit-Reset"] = str(ttl) rate_time_seconds = int(self.rate_time.total_seconds()) if auth: - response.headers[ - "RateLimit-Policy" - ] = f"{self.rate_amount_key};w={rate_time_seconds}" + response.headers["RateLimit-Policy"] = ( + f"{self.rate_amount_key};w={rate_time_seconds}" + ) else: - response.headers[ - "RateLimit-Policy" - ] = f"{self.rate_amount};w={rate_time_seconds}" + response.headers["RateLimit-Policy"] = ( + f"{self.rate_amount};w={rate_time_seconds}" + ) return response diff --git a/openaq_api/openaq_api/models/logging.py b/openaq_api/openaq_api/models/logging.py index 2965060..bcbe7f5 100644 --- a/openaq_api/openaq_api/models/logging.py +++ b/openaq_api/openaq_api/models/logging.py @@ -3,7 +3,7 @@ from fastapi import Request, status from humps import camelize from pydantic import BaseModel, ConfigDict, Field, computed_field - +import re class LogType(StrEnum): SUCCESS = "SUCCESS" @@ -107,8 +107,8 @@ def user_agent(self) -> str: @computed_field(return_type=str) @property def path(self) -> str: - """str: returns URL path from request""" - return self.request.url.path + """str: returns URL path from request but replaces numbers in the path with :id""" + return re.sub(r'/[0-9]+', '/:id', self.request.url.path) @computed_field(return_type=str) @property