Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove engine reuse #5547

Open
wants to merge 5 commits into
base: release-2.50.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/fides/api/schemas/saas/strategy_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class LinkPaginationConfiguration(StrategyConfiguration):
source: LinkSource
rel: Optional[str] = None
path: Optional[str] = None
has_next: Optional[str] = None

@model_validator(mode="before")
@classmethod
Expand Down
7 changes: 7 additions & 0 deletions src/fides/api/service/pagination/pagination_strategy_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, configuration: LinkPaginationConfiguration):
self.source = configuration.source
self.rel = configuration.rel
self.path = configuration.path
self.has_next = configuration.has_next

def get_next_request(
self,
Expand All @@ -40,6 +41,12 @@ def get_next_request(
if not response_data:
return None

if self.has_next:
has_next = pydash.get(response.json(), self.has_next)
logger.info(f"The {self.has_next} field has a value of {has_next}")
if str(has_next).lower() != "true":
return None

# read the next_link from the correct location based on the source value
next_link = None
if self.source == LinkSource.headers.value:
Expand Down
123 changes: 88 additions & 35 deletions src/fides/api/task/execute_request_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from celery.app.task import Task
from loguru import logger
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Query, Session
from tenacity import (
RetryCallState,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from fides.api.common_exceptions import (
PrivacyRequestCanceled,
Expand Down Expand Up @@ -35,14 +43,13 @@
# DSR 3.0 task functions


def run_prerequisite_task_checks(
def get_privacy_request_and_task(
session: Session, privacy_request_id: str, privacy_request_task_id: str
) -> Tuple[PrivacyRequest, RequestTask, Query]:
) -> Tuple[PrivacyRequest, RequestTask]:
"""
Upfront checks that run as soon as the RequestTask is executed by the worker.

Returns resources for use in executing a task
Retrieves and validates a privacy request and its associated task
"""

privacy_request: Optional[PrivacyRequest] = PrivacyRequest.get(
db=session, object_id=privacy_request_id
)
Expand All @@ -65,6 +72,22 @@ def run_prerequisite_task_checks(
f"Request Task with id {privacy_request_task_id} not found for privacy request {privacy_request_id}"
)

return privacy_request, request_task


def run_prerequisite_task_checks(
session: Session, privacy_request_id: str, privacy_request_task_id: str
) -> Tuple[PrivacyRequest, RequestTask, Query]:
"""
Upfront checks that run as soon as the RequestTask is executed by the worker.

Returns resources for use in executing a task
"""

privacy_request, request_task = get_privacy_request_and_task(
session, privacy_request_id, privacy_request_task_id
)

assert request_task # For mypy

upstream_results: Query = request_task.upstream_tasks_objects(session)
Expand Down Expand Up @@ -146,6 +169,43 @@ def can_run_task_body(
return True


def log_retry_attempt(retry_state: RetryCallState) -> None:
"""Log queue_downstream_tasks retry attempts."""

logger.warning(
"queue_downstream_tasks attempt {} failed. Retrying in {} seconds...",
retry_state.attempt_number,
retry_state.next_action.sleep, # type: ignore[union-attr]
)


@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1),
retry=retry_if_exception_type(OperationalError),
before_sleep=log_retry_attempt,
)
def queue_downstream_tasks_with_retries(
database_task: DatabaseTask,
privacy_request_id: str,
privacy_request_task_id: str,
current_step: CurrentStep,
privacy_request_proceed: bool,
) -> None:
with database_task.get_new_session() as session:
privacy_request, request_task = get_privacy_request_and_task(
session, privacy_request_id, privacy_request_task_id
)
log_task_complete(request_task)
queue_downstream_tasks(
session,
request_task,
privacy_request,
current_step,
privacy_request_proceed,
)


def queue_downstream_tasks(
session: Session,
request_task: RequestTask,
Expand Down Expand Up @@ -233,16 +293,15 @@ def run_access_node(
]
# Run the main access function
graph_task.access_request(*upstream_access_data)
log_task_complete(request_task)

with self.get_new_session() as session:
queue_downstream_tasks(
session,
request_task,
privacy_request,
CurrentStep.upload_access,
privacy_request_proceed,
)
logger.info(f"Session ID - After get access data: {id(session)}")

queue_downstream_tasks_with_retries(
self,
privacy_request_id,
privacy_request_task_id,
CurrentStep.upload_access,
privacy_request_proceed,
)


@celery_app.task(base=DatabaseTask, bind=True)
Expand Down Expand Up @@ -285,16 +344,13 @@ def run_erasure_node(
# Run the main erasure function!
graph_task.erasure_request(retrieved_data)

log_task_complete(request_task)

with self.get_new_session() as session:
queue_downstream_tasks(
session,
request_task,
privacy_request,
CurrentStep.finalize_erasure,
privacy_request_proceed,
)
queue_downstream_tasks_with_retries(
self,
privacy_request_id,
privacy_request_task_id,
CurrentStep.finalize_erasure,
privacy_request_proceed,
)


@celery_app.task(base=DatabaseTask, bind=True)
Expand Down Expand Up @@ -339,16 +395,13 @@ def run_consent_node(

graph_task.consent_request(access_data[0] if access_data else {})

log_task_complete(request_task)

with self.get_new_session() as session:
queue_downstream_tasks(
session,
request_task,
privacy_request,
CurrentStep.finalize_consent,
privacy_request_proceed,
)
queue_downstream_tasks_with_retries(
self,
privacy_request_id,
privacy_request_task_id,
CurrentStep.finalize_consent,
privacy_request_proceed,
)


def logger_method(request_task: RequestTask) -> Callable:
Expand Down
4 changes: 3 additions & 1 deletion src/fides/api/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from celery import Celery, Task
from loguru import logger
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from tenacity import (
Expand Down Expand Up @@ -64,10 +65,11 @@ def get_new_session(self) -> ContextManager[Session]:
if self._task_engine is None:
self._task_engine = get_db_engine(
config=CONFIG,
pool_size=CONFIG.database.task_engine_pool_size,
max_overflow=CONFIG.database.task_engine_max_overflow,
keepalives_idle=CONFIG.database.task_engine_keepalives_idle,
keepalives_interval=CONFIG.database.task_engine_keepalives_interval,
keepalives_count=CONFIG.database.task_engine_keepalives_count,
disable_pooling=True,
)

# same for the sessionmaker
Expand Down
80 changes: 80 additions & 0 deletions tests/ops/service/pagination/test_pagination_strategy_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,42 @@ def response_with_empty_string_link():
return response


@pytest.fixture(scope="function")
def response_with_has_next_conditional_true():
response = Response()
response._content = bytes(
json.dumps(
{
"customers": [{"id": 1}, {"id": 2}, {"id": 3}],
"links": {
"next": "https://domain.com/customers?page=def",
"hasNext": True,
},
}
),
"utf-8",
)
return response


@pytest.fixture(scope="function")
def response_with_has_next_conditional_false():
response = Response()
response._content = bytes(
json.dumps(
{
"customers": [{"id": 1}, {"id": 2}, {"id": 3}],
"links": {
"next": "https://domain.com/customers?page=abc",
"hasNext": False,
},
}
),
"utf-8",
)
return response


def test_link_in_headers(response_with_header_link):
config = LinkPaginationConfiguration(source="headers", rel="next")
request_params: SaaSRequestParams = SaaSRequestParams(
Expand Down Expand Up @@ -132,6 +168,50 @@ def test_link_in_body_empty_string(response_with_empty_string_link):
assert next_request is None


## TODO: Tests for when the link exists but there is a conditional boolean that checks if there is a next page
def test_link_in_body_with_conditional_boolean_true(
response_with_has_next_conditional_true,
):
config = LinkPaginationConfiguration(
source="body", path="links.next", has_next="links.hasNext"
)
request_params: SaaSRequestParams = SaaSRequestParams(
method=HTTPMethod.GET,
path="/customers",
query_params={"page": "abc"},
)

paginator = LinkPaginationStrategy(config)
next_request: Optional[SaaSRequestParams] = paginator.get_next_request(
request_params, {}, response_with_has_next_conditional_true, "customers"
)

assert next_request == SaaSRequestParams(
method=HTTPMethod.GET,
path="/customers",
query_params={"page": "def"},
)


def test_link_in_body_with_conditional_boolean_false(
response_with_has_next_conditional_false,
):
config = LinkPaginationConfiguration(
source="body", path="links.next", has_next="links.hasNext"
)
request_params: SaaSRequestParams = SaaSRequestParams(
method=HTTPMethod.GET,
path="/customers",
query_params={"page": "abc"},
)

paginator = LinkPaginationStrategy(config)
next_request: Optional[SaaSRequestParams] = paginator.get_next_request(
request_params, {}, response_with_has_next_conditional_false, "customers"
)
assert next_request is None


def test_wrong_source():
with pytest.raises(ValueError) as exc:
LinkPaginationConfiguration(source="somewhere", path="links.next")
Expand Down
Loading