diff --git a/mupifDB/api/main.py b/mupifDB/api/main.py index 092fa97..bd02884 100644 --- a/mupifDB/api/main.py +++ b/mupifDB/api/main.py @@ -15,7 +15,7 @@ import time -from fastapi import FastAPI, UploadFile, Depends, HTTPException +from fastapi import FastAPI, UploadFile, Depends, HTTPException, Request from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware @@ -46,6 +46,8 @@ # shorthands for common exceptions def NotFoundError(detail): return HTTPException(status_code=404, detail=detail) +def ForbiddenError(detail): + return HTTPException(status_code=403, detail=detail) from mupifDB import models @@ -62,11 +64,13 @@ def NotFoundError(detail): # the validation would (hopefully) happen automatically # import contextlib -from typing import Generator +from typing import Iterator,TypeVar from pymongo.client_session import ClientSession +from pymongo.collection import Collection +from pymongo.database import Database @contextlib.contextmanager -def db_transaction() -> Generator[ClientSession|None]: +def db_transaction() -> Iterator[ClientSession|None]: if 1: # return None as session object, this makes the context no-op yield None @@ -76,6 +80,32 @@ def db_transaction() -> Generator[ClientSession|None]: yield session +_PermWhat=Literal['read','child','modify'] +_PermOn=Literal['self','parent'] +_PermObj=models.GridFSFile_Model|models.MongoObj_Model +T=TypeVar('T') + +class Perms(object): + def __init__(self, db: Database): + self.db=db + def has(self, obj: Any, perm: _PermWhat='read', on: _PermOn='self') -> bool: + 'TODO: to be implemented' + return True + def ensure(self, obj: T, perm:_PermWhat='read', on: _PermOn='self',diag: str|None=None) -> T: + 'Check permissions on the object (read on obj by default) and return it. Raise ForbiddenError if the check fails.' + if not self.has(obj=obj,perm=perm, on=on): + raise ForbiddenError('Forbidden {perm} access to {"the parent of" if on=="parent" else ""} {obj.__class__.__name__}(dbID={obj.dbID}){": "+diag if diag else ""}.') + return obj + def TODO(*args,**kw): pass + def filterSelfRead(self,objs: List[T]) -> List[T]: return [obj for obj in objs if self.has(obj,perm='read',on='self')] + def notRemote(self, request: Request, diag: str): + import ipaddress + if request.client is None: raise ForbiddenError('Client address unknown ({diag}).') + if not ipaddress.ip_address(request.client.host).is_loopback: raise ForbiddenError('Remote access (from {requests.client.host}) forbidden ({diag}).') + +perms = Perms(db=db) + + tags_metadata = [ { "name": "Users", @@ -172,24 +202,25 @@ def read_root(): @app.get("/usecases/", tags=["Usecases"]) def get_usecases() -> List[models.UseCase_Model]: res = db.UseCases.find() - return [models.UseCase_Model.model_validate(r) for r in res] + return perms.filterSelfRead([m:=models.UseCase_Model.model_validate(r) for r in res]) @app.get("/usecases/{uid}", tags=["Usecases"]) def get_usecase(uid: str) -> models.UseCase_Model: res = db.UseCases.find_one({"ucid": uid}) if res is None: raise NotFoundError(f'Database reports no workflow with ucid={uid}.') - return models.UseCase_Model.model_validate(res) + return perms.ensure(models.UseCase_Model.model_validate(res)) @app.get("/usecases/{uid}/workflows", tags=["Usecases"]) def get_usecase_workflows(uid: str) -> List[models.Workflow_Model]: res = db.Workflows.find({"UseCase": uid}) - return [models.Workflow_Model.model_validate(r) for r in res] + return perms.filterSelfRead([models.Workflow_Model.model_validate(r) for r in res]) @app.post("/usecases/", tags=["Usecases"]) def post_usecase(uc: models.UseCase_Model) -> str: + perms.ensure(uc,perm='child',on='parent') res = db.UseCases.insert_one(uc.model_dump_db()) return str(res.inserted_id) @@ -201,30 +232,34 @@ def post_usecase(uc: models.UseCase_Model) -> str: @app.get("/workflows/", tags=["Workflows"]) def get_workflows() -> List[models.Workflow_Model]: res = db.Workflows.find() - return [models.Workflow_Model.model_validate(r) for r in res] + return perms.filterSelfRead([models.Workflow_Model.model_validate(r) for r in res]) @app.get("/workflows/{workflow_id}", tags=["Workflows"]) def get_workflow(workflow_id: str) -> models.Workflow_Model: res = db.Workflows.find_one({"wid": workflow_id}) if res is None: raise NotFoundError(f'Database reports no workflow with wid={workflow_id}.') - return models.Workflow_Model.model_validate(res) + return perms.ensure(models.Workflow_Model.model_validate(res)) @app.patch("/workflows/", tags=["Workflows"]) def update_workflow(wf: models.Workflow_Model) -> models.Workflow_Model: + perms.ensure(wf,perm='modify') # don't write the result if the result after the update does not validate with db_transaction() as session: + # PERM: self write res = db.Workflows.find_one_and_update({'wid': wf.wid}, {'$set': wf.model_dump_db()}, return_document=ReturnDocument.AFTER, session=session) return models.Workflow_Model.model_validate(res) @app.post("/workflows/", tags=["Workflows"]) def insert_workflow(wf: models.Workflow_Model) -> str: + perms.ensure(wf,perm='child',on='parent') res = db.Workflows.insert_one(wf.model_dump_db()) return str(res.inserted_id) @app.post("/workflows_history/", tags=["Workflows"]) def insert_workflow_history(wf: models.Workflow_Model) -> str: + perms.ensure(wf,perm='child',on='parent') res = db.WorkflowsHistory.insert_one(wf.model_dump_db()) return str(res.inserted_id) @@ -235,10 +270,9 @@ def insert_workflow_history(wf: models.Workflow_Model) -> str: @app.get("/workflows_history/{workflow_id}/{workflow_version}", tags=["Workflows"]) def get_workflow_history(workflow_id: str, workflow_version: int) -> models.Workflow_Model: - # print(f'AAA: {workflow_id=} {workflow_version=}') res = db.WorkflowsHistory.find_one({"wid": workflow_id, "Version": workflow_version}) if res is None: raise NotFoundError(f'Database reports no workflow with wid={workflow_id} and Version={workflow_version}.') - return models.Workflow_Model.model_validate(res) + return perms.ensure(models.Workflow_Model.model_validate(res)) # -------------------------------------------------- # Executions @@ -262,15 +296,14 @@ def get_executions(status: str = "", workflow_version: int = 0, workflow_id: str #pprint(filtering) res = db.WorkflowExecutions.find(filtering).sort('SubmittedDate', 1).limit(num_limit) # pprint(res) - return [models.WorkflowExecution_Model.model_validate(r) for r in res] + return perms.filterSelfRead([models.WorkflowExecution_Model.model_validate(r) for r in res]) @app.get("/executions/{uid}", tags=["Executions"]) def get_execution(uid: str) -> models.WorkflowExecution_Model: res = db.WorkflowExecutions.find_one({"_id": bson.objectid.ObjectId(uid)}) if res is None: raise NotFoundError(f'Database reports no execution with uid={uid}.') - return models.WorkflowExecution_Model.model_validate(res) - + return perms.ensure(models.WorkflowExecution_Model.model_validate(res)) # FIXME: how is this different from get_execution?? @app.get("/edm_execution/{uid}", tags=["Executions"]) @@ -283,7 +316,7 @@ def get_edm_execution_uid(uid: str) -> models.WorkflowExecution_Model: @app.get("/edm_execution/{uid}/{entity}/{iotype}", tags=["Executions"]) def get_edm_execution_uid_entity_iotype(uid: str, entity: str, iotype: Literal['input','output']) -> List[str]: - obj=get_edm_execution_uid(uid) + obj=get_edm_execution_uid(uid) # handles perms for m in obj.EDMMapping: T='input' if (m.createFrom or m.createNew) else 'output' if T==iotype and m.EDMEntity==entity: @@ -295,18 +328,20 @@ def get_edm_execution_uid_entity_iotype(uid: str, entity: str, iotype: Literal[' @app.post("/executions/create/", tags=["Executions"]) def create_execution(wec: models.WorkflowExecutionCreate_Model) -> str: + perms.TODO(wec) c = mupifDB.workflowmanager.WorkflowExecutionContext.create(workflowID=wec.wid, workflowVer=wec.version, requestedBy='', ip=wec.ip, no_onto=wec.no_onto) return str(c.executionID) @app.post("/executions/", tags=["Executions"]) def insert_execution(data: models.WorkflowExecution_Model) -> str: + perms.ensure(data,perm='child',on='parent') res = db.WorkflowExecutions.insert_one(data.model_dump_db()) return str(res.inserted_id) @app.get("/executions/{uid}/inputs/", tags=["Executions"]) def get_execution_inputs(uid: str) -> List[models.IODataRecordItem_Model]: - ex = get_execution(uid) + ex = get_execution(uid) # checks perms already if ex.Inputs: return models.IODataRecord_Model.model_validate(db.IOData.find_one({'_id': bson.objectid.ObjectId(ex.Inputs)})).DataSet return [] @@ -372,7 +407,7 @@ class M_IODataSetContainer(BaseModel): # FIXME: validation def set_execution_io_item(uid: str, name: str, obj_id: str, inputs: bool, data_container): we = get_execution(uid) - + perms.ensure(we,perm='modify',on='self') if (we.Status == 'Created' and inputs==True) or (we.Status == 'Running' and inputs==False): with db_transaction() as session: _id=we.Inputs if inputs else we.Outputs @@ -414,6 +449,7 @@ class M_ModifyExecutionOntoBaseObjectID(BaseModel): @app.patch("/executions/{uid}/set_onto_base_object_id/", tags=["Executions"]) def modify_execution_id(uid: str, data: M_ModifyExecutionOntoBaseObjectID): + perms.TODO() with db_transaction() as session: rec = db.WorkflowExecutions.find_one_and_update({'_id': bson.objectid.ObjectId(uid), "EDMMapping.Name": data.name}, {"$set": {"EDMMapping.$.id": data.value}}, return_document=ReturnDocument.AFTER, session=session) models.WorkflowExecution_Model.model_validate(rec) @@ -445,6 +481,7 @@ class M_ModifyExecution(BaseModel): @app.patch("/executions/{uid}", tags=["Executions"]) def modify_execution(uid: str, data: M_ModifyExecution): + perms.TODO() with db_transaction() as session: rec=db.WorkflowExecutions.find_one_and_update({'_id': bson.objectid.ObjectId(uid)}, {"$set": {data.key: data.value}}, return_document=ReturnDocument.AFTER, session=session) models.WorkflowExecution_Model.model_validate(rec) @@ -453,7 +490,7 @@ def modify_execution(uid: str, data: M_ModifyExecution): @app.patch("/executions/{uid}/schedule", tags=["Executions"]) def schedule_execution(uid: str): - execution_record = get_execution(uid) + execution_record = perms.ensure(get_execution(uid),perm='modify') if execution_record.Status == 'Created' or True: data = type('', (), {})() mod=M_ModifyExecution(key = "Status",value = "Pending") @@ -468,12 +505,13 @@ def schedule_execution(uid: str): @app.get("/iodata/{uid}", tags=["IOData"]) def get_execution_iodata(uid: str) -> models.IODataRecord_Model: res = db.IOData.find_one({'_id': bson.objectid.ObjectId(uid)}) - if res is None: raise NotFoundError(f'Database reports no iodata with uid={uid}.') - return models.IODataRecord_Model.model_validate(res) + if res is None: raise NotFoundError(f'Database reports no IOData with uid={uid}.') + return perms.ensure(models.IODataRecord_Model.model_validate(res)) # TODO: pass and store parent data as well @app.post("/iodata/", tags=["IOData"]) def insert_execution_iodata(data: models.IODataRecord_Model): + perms.ensure(data,perm='child',on='parent') res = db.IOData.insert_one(data.model_dump_db()) return str(res.inserted_id) @@ -501,13 +539,16 @@ def get_file(uid: str, tdir=Depends(get_temp_dir)): fs = gridfs.GridFS(db) foundfile = fs.get(bson.objectid.ObjectId(uid)) if not foundfile: raise NotFoundError('Database reports no file with {uid=}.') + # open the corresponding record in fs.files to check perms + perms.ensure(models.GridFSFile_Model.model_validate(db.get_collection('fs.files').find_one({'_id': bson.objectid.ObjectId(uid)}))) wfile = io.BytesIO(foundfile.read()) fn = foundfile.filename return StreamingResponse(wfile, headers={"Content-Disposition": "attachment; filename=" + fn}) -# TODO: store parent as metadata, validate the fs.files record as well +# TODO: needs parent as parameter, so that perms can be checked @app.post("/file/", tags=["Files"]) def upload_file(file: UploadFile): + perms.TODO() if file: fs = gridfs.GridFS(db) sourceID = fs.put(file.file, filename=file.filename) @@ -518,7 +559,7 @@ def upload_file(file: UploadFile): @app.get("/property_array_data/{fid}/{i_start}/{i_count}/", tags=["Additional"]) def get_property_array_data(fid: str, i_start: int, i_count: int): # XXX: make a direct function call, no need to go through REST API again (or is that for granta?) - pfile, fn = mupifDB.restApiControl.getBinaryFileByID(fid) + pfile, fn = mupifDB.restApiControl.getBinaryFileByID(fid) # checks perms with tempfile.TemporaryDirectory(dir="/tmp", prefix='mupifDB') as tempDir: full_path = tempDir + "/file.h5" f = open(full_path, 'wb') @@ -555,7 +596,8 @@ def get_field_as_vtu(fid: str, tdir=Depends(get_temp_dir)): # -------------------------------------------------- @app.post("/logs/", tags=["Logs"]) -def insert_log(data: dict): +def insert_log(data: dict, request: Request): + perms.notRemote(request,'inserting logging data') res = db.Logs.insert_one(data) return str(res.inserted_id) @@ -666,7 +708,8 @@ class M_ModifyStatistics(BaseModel): @app.patch("/scheduler_statistics/", tags=["Stats"]) -def set_scheduler_statistics(data: M_ModifyStatistics): +def set_scheduler_statistics(data: M_ModifyStatistics, request: Request): + perms.notRemote(request,'modifying scheduler statistics') if data.key in ["scheduler.runningTasks", "scheduler.scheduledTasks", "scheduler.load", "scheduler.processedTasks"]: res = db.Stat.update_one({}, {"$set": {data.key: int(data.value)}}) return True