Skip to content

Commit

Permalink
Add cookie-backed sessions
Browse files Browse the repository at this point in the history
This should be more secure than just a plain IP check.
  • Loading branch information
ToucheSir committed Sep 26, 2020
1 parent 37fd007 commit e723d58
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 49 deletions.
51 changes: 30 additions & 21 deletions backend/app/api.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
import csv
import json
from io import TextIOWrapper
from collections import defaultdict
import secrets
import socket
from datetime import datetime
from typing import DefaultDict, List, Optional, Set, Tuple
from typing import List, Optional

from bson import ObjectId
from fastapi import (
APIRouter,
Request,
BackgroundTasks,
Depends,
Query,
HTTPException,
status,
Form,
File,
Form,
HTTPException,
Query,
Request,
Response,
UploadFile,
status,
)
from bson import ObjectId
from fastapi.param_functions import Cookie
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from passlib.hash import bcrypt

from app.config import Settings, get_settings
from app.models import *
from app.database import DatabaseContext

from fastapi.security import HTTPBasic, HTTPBasicCredentials
from passlib.hash import bcrypt
from fastapi.responses import HTMLResponse
import socket
from app.models import *

security = HTTPBasic()
router = APIRouter()
Expand All @@ -42,15 +41,17 @@ def verify_password(plain_password, hashed_password: SecretStr):


async def get_current_user(
response: Response,
credentials: HTTPBasicCredentials = Depends(security),
session_id: Optional[str] = Cookie(None),
db: DatabaseContext = Depends(get_db),
) -> Annotator:
user = await db.get_annotator(credentials.username)
# Get users IP address
hostname = socket.gethostname()
ip_address = socket.gethostbyname(hostname)
# If user has an active session skip bcrypt validation:
if await db.active_session(ip_address):
if session_id and await db.active_session(ip_address, session_id):
return user
# Else try to login user
if not user or not verify_password(credentials.password, user.hashed_password):
Expand All @@ -60,7 +61,16 @@ async def get_current_user(
headers={"WWW-Authenticate": "Basic"},
)
# Create active session for newly logged in user
await db.create_session(ip_address)
session_id = secrets.token_urlsafe()
await db.create_session(ip_address, session_id)
response.set_cookie(
"session_id",
session_id,
secure=True,
httponly=True,
samesite="strict",
max_age=15 * 60,
)
return user


Expand All @@ -87,7 +97,6 @@ async def audit_handler(
query_params=request.query_params,
body=body,
)

background.add_task(db.add_audit_event, audit_event)


Expand Down Expand Up @@ -206,7 +215,7 @@ async def get_segment(
segment_id: str,
annotator_username: str = Query(None, alias="annotator"),
db: DatabaseContext = Depends(get_db),
username: str = Depends(get_current_user),
_=Depends(get_current_user),
):
segment: SegmentRecord = await db.get_segment(ObjectId(segment_id))
annotation = segment.annotations.get(annotator_username)
Expand All @@ -222,7 +231,7 @@ async def update_segment_annotations(
annotator_username: str,
annotation: Annotation,
db: DatabaseContext = Depends(get_db),
username: str = Depends(get_current_user),
_=Depends(get_current_user),
):
try:
await db.update_annotation(ObjectId(segment_id), annotator_username, annotation)
Expand Down
26 changes: 13 additions & 13 deletions backend/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ def __enter__(self):
def __exit__(self, *_):
self.client.close()

async def active_session(self, user_IP: str):
async def active_session(self, user_ip: str, session_id: str):
return await self.sessions.find_one_and_update(
{"user_IP": user_IP},
{"$set":
{"lastLoginAt": datetime.now()}
}
)

async def create_session(self, user_IP: str):
await self.sessions.insert_one({
"lastLoginAt": datetime.now(),
"user_IP": user_IP
})
return
{"user_ip": user_ip, "session_id": session_id},
{"$set": {"last_login": datetime.utcnow()}},
)

async def create_session(self, user_ip: str, session_id: str):
await self.sessions.insert_one(
{
"user_ip": user_ip,
"session_id": session_id,
"last_login": datetime.utcnow(),
}
)

async def set_campaign(self, username: str, campaign: AnnotationCampaign):
return await self.annotators.update_one(
Expand Down
6 changes: 0 additions & 6 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from pathlib import Path

from fastapi import FastAPI, Depends, Form
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.gzip import GZipMiddleware

from app.api import router as api_router, audit_handler


app = FastAPI()
app.add_middleware(GZipMiddleware)
app.include_router(api_router, dependencies=[Depends(audit_handler)])

# Dev runner
Expand Down
19 changes: 10 additions & 9 deletions backend/seed_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from app.config import get_settings
from app.models import SegmentRecord
from pymongo import MongoClient
from pymongo import ASCENDING, HASHED, MongoClient
from bson import ObjectId

# bcrypt.hash("12345")
Expand Down Expand Up @@ -68,24 +68,25 @@ def generate_annotation(i: int, n_samples=SIGNAL_HZ * SECONDS):
db = client.get_database(settings.db_name)
annotators = db.get_collection("annotators")
segments = db.get_collection("segment_records")
sessions = db.get_collection("active_sessions")

with client.start_session() as sess:
annotators.drop(session=sess)
segments.drop(session=sess)
sessions.drop(session=sess)
db.get_collection("audit_events").drop(session=sess)

sessions.create_index(
[("user_ip", ASCENDING), ("session_id", HASHED)],
session=sess,
)
sessions.create_index("last_login", expireAfterSeconds=15 * 60, session=sess)
with sess.start_transaction():
annotators.insert_many(ANNOTATORS)
segments.insert_many([generate_annotation(i) for i in range(N)])

segment_ids = [s["_id"] for s in segments.find(projection=[])]
annotators.update_many(
{},
{
"$set": {
"current_campaign": {
"name": "training",
"segments": segment_ids
}
}
},
{"$set": {"current_campaign": {"name": "training", "segments": segment_ids}}},
)

0 comments on commit e723d58

Please sign in to comment.