Skip to content

Commit

Permalink
rate limiter logging fix (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
russbiggs authored Sep 28, 2023
1 parent 979dd3f commit 43bf202
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 68 deletions.
79 changes: 33 additions & 46 deletions openaq_api/openaq_api/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
import datetime
import logging
import time
Expand Down Expand Up @@ -51,7 +52,7 @@
locations,
manufacturers,
measurements,
owners,
owners,
parameters,
providers,
sensors,
Expand Down Expand Up @@ -109,18 +110,23 @@ def render(self, content: Any) -> bytes:
redis_client = None # initialize for generalize_schema.py


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900")
app.add_middleware(LoggingMiddleware)
app.add_middleware(GZipMiddleware, minimum_size=1000)

if settings.RATE_LIMITING is True:
if settings.RATE_LIMITING:
logger.debug("Connecting to redis")
from redis.asyncio.cluster import RedisCluster

try:
redis_client = RedisCluster(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
decode_responses=True,
socket_timeout=5,
)
except Exception as e:
logging.error(
InfrastructureErrorLog(detail=f"failed to connect to redis: {e}")
)
logger.debug("Redis connected")
if redis_client:
app.add_middleware(
RateLimiterMiddleWare,
Expand All @@ -136,6 +142,17 @@ def render(self, content: Any) -> bytes:
)
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900")
app.add_middleware(LoggingMiddleware)
app.add_middleware(GZipMiddleware, minimum_size=1000)

app.include_router(auth_router)


Expand Down Expand Up @@ -182,53 +199,23 @@ async def openaq_exception_handler(request: Request, exc: ValidationError):
# return ORJSONResponse(status_code=500, content={"message": "internal server error"})


@app.on_event("startup")
async def startup_event():
"""
Application startup:
register the database
"""
@asynccontextmanager
async def lifespan(app: FastAPI):
if not hasattr(app.state, "pool"):
logger.debug("initializing connection pool")
app.state.pool = await db_pool(None)
logger.debug("Connection pool established")
if not hasattr(app.state, "redis_client"):
if settings.RATE_LIMITING:
logger.debug("Connecting to redis")
from redis.asyncio.cluster import RedisCluster

try:
redis_client = RedisCluster(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
decode_responses=True,
socket_timeout=5,
)
app.state.redis_client = redis_client
except Exception as e:
logging.error(
InfrastructureErrorLog(detail=f"failed to connect to redis: {e}")
)
logger.debug("Redis connected")

if hasattr(app.state, "counter"):
app.state.counter += 1
else:
app.state.counter = 0


@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown: de-register the database connection."""
yield
if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL:
logger.debug("Closing connection")
await app.state.pool.close()
delattr(app.state, "pool")
logger.debug("Connection closed")
if hasattr(app.state, "redis_client") and settings.RATE_LIMITING:
logger.debug("Closing redis connection")
await app.state.redis_client.close()
delattr(app.state, "redis_client")
logger.debug("redis connection closed")


@app.get("/ping", include_in_schema=False)
Expand Down
42 changes: 20 additions & 22 deletions openaq_api/openaq_api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ async def dispatch(self, request: Request, call_next):
return response


class Timer():
def __init__(self):
self.start_time = time.time()
self.last_mark = self.start_time
self.marks = []

def mark(self, key: str, return_time: str = 'total') -> float:
now = time.time()
mrk = {
"key": key,
"since": round((now - self.last_mark)*1000, 1),
"total": round((now - self.start_time)*1000, 1),
}
self.last_make = now
self.marks.append(mrk)
logger.debug(f"TIMER ({key}): {mrk['since']}")
return mrk.get(return_time)
class Timer:
def __init__(self):
self.start_time = time.time()
self.last_mark = self.start_time
self.marks = []

def mark(self, key: str, return_time: str = "total") -> float:
now = time.time()
mrk = {
"key": key,
"since": round((now - self.last_mark) * 1000, 1),
"total": round((now - self.start_time) * 1000, 1),
}
self.last_make = now
self.marks.append(mrk)
logger.debug(f"TIMER ({key}): {mrk['since']}")
return mrk.get(return_time)


class LoggingMiddleware(BaseHTTPMiddleware):
Expand All @@ -79,12 +79,11 @@ class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request.state.timer = Timer()
response = await call_next(request)
timing = request.state.timer.mark('process')
if hasattr(request.app.state, "rate_limiter"):
rate_limiter = request.app.state.rate_limiter
timing = request.state.timer.mark("process")
if hasattr(request.state, "rate_limiter"):
rate_limiter = request.state.rate_limiter
else:
rate_limiter = None

if hasattr(request.app.state, "counter"):
counter = request.app.state.counter
else:
Expand Down Expand Up @@ -207,7 +206,6 @@ async def dispatch(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"message": "Too many requests"},
)

request.state.rate_limiter = f"{key}/{limit}/{request.state.counter}"
response = await call_next(request)
return response

0 comments on commit 43bf202

Please sign in to comment.