Skip to content

Commit

Permalink
Feature/auth api (#340)
Browse files Browse the repository at this point in the history
* auth api updates
  • Loading branch information
russbiggs authored Feb 21, 2024
1 parent abdbfa7 commit d31a3fd
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 34 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/deploy-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ env:

EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }}

EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }}


jobs:
deploy:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/deploy-staging.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ env:
RATE_TIME: 1

EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }}

EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }}


jobs:
deploy:
Expand Down
62 changes: 59 additions & 3 deletions openaq_api/openaq_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
allowed_config_params = ["work_mem"]



DEFAULT_CONNECTION_TIMEOUT = 6
MAX_CONNECTION_TIMEOUT = 15

Expand All @@ -34,6 +33,7 @@ def default(obj):
# function is used in the `cached` decorator and without it
# we will get a number of arguments error


def dbkey(m, f, query, args, timeout=None, config=None):
j = orjson.dumps(
args, option=orjson.OPT_OMIT_MICROSECONDS, default=default
Expand Down Expand Up @@ -115,7 +115,9 @@ async def fetch(
q = f"SELECT set_config('{param}', $1, TRUE)"
s = await con.execute(q, str(value))
if not isinstance(timeout, (str, int)):
logger.warning(f"Non int or string timeout value passed - {timeout}")
logger.warning(
f"Non int or string timeout value passed - {timeout}"
)
timeout = DEFAULT_CONNECTION_TIMEOUT
r = await wait_for(con.fetch(rquery, *args), timeout=timeout)
await tr.commit()
Expand Down Expand Up @@ -193,9 +195,63 @@ async def create_user(self, user: User) -> str:
await conn.close()
return verification_token[0][0]

async def get_user(self, users_id: int) -> str:
"""
gets user info from users table and entities table
"""
query = """
SELECT
e.full_name
, u.email_address
, u.verification_code
FROM
users u
JOIN
users_entities USING (users_id)
JOIN
entities e USING (entities_id)
WHERE
u.users_id = :users_id
"""
conn = await asyncpg.connect(settings.DATABASE_READ_URL)
rquery, args = render(query, **{"users_id": users_id})
user = await conn.fetch(rquery, *args)
await conn.close()
return user[0]

async def regenerate_user_token(self, users_id: int, token: str) -> str:
"""
calls the get_user_token plpgsql function to verify user email and generate API token
"""
query = """
UPDATE
user_keys
SET
token = generate_token()
WHERE
users_id = :users_id
AND
token = :token
"""
conn = await asyncpg.connect(settings.DATABASE_WRITE_URL)
rquery, args = render(query, **{"users_id": users_id, "token": token})
await conn.fetch(rquery, *args)
await conn.close()

async def get_user_token(self, users_id: int) -> str:
""" """
query = """
SELECT token FROM user_keys WHERE users_id = :users_id
"""
conn = await asyncpg.connect(settings.DATABASE_WRITE_URL)
rquery, args = render(query, **{"users_id": users_id})
api_token = await conn.fetch(rquery, *args)
await conn.close()
return api_token[0][0]

async def generate_user_token(self, users_id: int) -> str:
"""
calls the get_user_token plpgsql function to vefiry user email and generate API token
calls the get_user_token plpgsql function to verify user email and generate API token
"""
query = """
SELECT * FROM get_user_token(:users_id)
Expand Down
57 changes: 30 additions & 27 deletions openaq_api/openaq_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from openaq_api.middleware import (
CacheControlMiddleware,
LoggingMiddleware,
PrivatePathsMiddleware,
RateLimiterMiddleWare,
)
from openaq_api.models.logging import (
Expand All @@ -47,6 +48,7 @@

# V3 routers
from openaq_api.v3.routers import (
auth,
countries,
instruments,
locations,
Expand Down Expand Up @@ -99,16 +101,38 @@ def render(self, content: Any) -> bytes:
return orjson.dumps(content, default=default)


redis_client = None # initialize for generalize_schema.py


@asynccontextmanager
async def lifespan(app: FastAPI):
if not hasattr(app.state, "pool"):
logger.debug("initializing connection pool")
app.state.pool = await db_pool(None)
logger.debug("Connection pool established")

if hasattr(app.state, "counter"):
app.state.counter += 1
else:
app.state.counter = 0
app.state.redis_client = redis_client
yield
if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL:
logger.debug("Closing connection")
await app.state.pool.close()
delattr(app.state, "pool")
logger.debug("Connection closed")


app = FastAPI(
title="OpenAQ",
description="OpenAQ API",
version="2.0.0",
default_response_class=ORJSONResponse,
docs_url="/docs",
lifespan=lifespan,
)

redis_client = None # initialize for generalize_schema.py


if settings.RATE_LIMITING is True:
if settings.RATE_LIMITING:
Expand All @@ -126,6 +150,7 @@ def render(self, content: Any) -> bytes:
logging.error(
InfrastructureErrorLog(detail=f"failed to connect to redis: {e}")
)
print(redis_client)
logger.debug("Redis connected")
if redis_client:
app.add_middleware(
Expand All @@ -152,8 +177,7 @@ def render(self, content: Any) -> bytes:
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900")
app.add_middleware(LoggingMiddleware)
app.add_middleware(GZipMiddleware, minimum_size=1000)

app.include_router(auth_router)
app.add_middleware(PrivatePathsMiddleware)


class OpenAQValidationResponseDetail(BaseModel):
Expand Down Expand Up @@ -199,29 +223,6 @@ async def openaq_exception_handler(request: Request, exc: ValidationError):
# return ORJSONResponse(status_code=500, content={"message": "internal server error"})


@app.on_event("startup")
async def startup_event():
if not hasattr(app.state, "pool"):
logger.debug("initializing connection pool")
app.state.pool = await db_pool(None)
logger.debug("Connection pool established")

if hasattr(app.state, "counter"):
app.state.counter += 1
else:
app.state.counter = 0
app.state.redis_client = redis_client


@app.on_event("shutdown")
async def shutdown_event():
if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL:
logger.debug("Closing connection")
await app.state.pool.close()
delattr(app.state, "pool")
logger.debug("Connection closed")


@app.get("/ping", include_in_schema=False)
def pong():
"""
Expand All @@ -239,6 +240,7 @@ def favico():


# v3
app.include_router(auth.router)
app.include_router(instruments.router)
app.include_router(locations.router)
app.include_router(parameters.router)
Expand Down Expand Up @@ -267,6 +269,7 @@ def favico():

static_dir = Path.joinpath(Path(__file__).resolve().parent, "static")


app.mount("/", StaticFiles(directory=str(static_dir), html=True))


Expand Down
22 changes: 22 additions & 0 deletions openaq_api/openaq_api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ async def dispatch(self, request: Request, call_next):
return response


class PrivatePathsMiddleware(BaseHTTPMiddleware):
"""
Middleware to protect private endpoints with an API key
"""

async def dispatch(self, request: Request, call_next):
route = request.url.path
if "/auth" in route:
auth = request.headers.get("x-api-key", None)
if auth != settings.EXPLORER_API_KEY:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"message": "invalid credentials"},
)
response = await call_next(request)
return response


