Skip to content

Commit

Permalink
Add dummy permissions checks to endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
eudoxos committed Nov 26, 2024
1 parent fe10436 commit 8b59ba9
Showing 1 changed file with 66 additions and 23 deletions.
89 changes: 66 additions & 23 deletions mupifDB/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import time

from fastapi import FastAPI, UploadFile, Depends, HTTPException
from fastapi import FastAPI, UploadFile, Depends, HTTPException, Request
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -46,6 +46,8 @@
# shorthands for common exceptions
def NotFoundError(detail):
return HTTPException(status_code=404, detail=detail)
def ForbiddenError(detail):
return HTTPException(status_code=403, detail=detail)

from mupifDB import models

Expand All @@ -62,11 +64,13 @@ def NotFoundError(detail):
# the validation would (hopefully) happen automatically
#
import contextlib
from typing import Generator
from typing import Iterator,TypeVar
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.database import Database

@contextlib.contextmanager
def db_transaction() -> Generator[ClientSession|None]:
def db_transaction() -> Iterator[ClientSession|None]:
if 1:
# return None as session object, this makes the context no-op
yield None
Expand All @@ -76,6 +80,32 @@ def db_transaction() -> Generator[ClientSession|None]:
yield session


_PermWhat=Literal['read','child','modify']
_PermOn=Literal['self','parent']
_PermObj=models.GridFSFile_Model|models.MongoObj_Model
T=TypeVar('T')

class Perms(object):
def __init__(self, db: Database):
self.db=db
def has(self, obj: Any, perm: _PermWhat='read', on: _PermOn='self') -> bool:
'TODO: to be implemented'
return True
def ensure(self, obj: T, perm:_PermWhat='read', on: _PermOn='self',diag: str|None=None) -> T:
'Check permissions on the object (read on obj by default) and return it. Raise ForbiddenError if the check fails.'
if not self.has(obj=obj,perm=perm, on=on):
raise ForbiddenError('Forbidden {perm} access to {"the parent of" if on=="parent" else ""} {obj.__class__.__name__}(dbID={obj.dbID}){": "+diag if diag else ""}.')
return obj
def TODO(*args,**kw): pass
def filterSelfRead(self,objs: List[T]) -> List[T]: return [obj for obj in objs if self.has(obj,perm='read',on='self')]
def notRemote(self, request: Request, diag: str):
import ipaddress
if request.client is None: raise ForbiddenError('Client address unknown ({diag}).')
if not ipaddress.ip_address(request.client.host).is_loopback: raise ForbiddenError('Remote access (from {requests.client.host}) forbidden ({diag}).')

perms = Perms(db=db)


