diff --git a/openaq_api/openaq_api/main.py b/openaq_api/openaq_api/main.py index 1c9a11c..09a0517 100644 --- a/openaq_api/openaq_api/main.py +++ b/openaq_api/openaq_api/main.py @@ -8,7 +8,7 @@ from typing import Any import orjson -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Depends from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -21,10 +21,9 @@ from openaq_api.db import db_pool from openaq_api.middleware import ( + check_api_key, CacheControlMiddleware, LoggingMiddleware, - PrivatePathsMiddleware, - RateLimiterMiddleWare, ) from openaq_api.models.logging import ( InfrastructureErrorLog, @@ -33,6 +32,8 @@ WarnLog, ) + + # from openaq_api.routers.auth import router as auth_router from openaq_api.routers.averages import router as averages_router from openaq_api.routers.cities import router as cities_router @@ -105,8 +106,6 @@ def render(self, content: Any) -> bytes: return orjson.dumps(content, default=default) -redis_client = None # initialize for generalize_schema.py - @asynccontextmanager async def lifespan(app: FastAPI): @@ -119,7 +118,7 @@ async def lifespan(app: FastAPI): app.state.counter += 1 else: app.state.counter = 0 - app.state.redis_client = redis_client + yield if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL: logger.debug("Closing connection") @@ -128,16 +127,21 @@ async def lifespan(app: FastAPI): logger.debug("Connection closed") + + + app = FastAPI( title="OpenAQ", description="OpenAQ API", version="2.0.0", default_response_class=ORJSONResponse, + dependencies=[Depends(check_api_key)], docs_url="/docs", lifespan=lifespan, ) +app.redis = None if settings.RATE_LIMITING is True: if settings.RATE_LIMITING: logger.debug("Connecting to redis") @@ -150,25 +154,15 @@ async def lifespan(app: FastAPI): decode_responses=True, socket_timeout=5, ) + # attach to the app so it can be retrieved via the request + app.redis = redis_client + logger.debug("Redis connected") + except Exception as e: logging.error( InfrastructureErrorLog(detail=f"failed to connect to redis: {e}") ) - print(redis_client) - logger.debug("Redis connected") - if redis_client: - app.add_middleware( - RateLimiterMiddleWare, - redis_client=redis_client, - rate_amount_key=settings.RATE_AMOUNT_KEY, - rate_time=datetime.timedelta(minutes=settings.RATE_TIME), - ) - else: - logger.warning( - WarnLog( - detail="valid redis client not provided but RATE_LIMITING set to TRUE" - ) - ) + app.add_middleware( CORSMiddleware, @@ -180,7 +174,6 @@ async def lifespan(app: FastAPI): app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900") app.add_middleware(LoggingMiddleware) app.add_middleware(GZipMiddleware, minimum_size=1000) -app.add_middleware(PrivatePathsMiddleware) class OpenAQValidationResponseDetail(BaseModel): @@ -198,31 +191,11 @@ async def openaq_request_validation_exception_handler( request: Request, exc: RequestValidationError ): return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc))) - # return PlainTextResponse(str(exc)) - # print("\n\n\n\n\n") - # print(str(exc)) - # print("\n\n\n\n\n") - # detail = orjson.loads(str(exc)) - # logger.debug(traceback.format_exc()) - # logger.info( - # UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json() - # ) - # detail = OpenAQValidationResponse(detail=detail) - # return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) @app.exception_handler(ValidationError) async def openaq_exception_handler(request: Request, exc: ValidationError): return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc))) - # detail = orjson.loads(exc.model_dump_json()) - # logger.debug(traceback.format_exc()) - # logger.error( - # ModelValidationError( - # request=request, detail=exc.jsmodel_dump_jsonon() - # ).model_dump_json() - # ) - # return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) - # return ORJSONResponse(status_code=500, content={"message": "internal server error"}) @app.get("/ping", include_in_schema=False) diff --git a/openaq_api/openaq_api/middleware.py b/openaq_api/openaq_api/middleware.py index 6a67d04..09af097 100644 --- a/openaq_api/openaq_api/middleware.py +++ b/openaq_api/openaq_api/middleware.py @@ -2,13 +2,17 @@ import time from datetime import timedelta, datetime from os import environ -from fastapi import Response, status +from fastapi import Response, status, Security, HTTPException from fastapi.responses import JSONResponse from redis.asyncio.cluster import RedisCluster from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.types import ASGIApp +from fastapi.security import ( + APIKeyHeader, +) + from openaq_api.models.logging import ( HTTPLog, LogType, @@ -21,6 +25,118 @@ logger = logging.getLogger("middleware") +NOT_AUTHENTICATED_EXCEPTION = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", +) + +TOO_MANY_REQUESTS = HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="To many requests", +) + +def is_whitelisted_route(route: str) -> bool: + logger.debug(f"Checking if '{route}' is whitelisted") + allow_list = ["/", "/auth", "/openapi.json", "/docs", "/register"] + if route in allow_list: + return True + if "/v2/locations/tiles" in route: + return True + if "/v3/locations/tiles" in route: + return True + if "/assets" in route: + return True + if ".css" in route: + return True + if ".js" in route: + return True + return False + + +async def check_api_key( + request: Request, + response: Response, + api_key=Security(APIKeyHeader(name='X-API-Key', auto_error=False)), + ): + """ + Check for an api key and then to see if they are rate limited. Throws a + `not authenticated` or `to many reqests` error if appropriate. + Meant to be used as a dependency either at the app, router or function level + """ + route = request.url.path + # no checking or limiting for whitelistted routes + if is_whitelisted_route(route): + return api_key + elif api_key == settings.EXPLORER_API_KEY: + return api_key + else: + # check to see if we are limiting + redis = request.app.redis + + if redis is None: + logger.warning('No redis client found') + return api_key + elif api_key is None: + logger.debug('No api key provided') + raise NOT_AUTHENTICATED_EXCEPTION + else: + # check api key + limit = settings.RATE_AMOUNT_KEY + limited = False + # check valid key + if await redis.sismember("keys", api_key) == 0: + logger.debug('Api key not found') + raise NOT_AUTHENTICATED_EXCEPTION + + # check if its limited + now = datetime.now() + # Using a sliding window rate limiting algorithm + # we add the current time to the minute to the api key and use that as our check + key = f"{api_key}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" + # if the that key is in our redis db it will return the number of requests + # that key has made during the current minute + value = await redis.get(key) + ttl = await redis.ttl(key) + + if value is None: + # if the value is none than we need to add that key to the redis db + # and set it, increment it and set it to timeout/delete is 60 seconds + logger.debug('redis no key for current minute so not limited') + async with redis.pipeline() as pipe: + [incr, _] = await pipe.incr(key).expire(key, 60).execute() + requests_used = limit - incr + elif int(value) < limit: + # if that key does exist and the value is below the allowed number of requests + # wea re going to increment it and move on + logger.debug(f'redis - has key for current minute value ({value}) < limit ({limit})') + async with redis.pipeline() as pipe: + [incr, _] = await pipe.incr(key).execute() + requests_used = limit - incr + else: + # otherwise the user is over their limit and so we are going to throw a 429 + # after we set the headers + logger.debug(f'redis - has key for current minute and value ({value}) >= limit ({limit})') + limited = True + requests_used = value + + response.headers["x-ratelimit-limit"] = str(limit) + response.headers["x-ratelimit-remaining"] = "0" + response.headers["x-ratelimit-used"] = str(requests_used) + response.headers["x-ratelimit-reset"] = str(ttl) + + if limited: + logging.info( + TooManyRequestsLog( + request=request, + rate_limiter=f"{key}/{limit}/{requests_used}", + ).model_dump_json() + ) + raise TOO_MANY_REQUESTS + + # it would be ideal if we were returing the user information right here + # even it was just an email address it might be useful + return api_key + class CacheControlMiddleware(BaseHTTPMiddleware): """MiddleWare to add CacheControl in response headers.""" @@ -44,16 +160,6 @@ async def dispatch(self, request: Request, call_next): return response -class GetHostMiddleware(BaseHTTPMiddleware): - """MiddleWare to set servers url on App with current url.""" - - async def dispatch(self, request: Request, call_next): - environ["BASE_URL"] = str(request.base_url) - response = await call_next(request) - - return response - - class Timer: def __init__(self): self.start_time = time.time() @@ -114,142 +220,3 @@ async def dispatch(self, request: Request, call_next): ).model_dump_json() ) return response - - -class PrivatePathsMiddleware(BaseHTTPMiddleware): - """ - Middleware to protect private endpoints with an API key - """ - - async def dispatch(self, request: Request, call_next): - route = request.url.path - if "/auth" in route: - auth = request.headers.get("x-api-key", None) - if auth != settings.EXPLORER_API_KEY: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"message": "invalid credentials"}, - ) - response = await call_next(request) - return response - - -class RateLimiterMiddleWare(BaseHTTPMiddleware): - def __init__( - self, - app: ASGIApp, - redis_client: RedisCluster, - rate_amount_key: int, # number of requests allowed with api key - rate_time: timedelta, # timedelta of rate limit expiration - ) -> None: - """Init Middleware.""" - super().__init__(app) - self.redis_client = redis_client - self.rate_amount_key = rate_amount_key - self.rate_time = rate_time - - async def request_is_limited(self, key: str, limit: int, request: Request) -> bool: - value = await self.redis_client.get(key) - if value is None: - async with self.redis_client.pipeline() as pipe: - [incr, _] = await pipe.incr(key).expire(key, 60).execute() - request.state.counter = limit - incr - return False - if int(value) < limit: - async with self.redis_client.pipeline() as pipe: - [incr] = await pipe.incr(key).execute() - request.state.counter = limit - incr - return False - else: - return True - - async def check_valid_key(self, key: str) -> bool: - if await self.redis_client.sismember("keys", key): - return True - return False - - @staticmethod - def limited_path(route: str) -> bool: - allow_list = ["/", "/openapi.json", "/docs", "/register"] - if route in allow_list: - return False - if "/v2/locations/tiles" in route: - return False - if "/v3/locations/tiles" in route: - return False - if "/assets" in route: - return False - if ".css" in route: - return False - if ".js" in route: - return False - return True - - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - route = request.url.path - if not self.limited_path(route): - response = await call_next(request) - return response - auth = request.headers.get("x-api-key", None) - if auth == settings.EXPLORER_API_KEY: - response = await call_next(request) - return response - limit = self.rate_amount_key - now = datetime.now() - key = f"{request.client.host}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" - if auth: - valid_key = await self.check_valid_key(auth) - if not valid_key: - logging.info( - UnauthorizedLog( - request=request, detail=f"invalid key used: {auth}" - ).model_dump_json() - ) - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"message": "invalid credentials"}, - ) - key = f"{auth}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" - limit = self.rate_amount_key - else: - logging.info( - UnauthorizedLog( - request=request, - ).model_dump_json() - ) - response = JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"message": "API KEY missing from x-api-key header"}, - ) - return response - request.state.counter = limit - limited = False - ttl = 0 - if self.limited_path(route): - limited = await self.request_is_limited(key, limit, request) - ttl = await self.redis_client.ttl(key) - if self.limited_path(route) and limited: - logging.info( - TooManyRequestsLog( - request=request, - rate_limiter=f"{key}/{limit}/{request.state.counter}", - ).model_dump_json() - ) - response = JSONResponse( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - content={"message": "Too many requests"}, - ) - response.headers["x-ratelimit-limit"] = str(limit) - response.headers["x-ratelimit-remaining"] = "0" - response.headers["x-ratelimit-used"] = str(limit) - response.headers["x-ratelimit-reset"] = str(ttl) - return response - request.state.rate_limiter = f"{key}/{limit}/{request.state.counter}" - response = await call_next(request) - response.headers["x-ratelimit-limit"] = str(limit) - response.headers["x-ratelimit-remaining"] = str(request.state.counter) - response.headers["x-ratelimit-used"] = str(limit - request.state.counter) - response.headers["x-ratelimit-reset"] = str(ttl) - return response diff --git a/openaq_api/openaq_api/settings.py b/openaq_api/openaq_api/settings.py index f433e57..e3ae2ee 100644 --- a/openaq_api/openaq_api/settings.py +++ b/openaq_api/openaq_api/settings.py @@ -30,7 +30,6 @@ class Settings(BaseSettings): RATE_LIMITING: bool = False RATE_AMOUNT_KEY: int | None = None - RATE_TIME: int | None = None USER_AGENT: str | None = None ORIGIN: str | None = None diff --git a/openaq_api/tests/test_limiting.py b/openaq_api/tests/test_limiting.py new file mode 100644 index 0000000..88a6995 --- /dev/null +++ b/openaq_api/tests/test_limiting.py @@ -0,0 +1,103 @@ +from fastapi.testclient import TestClient +import json +import time +import os +import pytest +from openaq_api.main import app +from openaq_api.db import db_pool +import re + + + +class FakePipeline: + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + ... + + def incr(self, key): + return self + + def expire(self, key, sec): + return self + + async def execute(self): + return [1, None] + + +class FakeRedisClient: + def __init__(self): + self.api_keys = [ + "limited-api-key", + "not-limited-api-key", + "new-api-key", + ]; + self.api_key_data = { + "limited-api-key": {"get":"10","ttl":5}, + "not-limited-api-key": {"get":"9","ttl":30}, + "missing-api-key": {"get":None,"ttl":-1} + } + + # is this key in the set + async def sismember(self, scope, key): + value = 1 if key in self.api_keys else 0 + print(f"redis sismember: {key} = {value}") + return value + + # number of requests made on this key + async def get(self, key): + key = re.sub('[\d+:]', '', key) + value = self.api_key_data.get(key, {}).get('get') + print(f"redis get: {key} = {value}") + return value + + # time to live + # how many seconds are left for this key + async def ttl(self, key): + key = re.sub('[\d+:]', '', key) + value = self.api_key_data.get(key, {}).get('ttl') + print(f"redis ttl: {key} = {value}") + return value + + # just a way to increment the number of requests + def pipeline(self): + return FakePipeline() + + +@pytest.fixture +def client(): + app.redis = FakeRedisClient() + with TestClient(app) as c: + yield c + + +def test_whitelisted_path_returns_200(client): + response = client.get("/openapi.json") + assert response.status_code == 200 + +def test_no_key_returns_401(client): + response = client.get("/ping") + assert response.status_code == 401 + +def test_empty_key_returns_401(client): + response = client.get("/ping", headers={"X-API-Key":""}) + assert response.status_code == 401 + + +def test_invalid_key_returns_401(client): + response = client.get("/ping", headers={"X-API-Key":"invalid-key"}) + assert response.status_code == 401 + +def test_limited_key_returns_429(client): + response = client.get("/ping", headers={"X-API-Key":"limited-api-key"}) + assert response.status_code == 429 + +def test_not_limited_key_returns_200(client): + response = client.get("/ping", headers={"X-API-Key":"not-limited-api-key"}) + assert response.status_code == 200 + +def test_new_api_key_returns_200(client): + response = client.get("/ping", headers={"X-API-Key":"not-limited-api-key"}) + assert response.status_code == 200 diff --git a/openaq_api/tests/test_sensors_latest.py b/openaq_api/tests/test_sensors_latest.py index cd312c3..3cf1b4e 100644 --- a/openaq_api/tests/test_sensors_latest.py +++ b/openaq_api/tests/test_sensors_latest.py @@ -47,7 +47,7 @@ def test_default_good(self, client): response = client.get(f"/v3/parameters/{measurands_id}/latest") assert response.status_code == 200 data = json.loads(response.content).get('results', []) - assert len(data) == 5 + assert len(data) == 6 def test_date_filter(self, client): response = client.get(f"/v3/parameters/{measurands_id}/latest?datetime_min=2024-08-27")