class RateLimiterMiddleWare(BaseHTTPMiddleware):
def __init__(
self,
Expand Down Expand Up @@ -167,8 +185,12 @@ def limited_path(route: str) -> bool:
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
print("RATE LIMIT\n\n\n")
route = request.url.path
auth = request.headers.get("x-api-key", None)
if auth == settings.EXPLORER_API_KEY:
response = await call_next(request)
return response
limit = self.rate_amount
now = datetime.now()
key = f"{request.client.host}:{now.year}{now.month}{now.day}{now.hour}{now.minute}"
Expand Down
46 changes: 45 additions & 1 deletion openaq_api/openaq_api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from email.message import EmailMessage

import boto3
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates
from passlib.hash import pbkdf2_sha256
from pydantic import BaseModel

from ..db import DB
from ..forms.register import RegisterForm, UserExistsException
Expand Down Expand Up @@ -220,6 +221,49 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()):
)


class RegenerateTokenBody(BaseModel):
users_id: int
token: str


@router.post("/regenerate-token")
async def regenerate_token(
token: int = Body(..., embed=True)
# request: Request,
# db: DB = Depends(),
):
""" """
_token = token
print(_token)
try:
# db.get_user_token
# await db.regenerate_user_token(body.users_id, _token)
# token = await db.get_user_token(body.users_id)
# redis_client = getattr(request.app.state, "redis_client")
# print("REDIS", redis_client)
# if redis_client:
# await redis_client.srem("keys", _token)
# await redis_client.sadd("keys", token)
return {"success"}
except Exception as e:
return e


@router.post("/send-verification")
async def get_register(
request: Request,
users_id: int,
db: DB = Depends(),
):
user = db.get_user(users_id=users_id)
full_name = user[0]
email_address = user[1]
verification_code = user[2]
response = send_verification_email(verification_code, full_name, email_address)
logger.info(InfoLog(detail=json.dumps(response)).model_dump_json())
return RedirectResponse("/check-email", status_code=status.HTTP_303_SEE_OTHER)


@router.get("/register")
async def get_register(request: Request):
return templates.TemplateResponse("register/index.html", {"request": request})
Expand Down
2 changes: 2 additions & 0 deletions openaq_api/openaq_api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Settings(BaseSettings):

EMAIL_SENDER: str | None = None

EXPLORER_API_KEY: str

@computed_field(return_type=str, alias="DATABASE_READ_URL")
@property
def DATABASE_READ_URL(self):
Expand Down
Loading

0 comments on commit d31a3fd

Please sign in to comment.