tags_metadata = [
{
"name": "Users",
Expand Down Expand Up @@ -172,24 +202,25 @@ def read_root():
@app.get("/usecases/", tags=["Usecases"])
def get_usecases() -> List[models.UseCase_Model]:
res = db.UseCases.find()
return [models.UseCase_Model.model_validate(r) for r in res]
return perms.filterSelfRead([m:=models.UseCase_Model.model_validate(r) for r in res])


@app.get("/usecases/{uid}", tags=["Usecases"])
def get_usecase(uid: str) -> models.UseCase_Model:
res = db.UseCases.find_one({"ucid": uid})
if res is None: raise NotFoundError(f'Database reports no workflow with ucid={uid}.')
return models.UseCase_Model.model_validate(res)
return perms.ensure(models.UseCase_Model.model_validate(res))


@app.get("/usecases/{uid}/workflows", tags=["Usecases"])
def get_usecase_workflows(uid: str) -> List[models.Workflow_Model]:
res = db.Workflows.find({"UseCase": uid})
return [models.Workflow_Model.model_validate(r) for r in res]
return perms.filterSelfRead([models.Workflow_Model.model_validate(r) for r in res])


@app.post("/usecases/", tags=["Usecases"])
def post_usecase(uc: models.UseCase_Model) -> str:
perms.ensure(uc,perm='child',on='parent')
res = db.UseCases.insert_one(uc.model_dump_db())
return str(res.inserted_id)

Expand All @@ -201,30 +232,34 @@ def post_usecase(uc: models.UseCase_Model) -> str:
@app.get("/workflows/", tags=["Workflows"])
def get_workflows() -> List[models.Workflow_Model]:
res = db.Workflows.find()
return [models.Workflow_Model.model_validate(r) for r in res]
return perms.filterSelfRead([models.Workflow_Model.model_validate(r) for r in res])

@app.get("/workflows/{workflow_id}", tags=["Workflows"])
def get_workflow(workflow_id: str) -> models.Workflow_Model:
res = db.Workflows.find_one({"wid": workflow_id})
if res is None: raise NotFoundError(f'Database reports no workflow with wid={workflow_id}.')
return models.Workflow_Model.model_validate(res)
return perms.ensure(models.Workflow_Model.model_validate(res))

@app.patch("/workflows/", tags=["Workflows"])
def update_workflow(wf: models.Workflow_Model) -> models.Workflow_Model:
perms.ensure(wf,perm='modify')
# don't write the result if the result after the update does not validate
with db_transaction() as session:
# PERM: self write
res = db.Workflows.find_one_and_update({'wid': wf.wid}, {'$set': wf.model_dump_db()}, return_document=ReturnDocument.AFTER, session=session)
return models.Workflow_Model.model_validate(res)


@app.post("/workflows/", tags=["Workflows"])
def insert_workflow(wf: models.Workflow_Model) -> str:
perms.ensure(wf,perm='child',on='parent')
res = db.Workflows.insert_one(wf.model_dump_db())
return str(res.inserted_id)


@app.post("/workflows_history/", tags=["Workflows"])
def insert_workflow_history(wf: models.Workflow_Model) -> str:
perms.ensure(wf,perm='child',on='parent')
res = db.WorkflowsHistory.insert_one(wf.model_dump_db())
return str(res.inserted_id)

Expand All @@ -235,10 +270,9 @@ def insert_workflow_history(wf: models.Workflow_Model) -> str:

@app.get("/workflows_history/{workflow_id}/{workflow_version}", tags=["Workflows"])
def get_workflow_history(workflow_id: str, workflow_version: int) -> models.Workflow_Model:
# print(f'AAA: {workflow_id=} {workflow_version=}')
res = db.WorkflowsHistory.find_one({"wid": workflow_id, "Version": workflow_version})
if res is None: raise NotFoundError(f'Database reports no workflow with wid={workflow_id} and Version={workflow_version}.')
return models.Workflow_Model.model_validate(res)
return perms.ensure(models.Workflow_Model.model_validate(res))

# --------------------------------------------------
# Executions
Expand All @@ -262,15 +296,14 @@ def get_executions(status: str = "", workflow_version: int = 0, workflow_id: str
#pprint(filtering)
res = db.WorkflowExecutions.find(filtering).sort('SubmittedDate', 1).limit(num_limit)
# pprint(res)
return [models.WorkflowExecution_Model.model_validate(r) for r in res]
return perms.filterSelfRead([models.WorkflowExecution_Model.model_validate(r) for r in res])


@app.get("/executions/{uid}", tags=["Executions"])
def get_execution(uid: str) -> models.WorkflowExecution_Model:
res = db.WorkflowExecutions.find_one({"_id": bson.objectid.ObjectId(uid)})
if res is None: raise NotFoundError(f'Database reports no execution with uid={uid}.')
return models.WorkflowExecution_Model.model_validate(res)

return perms.ensure(models.WorkflowExecution_Model.model_validate(res))

# FIXME: how is this different from get_execution??
@app.get("/edm_execution/{uid}", tags=["Executions"])
Expand All @@ -283,7 +316,7 @@ def get_edm_execution_uid(uid: str) -> models.WorkflowExecution_Model:

@app.get("/edm_execution/{uid}/{entity}/{iotype}", tags=["Executions"])
def get_edm_execution_uid_entity_iotype(uid: str, entity: str, iotype: Literal['input','output']) -> List[str]:
obj=get_edm_execution_uid(uid)
obj=get_edm_execution_uid(uid) # handles perms
for m in obj.EDMMapping:
T='input' if (m.createFrom or m.createNew) else 'output'
if T==iotype and m.EDMEntity==entity:
Expand All @@ -295,18 +328,20 @@ def get_edm_execution_uid_entity_iotype(uid: str, entity: str, iotype: Literal['

@app.post("/executions/create/", tags=["Executions"])
def create_execution(wec: models.WorkflowExecutionCreate_Model) -> str:
perms.TODO(wec)
c = mupifDB.workflowmanager.WorkflowExecutionContext.create(workflowID=wec.wid, workflowVer=wec.version, requestedBy='', ip=wec.ip, no_onto=wec.no_onto)
return str(c.executionID)


@app.post("/executions/", tags=["Executions"])
def insert_execution(data: models.WorkflowExecution_Model) -> str:
perms.ensure(data,perm='child',on='parent')
res = db.WorkflowExecutions.insert_one(data.model_dump_db())
return str(res.inserted_id)

@app.get("/executions/{uid}/inputs/", tags=["Executions"])
def get_execution_inputs(uid: str) -> List[models.IODataRecordItem_Model]:
ex = get_execution(uid)
ex = get_execution(uid) # checks perms already
if ex.Inputs: return models.IODataRecord_Model.model_validate(db.IOData.find_one({'_id': bson.objectid.ObjectId(ex.Inputs)})).DataSet
return []

Expand Down Expand Up @@ -372,7 +407,7 @@ class M_IODataSetContainer(BaseModel):
# FIXME: validation
def set_execution_io_item(uid: str, name: str, obj_id: str, inputs: bool, data_container):
we = get_execution(uid)

perms.ensure(we,perm='modify',on='self')
if (we.Status == 'Created' and inputs==True) or (we.Status == 'Running' and inputs==False):
with db_transaction() as session:
_id=we.Inputs if inputs else we.Outputs
Expand Down Expand Up @@ -414,6 +449,7 @@ class M_ModifyExecutionOntoBaseObjectID(BaseModel):

@app.patch("/executions/{uid}/set_onto_base_object_id/", tags=["Executions"])
def modify_execution_id(uid: str, data: M_ModifyExecutionOntoBaseObjectID):
perms.TODO()
with db_transaction() as session:
rec = db.WorkflowExecutions.find_one_and_update({'_id': bson.objectid.ObjectId(uid), "EDMMapping.Name": data.name}, {"$set": {"EDMMapping.$.id": data.value}}, return_document=ReturnDocument.AFTER, session=session)
models.WorkflowExecution_Model.model_validate(rec)
Expand Down Expand Up @@ -445,6 +481,7 @@ class M_ModifyExecution(BaseModel):

@app.patch("/executions/{uid}", tags=["Executions"])
def modify_execution(uid: str, data: M_ModifyExecution):
perms.TODO()
with db_transaction() as session:
rec=db.WorkflowExecutions.find_one_and_update({'_id': bson.objectid.ObjectId(uid)}, {"$set": {data.key: data.value}}, return_document=ReturnDocument.AFTER, session=session)
models.WorkflowExecution_Model.model_validate(rec)
Expand All @@ -453,7 +490,7 @@ def modify_execution(uid: str, data: M_ModifyExecution):

@app.patch("/executions/{uid}/schedule", tags=["Executions"])
def schedule_execution(uid: str):
execution_record = get_execution(uid)
execution_record = perms.ensure(get_execution(uid),perm='modify')
if execution_record.Status == 'Created' or True:
data = type('', (), {})()
mod=M_ModifyExecution(key = "Status",value = "Pending")
Expand All @@ -468,12 +505,13 @@ def schedule_execution(uid: str):
@app.get("/iodata/{uid}", tags=["IOData"])
def get_execution_iodata(uid: str) -> models.IODataRecord_Model:
res = db.IOData.find_one({'_id': bson.objectid.ObjectId(uid)})
if res is None: raise NotFoundError(f'Database reports no iodata with uid={uid}.')
return models.IODataRecord_Model.model_validate(res)
if res is None: raise NotFoundError(f'Database reports no IOData with uid={uid}.')
return perms.ensure(models.IODataRecord_Model.model_validate(res))

# TODO: pass and store parent data as well
@app.post("/iodata/", tags=["IOData"])
def insert_execution_iodata(data: models.IODataRecord_Model):
perms.ensure(data,perm='child',on='parent')
res = db.IOData.insert_one(data.model_dump_db())
return str(res.inserted_id)

Expand Down Expand Up @@ -501,13 +539,16 @@ def get_file(uid: str, tdir=Depends(get_temp_dir)):
fs = gridfs.GridFS(db)
foundfile = fs.get(bson.objectid.ObjectId(uid))
if not foundfile: raise NotFoundError('Database reports no file with {uid=}.')
# open the corresponding record in fs.files to check perms
perms.ensure(models.GridFSFile_Model.model_validate(db.get_collection('fs.files').find_one({'_id': bson.objectid.ObjectId(uid)})))
wfile = io.BytesIO(foundfile.read())
fn = foundfile.filename
return StreamingResponse(wfile, headers={"Content-Disposition": "attachment; filename=" + fn})

# TODO: store parent as metadata, validate the fs.files record as well
# TODO: needs parent as parameter, so that perms can be checked
@app.post("/file/", tags=["Files"])
def upload_file(file: UploadFile):
perms.TODO()
if file:
fs = gridfs.GridFS(db)
sourceID = fs.put(file.file, filename=file.filename)
Expand All @@ -518,7 +559,7 @@ def upload_file(file: UploadFile):
@app.get("/property_array_data/{fid}/{i_start}/{i_count}/", tags=["Additional"])
def get_property_array_data(fid: str, i_start: int, i_count: int):
# XXX: make a direct function call, no need to go through REST API again (or is that for granta?)
pfile, fn = mupifDB.restApiControl.getBinaryFileByID(fid)
pfile, fn = mupifDB.restApiControl.getBinaryFileByID(fid) # checks perms
with tempfile.TemporaryDirectory(dir="/tmp", prefix='mupifDB') as tempDir:
full_path = tempDir + "/file.h5"
f = open(full_path, 'wb')
Expand Down Expand Up @@ -555,7 +596,8 @@ def get_field_as_vtu(fid: str, tdir=Depends(get_temp_dir)):
# --------------------------------------------------

@app.post("/logs/", tags=["Logs"])
def insert_log(data: dict):
def insert_log(data: dict, request: Request):
perms.notRemote(request,'inserting logging data')
res = db.Logs.insert_one(data)
return str(res.inserted_id)

Expand Down Expand Up @@ -666,7 +708,8 @@ class M_ModifyStatistics(BaseModel):


@app.patch("/scheduler_statistics/", tags=["Stats"])
def set_scheduler_statistics(data: M_ModifyStatistics):
def set_scheduler_statistics(data: M_ModifyStatistics, request: Request):
perms.notRemote(request,'modifying scheduler statistics')
if data.key in ["scheduler.runningTasks", "scheduler.scheduledTasks", "scheduler.load", "scheduler.processedTasks"]:
res = db.Stat.update_one({}, {"$set": {data.key: int(data.value)}})
return True
Expand Down

0 comments on commit 8b59ba9

Please sign in to comment.