diff --git a/src/backend/base/langflow/alembic/versions/a6faa131285d_add_job_table.py b/src/backend/base/langflow/alembic/versions/a6faa131285d_add_job_table.py new file mode 100644 index 000000000000..7ad1ee6d1b1c --- /dev/null +++ b/src/backend/base/langflow/alembic/versions/a6faa131285d_add_job_table.py @@ -0,0 +1,67 @@ +"""add job table + +Revision ID: a6faa131285d +Revises: e3162c1804e6 +Create Date: 2024-12-23 10:54:57.844827 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.engine.reflection import Inspector + +from langflow.utils import migration + +# revision identifiers, used by Alembic. +revision: str = 'a6faa131285d' +down_revision: Union[str, None] = 'e3162c1804e6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) # type: ignore + table_names = inspector.get_table_names() + + # Create job table if it doesn't exist + if "job" not in table_names: + op.create_table( + "job", + sa.Column("id", sqlmodel.sql.sqltypes.AutoString(length=191), primary_key=True), + sa.Column("next_run_time", sa.DateTime(timezone=True), nullable=True), + sa.Column("job_state", sa.LargeBinary(), nullable=True), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("flow_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.now(), onupdate=sa.func.now()), + sa.ForeignKeyConstraint(["flow_id"], ["flow.id"], name="fk_job_flow_id_flow", ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], name="fk_job_user_id_user", ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id", name="pk_job"), + ) + + # Create indices + with op.batch_alter_table("job", schema=None) as batch_op: + batch_op.create_index(batch_op.f("ix_job_name"), ["name"], unique=False) + batch_op.create_index(batch_op.f("ix_job_flow_id"), ["flow_id"], unique=False) + batch_op.create_index(batch_op.f("ix_job_user_id"), ["user_id"], unique=False) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) # type: ignore + table_names = inspector.get_table_names() + + if "job" in table_names: + # Drop indices first + with op.batch_alter_table("job", schema=None) as batch_op: + batch_op.drop_index("ix_job_name") + batch_op.drop_index("ix_job_flow_id") + batch_op.drop_index("ix_job_user_id") + + # Drop the table + op.drop_table("job") diff --git a/src/backend/base/langflow/api/router.py b/src/backend/base/langflow/api/router.py index d2ce1905ada0..d94201182214 100644 --- a/src/backend/base/langflow/api/router.py +++ b/src/backend/base/langflow/api/router.py @@ -8,6 +8,7 @@ files_router, flows_router, folders_router, + jobs_router, login_router, monitor_router, starter_projects_router, @@ -33,3 +34,4 @@ router.include_router(monitor_router) router.include_router(folders_router) router.include_router(starter_projects_router) +router.include_router(jobs_router) diff --git a/src/backend/base/langflow/api/v1/__init__.py b/src/backend/base/langflow/api/v1/__init__.py index 48383770ab77..3d65baf49029 100644 --- a/src/backend/base/langflow/api/v1/__init__.py +++ b/src/backend/base/langflow/api/v1/__init__.py @@ -4,6 +4,7 @@ from langflow.api.v1.files import router as files_router from langflow.api.v1.flows import router as flows_router from langflow.api.v1.folders import router as folders_router +from langflow.api.v1.jobs import router as jobs_router from langflow.api.v1.login import router as login_router from langflow.api.v1.monitor import router as monitor_router from langflow.api.v1.starter_projects import router as starter_projects_router @@ -19,6 +20,7 @@ "files_router", "flows_router", "folders_router", + "jobs_router", "login_router", "monitor_router", "starter_projects_router", diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index d150bdb1576d..8db7ffe45e24 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -310,6 +310,7 @@ async def webhook_run_flow( HTTPException: If the flow is not found or if there is an error processing the request. """ telemetry_service = get_telemetry_service() + task_service = get_task_service() start_time = time.perf_counter() logger.debug("Received webhook request") error_msg = "" @@ -339,12 +340,17 @@ async def webhook_run_flow( session_id=None, ) - logger.debug("Starting background task") - background_tasks.add_task( - simple_run_flow_task, - flow=flow, - input_request=input_request, - api_key_user=user, + logger.debug("Creating job") + job_id = await task_service.create_job( + task_func=simple_run_flow_task, + run_at=None, + name=f"webhook_{flow.name}_{time.time()}", + kwargs={ + "flow": flow, + "input_request": input_request, + "stream": False, + "api_key_user": user, + }, ) except Exception as exc: error_msg = str(exc) @@ -360,7 +366,7 @@ async def webhook_run_flow( ), ) - return {"message": "Task started in the background", "status": "in progress"} + return {"message": "Job created successfully", "status": "pending", "job_id": job_id} @router.post( @@ -507,28 +513,15 @@ async def process() -> None: @router.get("/task/{task_id}") -async def get_task_status(task_id: str) -> TaskStatusResponse: - task_service = get_task_service() - task = task_service.get_task(task_id) - result = None - if task is None: - raise HTTPException(status_code=404, detail="Task not found") - if task.ready(): - result = task.result - # If result isinstance of Exception, can we get the traceback? - if isinstance(result, Exception): - logger.exception(task.traceback) - - if isinstance(result, dict) and "result" in result: - result = result["result"] - elif hasattr(result, "result"): - result = result.result - - if task.status == "FAILURE": - result = str(task.result) - logger.error(f"Task {task_id} failed: {task.traceback}") - - return TaskStatusResponse(status=task.status, result=result) +async def get_task_status(task_id: str) -> TaskStatusResponse: # noqa: ARG001 + # Deprecate this endpoint + logger.warning( + "The /task endpoint is deprecated and will be removed in a future version. Please use /jobs instead." + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The /task endpoint is deprecated and will be removed in a future version. Please use /jobs instead.", + ) @router.post( diff --git a/src/backend/base/langflow/api/v1/jobs.py b/src/backend/base/langflow/api/v1/jobs.py new file mode 100644 index 000000000000..eb94b9ce412e --- /dev/null +++ b/src/backend/base/langflow/api/v1/jobs.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException +from loguru import logger +from pydantic import BaseModel, Field + +from langflow.api.utils import CurrentActiveUser +from langflow.api.v1.endpoints import simple_run_flow_task +from langflow.api.v1.schemas import SimplifiedAPIRequest +from langflow.helpers.flow import get_flow_by_id_or_endpoint_name +from langflow.services.database.models.flow import Flow +from langflow.services.deps import get_task_service +from langflow.services.task.service import TaskService + +router = APIRouter(prefix="/jobs", tags=["Jobs"]) + + +class CreateJobRequest(BaseModel): + """Request model for creating a task.""" + + name: str | None = None + input_request: SimplifiedAPIRequest = Field(..., description="Input request for the flow") + + +class TaskResponse(BaseModel): + """Response model for task operations.""" + + id: str + name: str + pending: bool + + +@router.post("/{flow_id_or_name}", response_model=str) +async def create_job( + request: CreateJobRequest, + user: CurrentActiveUser, + flow: Annotated[Flow, Depends(get_flow_by_id_or_endpoint_name)], +) -> str: + """Create a new job.""" + try: + task_service = get_task_service() + return await task_service.create_job( + task_func=simple_run_flow_task, + run_at=None, + name=request.name, + kwargs={ + "flow": flow, + "input_request": request.input_request, + "stream": False, + "api_key_user": user, + }, + ) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + +@router.get("/{task_id}", response_model=TaskResponse) +async def get_task( + task_id: str, + user: CurrentActiveUser, +) -> TaskResponse: + """Get task information.""" + task_service: TaskService = get_task_service() + task = await task_service.get_job(task_id, user.id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + logger.info(f"Task: {task}") + return TaskResponse.model_validate(task, from_attributes=True) + + +@router.get("/", response_model=list[TaskResponse]) +async def get_tasks( + user: CurrentActiveUser, + task_service: Annotated[TaskService, Depends(get_task_service)], + pending: bool | None = None, +) -> list[TaskResponse]: + """Get all tasks for the current user.""" + tasks = await task_service.get_jobs(user_id=user.id, pending=pending) + return [TaskResponse.model_validate(task, from_attributes=True) for task in tasks] + + +@router.delete("/{task_id}") +async def cancel_task( + task_id: str, + user: CurrentActiveUser, + task_service: Annotated[TaskService, Depends(get_task_service)], +) -> bool: + """Cancel a task.""" + success = await task_service.cancel_job(task_id, user.id) + if not success: + raise HTTPException(status_code=404, detail="Task not found") + return True diff --git a/src/backend/base/langflow/services/base.py b/src/backend/base/langflow/services/base.py index a903332e1591..29b4bb4c889b 100644 --- a/src/backend/base/langflow/services/base.py +++ b/src/backend/base/langflow/services/base.py @@ -24,5 +24,8 @@ def get_schema(self): async def teardown(self) -> None: return + async def setup(self) -> None: + return + def set_ready(self) -> None: self.ready = True diff --git a/src/backend/base/langflow/services/database/models/__init__.py b/src/backend/base/langflow/services/database/models/__init__.py index 4419e7f1109e..d2cbe4181456 100644 --- a/src/backend/base/langflow/services/database/models/__init__.py +++ b/src/backend/base/langflow/services/database/models/__init__.py @@ -1,9 +1,10 @@ from .api_key import ApiKey from .flow import Flow from .folder import Folder +from .job import Job from .message import MessageTable from .transactions import TransactionTable from .user import User from .variable import Variable -__all__ = ["ApiKey", "Flow", "Folder", "MessageTable", "TransactionTable", "User", "Variable"] +__all__ = ["ApiKey", "Flow", "Folder", "Job", "MessageTable", "TransactionTable", "User", "Variable"] diff --git a/src/backend/base/langflow/services/database/models/job/__init__.py b/src/backend/base/langflow/services/database/models/job/__init__.py new file mode 100644 index 000000000000..73300556179c --- /dev/null +++ b/src/backend/base/langflow/services/database/models/job/__init__.py @@ -0,0 +1,3 @@ +from .model import Job + +__all__ = ["Job"] diff --git a/src/backend/base/langflow/services/database/models/job/model.py b/src/backend/base/langflow/services/database/models/job/model.py new file mode 100644 index 000000000000..d7d399928a32 --- /dev/null +++ b/src/backend/base/langflow/services/database/models/job/model.py @@ -0,0 +1,60 @@ +from datetime import datetime +from enum import Enum +from uuid import UUID + +import sqlalchemy as sa +from sqlmodel import JSON, Boolean, Column, DateTime, Field, SQLModel + + +class JobStatus(str, Enum): + """Job status enum.""" + + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELLED = "CANCELLED" + + +class Job(SQLModel, table=True): # type: ignore[call-arg] + """Model for storing scheduled jobs. + + This model extends APScheduler's job table with additional metadata for Langflow. + The core APScheduler fields (id, next_run_time, job_state) are used directly by APScheduler, + while the additional fields are used by Langflow for UI/API purposes. + """ + + # APScheduler required fields + id: str = Field(max_length=191, primary_key=True) + next_run_time: datetime | None = Field(sa_column=Column(sa.DateTime(timezone=True)), default=None) + job_state: bytes | None = Field(sa_column=Column(sa.LargeBinary), default=None) + + # Additional Langflow metadata + status: str = Field(default=JobStatus.PENDING) + result: dict | None = Field(sa_column=Column(JSON), default=None) + error: str | None = Field(default=None) + name: str = Field(index=True) + flow_id: UUID = Field(foreign_key="flow.id", index=True) + user_id: UUID = Field(foreign_key="user.id", index=True) + is_active: bool = Field(default=True, sa_column=Column(Boolean, server_default="true", nullable=False)) + created_at: datetime = Field(sa_column=Column(DateTime, server_default=sa.func.now(), nullable=False)) + updated_at: datetime = Field( + sa_column=Column(DateTime, server_default=sa.func.now(), nullable=False, onupdate=sa.func.now()) + ) + + +class JobRead(SQLModel): + """Model for reading scheduled jobs.""" + + id: str + job_state: bytes | None + next_run_time: datetime | None + status: str + + name: str + flow_id: UUID + user_id: UUID + is_active: bool + created_at: datetime + updated_at: datetime + result: dict | None diff --git a/src/backend/base/langflow/services/manager.py b/src/backend/base/langflow/services/manager.py index 8c66f1ae853e..a1e1eb6dd16f 100644 --- a/src/backend/base/langflow/services/manager.py +++ b/src/backend/base/langflow/services/manager.py @@ -87,6 +87,13 @@ def update(self, service_name: ServiceType) -> None: self.services.pop(service_name, None) self.get(service_name) + async def setup(self) -> None: + """Initialize all the services.""" + for service in self.services.values(): + result = service.setup() + if inspect.iscoroutine(result): + await result + async def teardown(self) -> None: """Teardown all the services.""" for service in self.services.values(): diff --git a/src/backend/base/langflow/services/task/factory.py b/src/backend/base/langflow/services/task/factory.py index c030776de5ea..5c7ea09d39ec 100644 --- a/src/backend/base/langflow/services/task/factory.py +++ b/src/backend/base/langflow/services/task/factory.py @@ -1,4 +1,5 @@ from langflow.services.factory import ServiceFactory +from langflow.services.settings.service import SettingsService from langflow.services.task.service import TaskService @@ -6,6 +7,6 @@ class TaskServiceFactory(ServiceFactory): def __init__(self) -> None: super().__init__(TaskService) - def create(self): - # Here you would have logic to create and configure a TaskService - return TaskService() + def create(self, settings_service: SettingsService) -> TaskService: + """Create a new TaskService instance with the required dependencies.""" + return TaskService(settings_service=settings_service) diff --git a/src/backend/base/langflow/services/task/jobstore.py b/src/backend/base/langflow/services/task/jobstore.py new file mode 100644 index 000000000000..972a4863ab72 --- /dev/null +++ b/src/backend/base/langflow/services/task/jobstore.py @@ -0,0 +1,222 @@ +import pickle +from datetime import datetime, timezone +from uuid import UUID + +from apscheduler.job import Job as APSJob +from apscheduler.jobstores.base import BaseJobStore, JobLookupError +from apscheduler.triggers.date import DateTrigger +from loguru import logger +from sqlmodel import select + +from langflow.services.database.models.job import Job +from langflow.services.deps import session_scope + + +class AsyncSQLModelJobStore(BaseJobStore): + """A job store that uses SQLModel to store jobs in the Langflow database. + + Currently only supports one-off tasks. + """ + + def __init__(self): + super().__init__() + self._jobs = {} + + async def get_all_jobs(self) -> list[Job]: + """Get all jobs in the store.""" + async with session_scope() as session: + stmt = select(Job) + tasks = (await session.exec(stmt)).all() + + jobs = [] + for task in tasks: + try: + job_state = pickle.loads(task.job_state) # noqa: S301 + job = self._reconstitute_job(job_state) + self._jobs[job.id] = job + jobs.append(job) + except Exception: # noqa: BLE001 + logger.exception(f"Unable to restore job {task.id}") + await session.delete(task) + + await session.commit() + return jobs + + async def lookup_job(self, job_id: str, user_id: UUID | None = None) -> APSJob | None: + """Get job by ID.""" + async with session_scope() as session: + stmt = select(Job).where(Job.id == job_id) + if user_id: + if isinstance(user_id, str): + user_id = UUID(user_id) + stmt = stmt.where(Job.user_id == user_id) + db_job = (await session.exec(stmt)).first() + if not db_job: + return None + + try: + job: APSJob = self._reconstitute_job(db_job.job_state) + self._jobs[job_id] = job + except Exception: # noqa: BLE001 + logger.exception(f"Unable to restore job {job_id}") + await session.delete(db_job) + await session.commit() + return None + return job + + async def get_due_jobs(self, now: datetime) -> list[Job]: + """Get all jobs that should be run at the given time.""" + async with session_scope() as session: + stmt = select(Job).where( + Job.next_run_time <= now, + Job.is_active == True, # noqa: E712 + ) + tasks = (await session.exec(stmt)).all() + + jobs = [] + for task in tasks: + try: + job_state = pickle.loads(task.job_state) # noqa: S301 + job = self._reconstitute_job(job_state) + self._jobs[job.id] = job + jobs.append(job) + except Exception: # noqa: BLE001 + logger.exception(f"Unable to restore job {task.id}") + await session.delete(task) + + await session.commit() + return jobs + + async def get_next_run_time(self) -> datetime | None: + """Get the earliest timestamp of all scheduled jobs.""" + async with session_scope() as session: + stmt = ( + select(Job) + .where(Job.is_active == True) # noqa: E712 + .order_by(Job.next_run_time) + ) + task = (await session.exec(stmt)).first() + return task.next_run_time if task else None + + async def add_job(self, job: APSJob) -> None: + """Add a one-off job.""" + if not isinstance(job.trigger, DateTrigger): + msg = "Only one-off tasks are supported" + raise TypeError(msg) + + job_state = pickle.dumps(job.__getstate__()) + + async with session_scope() as session: + if "flow" not in job.kwargs or "api_key_user" not in job.kwargs: + msg = f"Job invalid: {job}" + raise ValueError(msg) + + flow = job.kwargs.get("flow") + api_key_user = job.kwargs.get("api_key_user") + + flow_id = flow.get("id") if isinstance(flow, dict) else flow.id + if isinstance(flow_id, str): + flow_id = UUID(flow_id) + + api_key_user_id = api_key_user.get("id") if isinstance(api_key_user, dict) else api_key_user.id + + if isinstance(api_key_user_id, str): + api_key_user_id = UUID(api_key_user_id) + + # Check for ids + if not isinstance(flow_id, UUID) or not isinstance(api_key_user_id, UUID): + msg = f"Job invalid: {job}" + raise TypeError(msg) + + try: + task = Job( + id=job.id, + name=job.name, + flow_id=flow_id, + user_id=api_key_user_id, + is_active=True, + next_run_time=job.next_run_time, + job_state=job_state, + ) + + session.add(task) + await session.commit() + await session.refresh(task) + self._jobs[job.id] = job + except Exception as exc: + logger.exception(f"Unable to add job {job.id}") + msg = f"Job invalid: {job}" + raise ValueError(msg) from exc + + async def update_job(self, job: APSJob) -> None: + """Update a job in the store.""" + async with session_scope() as session: + stmt = select(Job).where(Job.id == job.id) + task = (await session.exec(stmt)).first() + if not task: + raise JobLookupError(job.id) + + job_state = job.__getstate__() + task.name = job.name + task.next_run_time = job.next_run_time + task.job_state = job_state + task.updated_at = datetime.now(timezone.utc) + + session.add(task) + await session.commit() + await session.refresh(task) + self._jobs[job.id] = job + + async def remove_job(self, job_id: str) -> None: + """Remove a job.""" + async with session_scope() as session: + stmt = select(Job).where(Job.id == job_id) + task = (await session.exec(stmt)).first() + if not task: + raise JobLookupError(job_id) + + await session.delete(task) + await session.commit() + self._jobs.pop(job_id, None) + + async def remove_all_jobs(self) -> None: + """Remove all jobs.""" + async with session_scope() as session: + stmt = select(Job) + tasks = (await session.exec(stmt)).all() + for task in tasks: + await session.delete(task) + await session.commit() + self._jobs.clear() + + async def get_user_jobs(self, user_id: UUID, pending: bool | None = None) -> list[Job]: + """Get all jobs for a specific user.""" + async with session_scope() as session: + stmt = select(Job).where(Job.user_id == user_id) + tasks = (await session.exec(stmt)).all() + + jobs = [] + for task in tasks: + try: + job_state = pickle.loads(task.job_state) # noqa: S301 + job = self._reconstitute_job(job_state) + if pending is not None and job.pending != pending: + continue + self._jobs[job.id] = job + jobs.append(job) + except Exception: # noqa: BLE001 + logger.exception(f"Unable to restore job {task.id}") + await session.delete(task) + + await session.commit() + return jobs + + def _reconstitute_job(self, job_state): + """Reconstitute a job from its serialized state.""" + job_state_dict = job_state if isinstance(job_state, dict) else pickle.loads(job_state) # noqa: S301 + job_state_dict["jobstore"] = self + job = APSJob.__new__(APSJob) + job.__setstate__(job_state_dict) + job._scheduler = self._scheduler + job._jobstore_alias = self._alias + return job diff --git a/src/backend/base/langflow/services/task/scheduler.py b/src/backend/base/langflow/services/task/scheduler.py new file mode 100644 index 000000000000..4e026bfd858e --- /dev/null +++ b/src/backend/base/langflow/services/task/scheduler.py @@ -0,0 +1,363 @@ +import asyncio +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from apscheduler.events import EVENT_ALL_JOBS_REMOVED, EVENT_JOB_ADDED, EVENT_JOB_REMOVED, JobEvent, SchedulerEvent +from apscheduler.job import Job as APSJob +from apscheduler.jobstores.base import ConflictingIdError +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.schedulers.base import ( + EVENT_JOB_MAX_INSTANCES, + EVENT_JOB_SUBMITTED, + EVENT_SCHEDULER_STARTED, + STATE_PAUSED, + STATE_RUNNING, + STATE_STOPPED, + TIMEOUT_MAX, + JobLookupError, + JobSubmissionEvent, + MaxInstancesReachedError, + SchedulerAlreadyRunningError, +) +from apscheduler.util import undefined +from loguru import logger + +if TYPE_CHECKING: + from langflow.services.task.jobstore import AsyncSQLModelJobStore + + +class AsyncScheduler(AsyncIOScheduler): + """An improved version of AsyncIOScheduler that supports async jobstores.""" + + def __init__(self, *args, **kwargs): + self.timezone = timezone.utc + super().__init__(*args, **kwargs) + + async def wakeup(self): + self._stop_timer() + wait_seconds = await self._process_jobs() + self._start_timer(wait_seconds) + + async def start(self, *, paused: bool = False): + """Start the scheduler. + + Args: + paused (bool, optional): If True, start in paused state. Defaults to False. + """ + self._eventloop: asyncio.AbstractEventLoop | None = self._eventloop or asyncio.get_running_loop() + await self._start(paused=paused) + + async def _start(self, *, paused: bool = False): + """Start the configured executors and job stores and begin processing scheduled jobs. + + :param bool paused: if ``True``, don't start job processing until :meth:`resume` is called + :raises SchedulerAlreadyRunningError: if the scheduler is already running + :raises RuntimeError: if running under uWSGI with threads disabled + + """ + if self.state != STATE_STOPPED: # type: ignore[has-type] + raise SchedulerAlreadyRunningError + + self._check_uwsgi() + + with self._executors_lock: + # Create a default executor if nothing else is configured + if "default" not in self._executors: + self.add_executor(self._create_default_executor(), "default") + + # Start all the executors + for alias, executor in self._executors.items(): + executor.start(self, alias) + + with self._jobstores_lock: + # Create a default job store if nothing else is configured + if "default" not in self._jobstores: + self.add_jobstore(self._create_default_jobstore(), "default") + + # Start all the job stores + for alias, store in self._jobstores.items(): + store.start(self, alias) + + # Schedule all pending jobs + for job, jobstore_alias, replace_existing in self._pending_jobs: + self._real_add_job(job, jobstore_alias, replace_existing) + del self._pending_jobs[:] + + self.state = STATE_PAUSED if paused else STATE_RUNNING + self._logger.info("Scheduler started") + self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_STARTED)) + + if not paused: + result = self.wakeup() + if asyncio.iscoroutine(result): + await result + + async def _process_jobs(self): + """Iterates through jobs in every jobstore. + + Starts jobs that are due and figures out how long to wait for the next round. + + If the ``get_due_jobs()`` call raises an exception, a new wakeup is scheduled in at least + ``jobstore_retry_interval`` seconds. + """ + if self.state == STATE_PAUSED: + self._logger.debug("Scheduler is paused -- not processing jobs") + return None + + self._logger.debug("Looking for jobs to run") + now = datetime.now(self.timezone) + next_wakeup_time = None + events = [] + + with self._jobstores_lock: + for jobstore_alias, jobstore in self._jobstores.items(): + try: + due_jobs = jobstore.get_due_jobs(now) + if asyncio.iscoroutine(due_jobs): + due_jobs = await due_jobs + except Exception as e: # noqa: BLE001 + # Schedule a wakeup at least in jobstore_retry_interval seconds + self._logger.warning( + "Error getting due jobs from job store %r: %s", + jobstore_alias, + e, + ) + retry_wakeup_time = now + timedelta(seconds=self.jobstore_retry_interval) + if not next_wakeup_time or next_wakeup_time > retry_wakeup_time: + next_wakeup_time = retry_wakeup_time + + continue + + for job in due_jobs: + # Look up the job's executor + try: + executor = self._lookup_executor(job.executor) + except BaseException: + self._logger.exception( + 'Executor lookup ("%s") failed for job "%s" -- removing it from the ' "job store", + job.executor, + job, + ) + result = self.remove_job(job.id, jobstore_alias) + if asyncio.iscoroutine(result): + await result + continue + + run_times = job._get_run_times(now) + run_times = run_times[-1:] if run_times and job.coalesce else run_times + if run_times: + try: + result = executor.submit_job(job, run_times) + if asyncio.iscoroutine(result): + await result + except MaxInstancesReachedError: + self._logger.warning( + 'Execution of job "%s" skipped: maximum number of running ' "instances reached (%d)", + job, + job.max_instances, + ) + event = JobSubmissionEvent( + EVENT_JOB_MAX_INSTANCES, + job.id, + jobstore_alias, + run_times, + ) + events.append(event) + except BaseException: + self._logger.exception( + 'Error submitting job "%s" to executor "%s"', + job, + job.executor, + ) + else: + event = JobSubmissionEvent(EVENT_JOB_SUBMITTED, job.id, jobstore_alias, run_times) + events.append(event) + + # Update the job if it has a next execution time. + # Otherwise remove it from the job store. + job_next_run = job.trigger.get_next_fire_time(run_times[-1], now) + if job_next_run: + job._modify(next_run_time=job_next_run) + result = jobstore.update_job(job) + if asyncio.iscoroutine(result): + await result + else: + result = self.remove_job(job.id, jobstore_alias) + if asyncio.iscoroutine(result): + await result + + # Set a new next wakeup time if there isn't one yet or + # the jobstore has an even earlier one + jobstore_next_run_time = jobstore.get_next_run_time() + if asyncio.iscoroutine(jobstore_next_run_time): + jobstore_next_run_time = await jobstore_next_run_time + if jobstore_next_run_time and (next_wakeup_time is None or jobstore_next_run_time < next_wakeup_time): + next_wakeup_time = jobstore_next_run_time.astimezone(self.timezone) + + # Dispatch collected events + for event in events: + result = self._dispatch_event(event) + if asyncio.iscoroutine(result): + await result + + # Determine the delay until this method should be called again + if self.state == STATE_PAUSED: + wait_seconds = None + self._logger.debug("Scheduler is paused; waiting until resume() is called") + elif next_wakeup_time is None: + wait_seconds = None + self._logger.debug("No jobs; waiting until a job is added") + else: + now = datetime.now(self.timezone) + wait_seconds = min(max((next_wakeup_time - now).total_seconds(), 0), TIMEOUT_MAX) + self._logger.debug( + "Next wakeup is due at %s (in %f seconds)", + next_wakeup_time, + wait_seconds, + ) + + return wait_seconds + + async def remove_job(self, job_id, jobstore=None): + """Removes a job, preventing it from being run any more. + + :param str|unicode job_id: the identifier of the job + :param str|unicode jobstore: alias of the job store that contains the job + :raises JobLookupError: if the job was not found + + """ + jobstore_alias = None + with self._jobstores_lock: + # Check if the job is among the pending jobs + if self.state == STATE_STOPPED: + for i, (job, alias, _replace_existing) in enumerate(self._pending_jobs): + if job.id == job_id and jobstore in (None, alias): + del self._pending_jobs[i] + jobstore_alias = alias + break + else: + # Otherwise, try to remove it from each store until it succeeds or we run out of + # stores to check + for alias, store in self._jobstores.items(): + if jobstore in (None, alias): + try: + result = store.remove_job(job_id) + if asyncio.iscoroutine(result): + await result + jobstore_alias = alias + break + except JobLookupError: + continue + + if jobstore_alias is None: + raise JobLookupError(job_id) + + # Notify listeners that a job has been removed + event = JobEvent(EVENT_JOB_REMOVED, job_id, jobstore_alias) + self._dispatch_event(event) + + self._logger.info("Removed job %s", job_id) + + async def remove_all_jobs(self, jobstore=None): + """Removes all jobs from the specified job store, or all job stores if none is given. + + :param str|unicode jobstore: alias of the job store + + """ + with self._jobstores_lock: + if self.state == STATE_STOPPED: + if jobstore: + self._pending_jobs = [pending for pending in self._pending_jobs if pending[1] != jobstore] + else: + self._pending_jobs = [] + else: + for alias, store in self._jobstores.items(): + if jobstore in (None, alias): + store.remove_all_jobs() + + self._dispatch_event(SchedulerEvent(EVENT_ALL_JOBS_REMOVED, jobstore)) + + async def _real_add_job(self, job, jobstore_alias, replace_existing): + """Override to make async-compatible.""" + # Fill in undefined values with defaults + replacements = {key: value for key, value in self._job_defaults.items() if not hasattr(job, key)} + + # Calculate the next run time if there is none defined + if not hasattr(job, "next_run_time"): + now = datetime.now(timezone.utc) + replacements["next_run_time"] = job.trigger.get_next_fire_time(None, now) + + # Apply any replacements + job._modify(**replacements) + + # Add the job to the given job store + store: AsyncSQLModelJobStore = self._lookup_jobstore(jobstore_alias) + try: + await store.add_job(job) + except ConflictingIdError: + if replace_existing: + await store.update_job(job) + else: + raise + + # Mark the job as no longer pending + job._jobstore_alias = jobstore_alias + + # Notify listeners that a new job has been added + event = JobEvent(EVENT_JOB_ADDED, job.id, jobstore_alias) + self._dispatch_event(event) + + logger.info(f"Added job {job.name} to job store {jobstore_alias}") + + # Notify the scheduler about the new job + if self.state == STATE_RUNNING: + await self.wakeup() + + async def add_job( + self, + func, + trigger=None, + args=None, + kwargs=None, + id=None, # noqa: A002 + name=None, + misfire_grace_time=undefined, + coalesce=undefined, + max_instances=undefined, + next_run_time=undefined, + jobstore="default", + executor="default", + *, + replace_existing=False, + **trigger_args, + ): + """Add a job to the scheduler. + + Any option that defaults to undefined will be replaced with the corresponding default + value when the job is scheduled. + """ + job_kwargs = { + "trigger": self._create_trigger(trigger, trigger_args), + "executor": executor, + "func": func, + "args": tuple(args) if args is not None else (), + "kwargs": dict(kwargs) if kwargs is not None else {}, + "id": id, + "name": name, + "misfire_grace_time": misfire_grace_time, + "coalesce": coalesce, + "max_instances": max_instances, + "next_run_time": next_run_time, + } + job_kwargs = {key: value for key, value in job_kwargs.items() if value is not undefined} + job = APSJob(self, **job_kwargs) + + # Don't really add jobs to job stores before the scheduler is up and running + with self._jobstores_lock: + if self.state == STATE_STOPPED: + self._pending_jobs.append((job, jobstore, replace_existing)) + logger.info("Adding job tentatively -- it will be properly scheduled when the scheduler starts") + else: + await self._real_add_job(job, jobstore, replace_existing) + + return job diff --git a/src/backend/base/langflow/services/task/service.py b/src/backend/base/langflow/services/task/service.py index b113cdc5dd0c..207d77726208 100644 --- a/src/backend/base/langflow/services/task/service.py +++ b/src/backend/base/langflow/services/task/service.py @@ -1,87 +1,196 @@ from __future__ import annotations -from collections.abc import Callable, Coroutine +import asyncio from typing import TYPE_CHECKING, Any +from uuid import UUID, uuid4 +from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED, JobEvent, JobExecutionEvent +from apscheduler.schedulers.base import SchedulerAlreadyRunningError +from apscheduler.triggers.date import DateTrigger from loguru import logger +from sqlmodel import select from langflow.services.base import Service -from langflow.services.task.backends.anyio import AnyIOBackend -from langflow.services.task.utils import get_celery_worker_status +from langflow.services.database.models.job.model import Job, JobStatus +from langflow.services.deps import session_scope +from langflow.services.task.jobstore import AsyncSQLModelJobStore +from langflow.services.task.scheduler import AsyncScheduler if TYPE_CHECKING: - from langflow.services.settings.service import SettingsService - from langflow.services.task.backends.base import TaskBackend - + from collections.abc import Callable + from datetime import datetime -def check_celery_availability(): - try: - from langflow.worker import celery_app + from apscheduler.job import Job as APSJob - status = get_celery_worker_status(celery_app) - logger.debug(f"Celery status: {status}") - except Exception: # noqa: BLE001 - logger.opt(exception=True).debug("Celery not available") - status = {"availability": None} - return status + from langflow.services.settings.service import SettingsService class TaskService(Service): + """Service for managing tasks and scheduled flows.""" + name = "task_service" def __init__(self, settings_service: SettingsService): self.settings_service = settings_service + self._started = False + self.scheduler: AsyncScheduler | None = None + self.job_store: AsyncSQLModelJobStore | None = None + + async def setup(self): + """Initialize the scheduler.""" + self.scheduler = await asyncio.to_thread(AsyncScheduler) + self.job_store = AsyncSQLModelJobStore() + self.scheduler.add_jobstore(self.job_store, "default") + + # Add event listeners + self.scheduler.add_listener(self._handle_job_executed, EVENT_JOB_EXECUTED) + self.scheduler.add_listener(self._handle_job_error, EVENT_JOB_ERROR) + self._started = False + + async def _ensure_scheduler_running(self): + """Ensure the scheduler is running.""" + if not self._started: + if self.scheduler is None: + await self.setup() + try: + await self.scheduler.start(paused=False) + self._started = True + except SchedulerAlreadyRunningError: + pass + + async def _handle_job_executed(self, event: JobExecutionEvent) -> None: + """Handle job executed event.""" + await self._ensure_scheduler_running() + async with session_scope() as session: + stmt = select(Job).where(Job.id == event.job_id) + job = (await session.exec(stmt)).first() + if job: + job.status = JobStatus.COMPLETED + job.result = event.retval if isinstance(event.retval, dict) else {"output": str(event.retval)} + session.add(job) + await session.commit() + + async def _handle_job_error(self, event: JobEvent) -> None: + """Handle job error event.""" + await self._ensure_scheduler_running() + async with session_scope() as session: + stmt = select(Job).where(Job.id == event.job_id) + job = (await session.exec(stmt)).first() + if job: + job.status = JobStatus.FAILED + job.error = str(event.exception) + session.add(job) + await session.commit() + + async def create_job( + self, + task_func: str | Callable[..., Any], + run_at: datetime | None = None, + name: str | None = None, + args: list | None = None, + kwargs: dict | None = None, + ) -> str: + """Create a new job.""" + await self._ensure_scheduler_running() + if self.scheduler is None or self.job_store is None: + msg = "Scheduler or job store not initialized" + logger.error(msg) + raise ValueError(msg) + task_id = str(uuid4()) try: - if self.settings_service.settings.celery_enabled: - status = check_celery_availability() - - use_celery = status.get("availability") is not None - else: - use_celery = False - except ImportError: - use_celery = False - - self.use_celery = use_celery - self.backend = self.get_backend() - - @property - def backend_name(self) -> str: - return self.backend.name - - def get_backend(self) -> TaskBackend: - if self.use_celery: - from langflow.services.task.backends.celery import CeleryBackend - - logger.debug("Using Celery backend") - return CeleryBackend() - logger.debug("Using AnyIO backend") - return AnyIOBackend() - - # In your TaskService class - async def launch_and_await_task( + trigger = DateTrigger(run_date=run_at) if run_at is not None else None + + await self.scheduler.add_job( + task_func, + trigger=trigger, + args=args or [], + kwargs=kwargs or {}, + id=task_id, + name=name or f"task_{task_id}", + misfire_grace_time=None, # Run immediately when missed + coalesce=True, # Only run once if multiple are due + max_instances=1, # Only one instance at a time + replace_existing=True, + ) + + except Exception as exc: + logger.error(f"Error creating task: {exc}") + raise + return task_id + + async def get_job(self, job_id: str, user_id: UUID | None = None) -> APSJob | None: + """Get job information.""" + await self._ensure_scheduler_running() + if self.job_store is None: + msg = "Job store not initialized" + logger.error(msg) + raise ValueError(msg) + try: + job = await self.job_store.lookup_job(job_id, user_id) + except Exception as exc: + logger.error(f"Error getting job {job_id}: {exc}") + raise + return job + + async def cancel_job(self, job_id: str, user_id: UUID | None = None) -> bool: + """Cancel a job.""" + await self._ensure_scheduler_running() + if self.scheduler is None or self.job_store is None: + msg = "Scheduler or job store not initialized" + logger.error(msg) + raise ValueError(msg) + try: + # Get the job from jobstore + job = await self.job_store.lookup_job(job_id, user_id) + if not job: + return False + + # Remove from scheduler if not yet executed + scheduler_job = await self.scheduler.get_job(job_id) + if scheduler_job is not None: + await self.scheduler.remove_job(job_id) + + except Exception as exc: + logger.error(f"Error cancelling job {job_id}: {exc}") + raise + return True + + async def get_jobs( self, - task_func: Callable[..., Any], - *args: Any, - **kwargs: Any, - ) -> Any: - if not self.use_celery: - return None, await task_func(*args, **kwargs) - if not hasattr(task_func, "apply"): - msg = f"Task function {task_func} does not have an apply method" + user_id: UUID | None = None, + pending: bool | None = None, + ) -> list[dict]: + """Get tasks with optional filters.""" + await self._ensure_scheduler_running() + if self.job_store is None: + msg = "Job store not initialized" + logger.error(msg) + raise ValueError(msg) + try: + if user_id: + return await self.job_store.get_user_jobs(user_id, pending) + # For other filters, we'll need to implement corresponding methods in the jobstore + # For now, we'll just get all jobs if no user_id is provided + return await self.job_store.get_all_jobs() + except Exception as exc: + logger.error(f"Error getting tasks: {exc}") + raise + + async def get_user_jobs(self, user_id: UUID) -> list[dict]: + """Get all jobs for a specific user.""" + await self._ensure_scheduler_running() + if self.job_store is None: + msg = "Job store not initialized" + logger.error(msg) raise ValueError(msg) - task = task_func.apply(args=args, kwargs=kwargs) - - result = task.get() - # if result is coroutine - if isinstance(result, Coroutine): - result = await result - return task.id, result - - async def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: - logger.debug(f"Launching task {task_func} with args {args} and kwargs {kwargs}") - logger.debug(f"Using backend {self.backend}") - task = self.backend.launch_task(task_func, *args, **kwargs) - return await task if isinstance(task, Coroutine) else task - - def get_task(self, task_id: str) -> Any: - return self.backend.get_task(task_id) + try: + return await self.job_store.get_user_jobs(user_id) + except Exception as exc: + logger.error(f"Error getting jobs for user {user_id}: {exc}") + raise + + async def stop(self): + """Stop the scheduler.""" + if self.scheduler.running: + self.scheduler.shutdown() + logger.info("Task scheduler stopped") diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index 9bf41e3531a6..b8bf654a23db 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -234,6 +234,8 @@ async def clean_vertex_builds(settings_service: SettingsService, session: AsyncS async def initialize_services(*, fix_migration: bool = False) -> None: """Initialize all the services needed.""" + from langflow.services.manager import service_manager + # Test cache connection get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory()) # Setup the superuser @@ -249,3 +251,4 @@ async def initialize_services(*, fix_migration: bool = False) -> None: logger.warning(f"Error assigning orphaned flows to the superuser: {exc!s}") await clean_transactions(settings_service, session) await clean_vertex_builds(settings_service, session) + await service_manager.setup() diff --git a/src/backend/base/pyproject.toml b/src/backend/base/pyproject.toml index 57bc691e7f04..0ef1c646df45 100644 --- a/src/backend/base/pyproject.toml +++ b/src/backend/base/pyproject.toml @@ -173,6 +173,7 @@ dependencies = [ "defusedxml>=0.7.1,<1.0.0", "pypdf~=5.1.0", "validators>=0.34.0", + "apscheduler==3.11.0", ] [project.urls] diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index b854d682015b..a3f0d75cccaa 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -57,6 +57,10 @@ def blockbuster(request): "os.path.abspath", ]: bb.functions[func].can_block_in("settings/service.py", "initialize") + if func == "os.path.abspath": + bb.functions[func].can_block_in( + "pydevd/pydevd_file_utils.py", "get_abs_path_real_path_and_base_from_file" + ) for func in [ "io.BufferedReader.read", "io.TextIOWrapper.read", diff --git a/src/backend/tests/data/WebhookTest.json b/src/backend/tests/data/WebhookTest.json index 71ca54183842..4f3c0c24b7b1 100644 --- a/src/backend/tests/data/WebhookTest.json +++ b/src/backend/tests/data/WebhookTest.json @@ -21,7 +21,7 @@ "list": false, "show": true, "multiline": true, - "value": "# from langflow.field_typing import Data\nfrom langflow.custom import Component\nfrom langflow.io import StrInput\nfrom langflow.schema import Data\nfrom langflow.io import Output\nfrom pathlib import Path\nimport aiofiles\n\nclass CustomComponent(Component):\n display_name = \"Async Component\"\n description = \"Use as a template to create your own component.\"\n documentation: str = \"http://docs.langflow.org/components/custom\"\n icon = \"custom_components\"\n\n inputs = [\n StrInput(name=\"input_value\", display_name=\"Input Value\", value=\"Hello, World!\", input_types=[\"Data\"]),\n ]\n\n outputs = [\n Output(display_name=\"Output\", name=\"output\", method=\"build_output\"),\n ]\n\n async def build_output(self) -> Data:\n if isinstance(self.input_value, Data):\n data = self.input_value\n else:\n data = Data(value=self.input_value)\n \n if \"path\" in data:\n path = self.resolve_path(data.path)\n path_obj = Path(path)\n async with aiofiles.open(path, \"w\") as f:\n await f.write(data.model_dump())\n \n self.status = data\n return data", + "value": "# from langflow.field_typing import Data\nfrom langflow.custom import Component\nfrom langflow.io import StrInput\nfrom langflow.schema import Data\nfrom langflow.io import Output\nfrom pathlib import Path\nimport aiofiles\n\nclass CustomComponent(Component):\n display_name = \"Async Component\"\n description = \"Use as a template to create your own component.\"\n documentation: str = \"http://docs.langflow.org/components/custom\"\n icon = \"custom_components\"\n\n inputs = [\n StrInput(name=\"input_value\", display_name=\"Input Value\", value=\"Hello, World!\", input_types=[\"Data\"]),\n ]\n\n outputs = [\n Output(display_name=\"Output\", name=\"output\", method=\"build_output\"),\n ]\n\n async def build_output(self) -> Data:\n if isinstance(self.input_value, Data):\n data = self.input_value\n else:\n data = Data(value=self.input_value)\n \n if \"path\" in data:\n path = self.resolve_path(data.path)\n path_obj = Path(path)\n async with aiofiles.open(path, \"w\") as f:\n await f.write(data.model_dump_json())\n \n self.status = data\n return data", "fileTypes": [], "file_path": "", "password": false, diff --git a/src/backend/tests/unit/api/v1/test_tasks.py b/src/backend/tests/unit/api/v1/test_tasks.py new file mode 100644 index 000000000000..7108b62cf19d --- /dev/null +++ b/src/backend/tests/unit/api/v1/test_tasks.py @@ -0,0 +1,299 @@ +import asyncio +from uuid import uuid4 + +import pytest +from fastapi import status +from httpx import AsyncClient +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.user.model import User, UserRead +from langflow.services.deps import get_db_service +from sqlalchemy.orm import selectinload +from sqlmodel import select + +from tests.conftest import _delete_transactions_and_vertex_builds + + +@pytest.fixture +def create_task_request(): + """Fixture for creating a task request.""" + return { + "name": "Test Task", + "input_request": { + "input_value": "test input", + "input_type": "text", + "output_type": "text", + "tweaks": {}, + }, + } + + +@pytest.fixture +async def anoter_active_user(client): # noqa: ARG001 + db_manager = get_db_service() + async with db_manager.with_session() as session: + user = User( + username="another_active_user", + password=get_password_hash("testpassword"), + is_active=True, + is_superuser=False, + ) + stmt = select(User).where(User.username == user.username) + if active_user := (await session.exec(stmt)).first(): + user = active_user + else: + session.add(user) + await session.commit() + await session.refresh(user) + user = UserRead.model_validate(user, from_attributes=True) + yield user + # Clean up + # Now cleanup transactions, vertex_build + async with db_manager.with_session() as session: + user = await session.get(User, user.id, options=[selectinload(User.flows)]) + await _delete_transactions_and_vertex_builds(session, user.flows) + await session.delete(user) + + await session.commit() + + +@pytest.fixture +async def another_user_headers(client, anoter_active_user): + login_data = {"username": anoter_active_user.username, "password": "testpassword"} + response = await client.post("api/v1/login", data=login_data) + assert response.status_code == 200 + tokens = response.json() + a_token = tokens["access_token"] + return {"Authorization": f"Bearer {a_token}"} + + +async def test_create_task(client: AsyncClient, logged_in_headers, simple_api_test, create_task_request): + """Test creating a task.""" + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request + ) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to create task. Status: {response.status_code}. Response: {response.text}" + task_id = response.json() + assert isinstance(task_id, str), f"Expected task_id to be a string, got {type(task_id)}" + + # Verify task was created by getting it + response = await client.get(f"/api/v1/tasks/{task_id}", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to get created task. Status: {response.status_code}. Response: {response.text}" + task = response.json() + assert task["id"] == task_id, f"Task ID mismatch. Expected: {task_id}, got: {task['id']}" + assert ( + task["name"] == create_task_request["name"] + ), f"Task name mismatch. Expected: {create_task_request['name']}, got: {task['name']}" + assert task["pending"] is False, f"Expected task to not be pending, got: {task['pending']}" + + +async def test_create_task_invalid_flow(client: AsyncClient, logged_in_headers, create_task_request): + """Test creating a task with an invalid flow ID.""" + some_flow_id = uuid4() + response = await client.post(f"/api/v1/tasks/{some_flow_id}", headers=logged_in_headers, json=create_task_request) + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ), f"Expected 404 error for invalid flow ID. Got: {response.status_code}. Response: {response.text}" + + +async def test_get_task_not_found(client: AsyncClient, logged_in_headers): + """Test getting a non-existent task.""" + response = await client.get("/api/v1/tasks/nonexistent", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ), f"Expected 404 for non-existent task. Got: {response.status_code}. Response: {response.text}" + + +async def test_get_tasks(client: AsyncClient, logged_in_headers, simple_api_test, create_task_request): + """Test getting all tasks.""" + # Create a task first + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request + ) + response.json() + + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to create task. Status: {response.status_code}. Response: {response.text}" + # Get all tasks + response = await client.get("/api/v1/tasks/", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to get tasks. Status: {response.status_code}. Response: {response.text}" + tasks = response.json() + assert isinstance(tasks, list), f"Expected tasks to be a list, got {type(tasks)}" + assert len(tasks) >= 1, f"Expected at least 1 task, got {len(tasks)}" + assert all(isinstance(task["id"], str) for task in tasks), "Some task IDs are not strings" + + +async def test_get_tasks_with_status(client: AsyncClient, logged_in_headers, simple_api_test, create_task_request): + """Test getting tasks filtered by status.""" + # Create a task first + await client.post(f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request) + + # Get all tasks + response = await client.get("/api/v1/tasks/", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to get tasks. Status: {response.status_code}. Response: {response.text}" + tasks = response.json() + assert isinstance(tasks, list), f"Expected tasks to be a list, got {type(tasks)}" + assert len(tasks) > 0, "Expected at least one task" + + +async def test_cancel_task(client: AsyncClient, logged_in_headers, simple_api_test, create_task_request): + """Test canceling a task.""" + # Create a task first + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request + ) + task_id = response.json() + + # Cancel the task + response = await client.delete(f"/api/v1/tasks/{task_id}", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to cancel task. Status: {response.status_code}. Response: {response.text}" + assert response.json() is True, f"Expected True response for task cancellation, got: {response.json()}" + + # Verify task was canceled by trying to get it + response = await client.get(f"/api/v1/tasks/{task_id}", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ), f"Expected task to be not found after cancellation. Status: {response.status_code}. Response: {response.text}" + + +async def test_cancel_nonexistent_task(client: AsyncClient, logged_in_headers): + """Test canceling a non-existent task.""" + response = await client.delete("/api/v1/tasks/nonexistent", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ), f"Expected 404 for non-existent task cancellation. Got: {response.status_code}. Response: {response.text}" + + +async def test_create_task_invalid_request(client: AsyncClient, logged_in_headers, simple_api_test): + """Test creating a task with invalid request data.""" + invalid_request = { + "name": "Test Task", + # Missing required input_request field + } + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=invalid_request + ) + assert ( + response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + ), f"Expected 422 for invalid request. Got: {response.status_code}. Response: {response.text}" + + +async def test_task_access_control( + client: AsyncClient, logged_in_headers, another_user_headers, simple_api_test, create_task_request +): + """Test that a user cannot access another user's tasks.""" + # assert headers are different + assert logged_in_headers["Authorization"] != another_user_headers["Authorization"], "Headers are the same" + # User A creates a task + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request + ) + assert response.status_code == status.HTTP_200_OK, f"Failed to create task. Response: {response.text}" + task_id = response.json() + + # User B tries to access User A's task + response = await client.get(f"/api/v1/tasks/{task_id}", headers=another_user_headers) + assert response.status_code == status.HTTP_404_NOT_FOUND, ( + f"Expected 404 when accessing another user's task. " + f"Got status {response.status_code}. Response: {response.text}" + ) + + # User B tries to cancel User A's task + response = await client.delete(f"/api/v1/tasks/{task_id}", headers=another_user_headers) + assert response.status_code == status.HTTP_404_NOT_FOUND, ( + f"Expected 404 when canceling another user's task. " + f"Got status {response.status_code}. Response: {response.text}" + ) + + +async def test_create_multiple_tasks(client: AsyncClient, logged_in_headers, simple_api_test, create_task_request): + """Test creating multiple tasks concurrently.""" + num_tasks = 5 + tasks = [ + client.post(f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request) + for _ in range(num_tasks) + ] + responses = await asyncio.gather(*tasks) + + for i, response in enumerate(responses): + assert response.status_code == status.HTTP_200_OK, ( + f"Failed to create task {i + 1}/{num_tasks}. " f"Status: {response.status_code}. Response: {response.text}" + ) + task_id = response.json() + assert isinstance(task_id, str), f"Task {i + 1}/{num_tasks}: Expected string ID, got {type(task_id)}" + + # Verify all tasks were created + response = await client.get("/api/v1/tasks/", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to get tasks list. Status: {response.status_code}. Response: {response.text}" + tasks = response.json() + assert len(tasks) >= num_tasks, ( + f"Expected at least {num_tasks} tasks, but found {len(tasks)}. " f"Some tasks may have failed to create." + ) + + +async def test_task_status_transitions(client: AsyncClient, logged_in_headers, simple_api_test, create_task_request): + """Test task status transitions.""" + # Create a task + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=create_task_request + ) + assert response.status_code == status.HTTP_200_OK, ( + f"Failed to create task for status test. " f"Status: {response.status_code}. Response: {response.text}" + ) + task_id = response.json() + + # Get task status + response = await client.get(f"/api/v1/tasks/{task_id}", headers=logged_in_headers) + assert ( + response.status_code == status.HTTP_200_OK + ), f"Failed to get task status. Status: {response.status_code}. Response: {response.text}" + task = response.json() + + # Verify task has a valid pending status + assert "pending" in task, f"Task response missing 'pending' field. Response: {task}" + assert isinstance(task["pending"], bool), ( + f"Expected boolean for task.pending, got {type(task['pending'])}. " f"Value: {task['pending']}" + ) + + +async def test_create_task_malicious_input(client: AsyncClient, logged_in_headers, simple_api_test): + """Test task creation with potentially malicious input.""" + malicious_request = { + "name": "'; DROP TABLE tasks; --", + "input_request": { + "input_value": "", + "input_type": "text", + "output_type": "text", + "tweaks": {"malicious": "'; DROP TABLE users; --"}, + }, + } + + response = await client.post( + f"/api/v1/tasks/{simple_api_test['id']}", headers=logged_in_headers, json=malicious_request + ) + + # Should either sanitize and accept (200) or reject invalid input (422) + assert response.status_code in [status.HTTP_200_OK, status.HTTP_422_UNPROCESSABLE_ENTITY], ( + f"Expected status 200 or 422 for malicious input, got {response.status_code}. " f"Response: {response.text}" + ) + + if response.status_code == status.HTTP_200_OK: + task_id = response.json() + # Verify the task was created and can be retrieved + response = await client.get(f"/api/v1/tasks/{task_id}", headers=logged_in_headers) + assert response.status_code == status.HTTP_200_OK, ( + f"Failed to retrieve task created with sanitized malicious input. " + f"Status: {response.status_code}. Response: {response.text}" + ) diff --git a/src/backend/tests/unit/models/job/__init__.py b/src/backend/tests/unit/models/job/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/backend/tests/unit/models/job/test_job_models.py b/src/backend/tests/unit/models/job/test_job_models.py new file mode 100644 index 000000000000..ee32078f975d --- /dev/null +++ b/src/backend/tests/unit/models/job/test_job_models.py @@ -0,0 +1,590 @@ +"""Tests for the Job model. + +Note: These tests intentionally use pickle.loads to test serialization behavior. +The S301 warnings are suppressed because this is a test file and we're testing +the pickle functionality itself, not using it to deserialize untrusted data in +production code. +""" + +import threading +import uuid +from contextlib import suppress +from datetime import datetime, timezone + +import pytest +from langflow.services.database.models.job.model import Job, JobRead, JobStatus + + +def test_job_status_enum(): + """Test JobStatus enum values.""" + assert JobStatus.PENDING == "PENDING" + assert JobStatus.RUNNING == "RUNNING" + assert JobStatus.COMPLETED == "COMPLETED" + assert JobStatus.FAILED == "FAILED" + assert JobStatus.CANCELLED == "CANCELLED" + + +def test_create_job(): + """Test creating a job with required fields.""" + job_id = "test-job-1" + flow_id = uuid.uuid4() + user_id = uuid.uuid4() + name = "Test Job" + + job = Job( + id=job_id, + flow_id=flow_id, + user_id=user_id, + name=name, + ) + + assert job.id == job_id + assert job.flow_id == flow_id + assert job.user_id == user_id + assert job.name == name + assert job.status == JobStatus.PENDING + assert job.is_active is True + assert job.result is None + assert job.error is None + assert job.job_state is None + assert job.next_run_time is None + + +def test_create_job_with_all_fields(): + """Test creating a job with all fields.""" + job_id = "test-job-2" + flow_id = uuid.uuid4() + user_id = uuid.uuid4() + name = "Test Job Complete" + next_run_time = datetime.now(timezone.utc) + job_state = b"serialized_state" + result = {"output": "test_result"} + + job = Job( + id=job_id, + flow_id=flow_id, + user_id=user_id, + name=name, + next_run_time=next_run_time, + job_state=job_state, + status=JobStatus.RUNNING, + result=result, + error="No error", + is_active=False, + ) + + assert job.id == job_id + assert job.flow_id == flow_id + assert job.user_id == user_id + assert job.name == name + assert job.next_run_time == next_run_time + assert job.job_state == job_state + assert job.status == JobStatus.RUNNING + assert job.result == result + assert job.error == "No error" + assert job.is_active is False + + +def test_job_read_model(): + """Test JobRead model creation and field mapping.""" + job_id = "test-job-3" + flow_id = uuid.uuid4() + user_id = uuid.uuid4() + name = "Test Job Read" + created_at = datetime.now(timezone.utc) + updated_at = datetime.now(timezone.utc) + result = {"status": "success"} + + job_read = JobRead( + id=job_id, + flow_id=flow_id, + user_id=user_id, + name=name, + status=JobStatus.COMPLETED, + is_active=True, + created_at=created_at, + updated_at=updated_at, + job_state=None, + next_run_time=None, + result=result, + ) + + assert job_read.id == job_id + assert job_read.flow_id == flow_id + assert job_read.user_id == user_id + assert job_read.name == name + assert job_read.status == JobStatus.COMPLETED + assert job_read.is_active is True + assert job_read.created_at == created_at + assert job_read.updated_at == updated_at + assert job_read.result == result + + +def test_job_status_validation(): + """Test job status field behavior.""" + # Test default status + job = Job( + id="test-job-4", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Status Test Job", + ) + assert job.status == JobStatus.PENDING + + # Test setting valid status + job.status = JobStatus.RUNNING + assert job.status == JobStatus.RUNNING + + # Test setting raw string value + job.status = "COMPLETED" + assert job.status == "COMPLETED" + + # Test that invalid status is accepted (no validation at model level) + job.status = "INVALID_STATUS" + assert job.status == "INVALID_STATUS" + + +def test_job_status_transitions(): + """Test typical job status transitions.""" + job = Job( + id="test-job-8", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Status Transition Test", + ) + + # Test initial state + assert job.status == JobStatus.PENDING + + # Test running transition + job.status = JobStatus.RUNNING + assert job.status == JobStatus.RUNNING + + # Test completion transition + job.status = JobStatus.COMPLETED + assert job.status == JobStatus.COMPLETED + + # Test failure transition + job = Job( + id="test-job-9", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Failed Job Test", status=JobStatus.RUNNING + ) + job.status = JobStatus.FAILED + assert job.status == JobStatus.FAILED + + # Test cancellation + job = Job( + id="test-job-10", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Cancelled Job Test", + status=JobStatus.RUNNING, + ) + job.status = JobStatus.CANCELLED + assert job.status == JobStatus.CANCELLED + + +def test_job_timestamps(): + """Test that created_at and updated_at are properly set.""" + job = Job( + id="test-job-5", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Timestamp Test Job", + # Explicitly set timestamps for testing + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + assert isinstance(job.created_at, datetime) + assert isinstance(job.updated_at, datetime) + assert job.created_at.tzinfo is not None # Ensure timezone is set + assert job.updated_at.tzinfo is not None # Ensure timezone is set + + +def test_job_state_serialization(): + """Test job state serialization with bytes.""" + job_state = b"serialized_state_data" + job = Job(id="test-job-6", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="State Test Job", job_state=job_state) + + assert isinstance(job.job_state, bytes) + assert job.job_state == job_state + + +def test_job_result_json(): + """Test job result JSON field.""" + result_data = {"output": "test output", "metrics": {"time": 1.23}, "nested": {"key": ["value1", "value2"]}} + job = Job(id="test-job-7", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Result Test Job", result=result_data) + + assert job.result == result_data + assert isinstance(job.result, dict) + + +def test_job_result_edge_cases(): + """Test edge cases for job result field.""" + # Test non-dict return value (service converts to dict) + job = Job( + id="test-job-11", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="String Result Job", + result={"output": "plain string result"}, + ) + assert job.result == {"output": "plain string result"} + + # Test complex nested result + complex_result = { + "output": {"text": "Generated text", "tokens": 150, "model": "gpt-3.5-turbo"}, + "metrics": {"time_taken": 2.5, "tokens_per_second": 60, "cost": 0.002}, + "metadata": { + "version": "1.0", + "timestamp": "2024-01-01T00:00:00Z", + "settings": {"temperature": 0.7, "max_tokens": 200}, + }, + } + job = Job( + id="test-job-12", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Complex Result Job", result=complex_result + ) + assert job.result == complex_result + assert isinstance(job.result["output"], dict) + assert isinstance(job.result["metrics"], dict) + assert isinstance(job.result["metadata"], dict) + + # Test empty result + job = Job(id="test-job-13", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Empty Result Job", result={}) + assert job.result == {} + + +def test_job_error_edge_cases(): + """Test edge cases for job error field.""" + # Test long error message + long_error = "Error: " + "very long error message " * 100 + job = Job( + id="test-job-14", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Long Error Job", + error=long_error, + status=JobStatus.FAILED, + ) + assert job.error == long_error + assert job.status == JobStatus.FAILED + + # Test error with special characters + special_error = "Error: Something went wrong!\n\tDetails: {'key': 'value'}\n\tStack trace: [...]" + job = Job( + id="test-job-15", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Special Error Job", + error=special_error, + status=JobStatus.FAILED, + ) + assert job.error == special_error + assert job.status == JobStatus.FAILED + + # Test empty error + job = Job( + id="test-job-16", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Empty Error Job", + error="", + status=JobStatus.FAILED, + ) + assert job.error == "" + assert job.status == JobStatus.FAILED + + +def test_job_state_edge_cases(): + """Test edge cases for job state field.""" + # Test large job state + large_state = b"x" * 1000000 # 1MB of data + job = Job( + id="test-job-17", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Large State Job", job_state=large_state + ) + assert job.job_state == large_state + assert len(job.job_state) == 1000000 + + # Test job state with special bytes + special_state = bytes([0, 1, 2, 3, 255, 254, 253, 252]) + job = Job( + id="test-job-18", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Special State Job", job_state=special_state + ) + assert job.job_state == special_state + assert len(job.job_state) == 8 + + # Test empty job state + job = Job(id="test-job-19", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Empty State Job", job_state=b"") + assert job.job_state == b"" + assert len(job.job_state) == 0 + + +def test_job_state_result_error_combinations(): + """Test various combinations of job state, result, and error fields.""" + # Test job with all fields populated + job = Job( + id="test-job-20", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Complete Job", + job_state=b"state data", + result={"output": "success"}, + error=None, + status=JobStatus.COMPLETED, + ) + assert job.job_state == b"state data" + assert job.result == {"output": "success"} + assert job.error is None + assert job.status == JobStatus.COMPLETED + + # Test failed job with error but no result + job = Job( + id="test-job-21", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Failed Job", + job_state=b"state data", + result=None, + error="Task failed successfully", + status=JobStatus.FAILED, + ) + assert job.job_state == b"state data" + assert job.result is None + assert job.error == "Task failed successfully" + assert job.status == JobStatus.FAILED + + # Test cancelled job + job = Job( + id="test-job-22", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Cancelled Job", + job_state=None, + result=None, + error=None, + status=JobStatus.CANCELLED, + ) + assert job.job_state is None + assert job.result is None + assert job.error is None + assert job.status == JobStatus.CANCELLED + + +# Test classes need to be at module level for pickling +class CustomReduceObject: + def __init__(self, value): + self.value = value + + def __reduce__(self): + return (self.__class__, (self.value,)) + + +class BrokenReduceObject: + def __reduce__(self): + return (self.__class__, ()) # Missing required arguments + + +class ObjectWithLock: + """Test class with a lock attribute.""" + + def __init__(self): + self.lock = threading.Lock() + self.data = {"key": "value"} + + +def unpickleable_func(): + """Test function at module level.""" + + +def test_job_state_serialization_edge_cases(): + """Test problematic serialization cases for job_state.""" + import pickle + import types + from io import StringIO + + # Test corrupted pickle data + corrupted_pickle = b"invalid pickle data" + job = Job( + id="test-job-24", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name="Corrupted Pickle Job", + job_state=corrupted_pickle, + ) + assert job.job_state == corrupted_pickle + with pytest.raises((pickle.UnpicklingError, EOFError)): + pickle.loads(job.job_state) # noqa: S301 + + # Test recursive data structure + recursive_dict = {} + recursive_dict["self"] = recursive_dict + # Python's pickle can handle recursive structures! + job_state = pickle.dumps(recursive_dict) + job = Job( + id="test-job-25", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Recursive Data Job", job_state=job_state + ) + unpickled = pickle.loads(job.job_state) # noqa: S301 + assert unpickled["self"] is unpickled # Verify recursion is preserved + + # Test extremely nested data that might exceed pickle's recursion limit + deep_list = [] + current = deep_list + for _ in range(2000): # Create an extremely deeply nested list + current.append([]) + current = current[0] + + with pytest.raises((RecursionError, pickle.PicklingError)): + pickle.dumps(deep_list, protocol=pickle.HIGHEST_PROTOCOL) + + # Test lambda function that can't be pickled + # Note: Lambda functions raise AttributeError when trying to pickle + with pytest.raises(AttributeError): + pickle.dumps(lambda x: x) + + # Test file-like objects + text_io = StringIO("test data") # StringIO can be pickled + job_state = pickle.dumps(text_io) + job = Job(id="test-job-28", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Text IO Job", job_state=job_state) + assert isinstance(pickle.loads(job.job_state), StringIO) # noqa: S301 + + # Test generator function that can't be pickled + def generator_func(): + yield "test" + + gen = generator_func() + with pytest.raises(TypeError): # Python raises TypeError for generators + pickle.dumps(gen) + + # Test module object that shouldn't be pickled + with pytest.raises(TypeError): # Python raises TypeError for modules + pickle.dumps(types) + + +def test_job_state_with_custom_objects(): + """Test job state with custom objects that implement __reduce__.""" + import pickle + + # Test pickling custom object that implements __reduce__ + custom_obj = CustomReduceObject("test value") + job_state = pickle.dumps(custom_obj) + job = Job( + id="test-job-31", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Custom Reduce Job", job_state=job_state + ) + assert job.job_state == job_state + unpickled_obj = pickle.loads(job.job_state) # noqa: S301 + assert isinstance(unpickled_obj, CustomReduceObject) + assert unpickled_obj.value == "test value" + + # Test object with broken __reduce__ implementation + broken_obj = BrokenReduceObject() + job_state = pickle.dumps(broken_obj) + job = Job( + id="test-job-32", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Broken Reduce Job", job_state=job_state + ) + # The object will be pickled but will be empty when unpickled + unpickled = pickle.loads(job.job_state) # noqa: S301 + assert isinstance(unpickled, BrokenReduceObject) + + +def test_job_state_security_risks(): + """Test potential security risks with job state serialization.""" + import os + import pickle + import sys + + # Test attempting to pickle a dangerous command + class DangerousPickle: + def __reduce__(self): + return (os.system, ('echo "DANGER"',)) + + # Python's pickle will allow this dangerous payload! + dangerous_obj = DangerousPickle() + job_state = pickle.dumps(dangerous_obj) + job = Job( + id="test-job-33", flow_id=uuid.uuid4(), user_id=uuid.uuid4(), name="Security Risk Job", job_state=job_state + ) + + # This is why it's crucial to NEVER unpickle untrusted data! + assert job.job_state == job_state + # We won't unpickle it as it would execute the command + + # Test attempting to pickle system objects + with pytest.raises(TypeError): # Python raises TypeError for modules + pickle.dumps(sys.modules["os"]) + + +def test_job_state_pickle_protocol_compatibility(): + """Test job state compatibility with different pickle protocols.""" + import pickle + + test_data = {"name": "test", "value": 123, "nested": {"key": "value"}} + + # Test all available pickle protocols + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + job_state = pickle.dumps(test_data, protocol=protocol) + job = Job( + id=f"test-job-protocol-{protocol}", + flow_id=uuid.uuid4(), + user_id=uuid.uuid4(), + name=f"Protocol {protocol} Job", + job_state=job_state, + ) + assert job.job_state == job_state + # Verify we can unpickle the data + unpickled_data = pickle.loads(job.job_state) # noqa: S301 + assert unpickled_data == test_data + + +def test_job_state_thread_synchronization(): + """Test serialization of thread synchronization primitives.""" + import asyncio + import multiprocessing + import pickle + import threading + from concurrent.futures import ThreadPoolExecutor + + # Test threading.Lock + lock = threading.Lock() + with pytest.raises(TypeError): # Python raises TypeError for thread locks + pickle.dumps(lock) + + # Test threading.RLock + rlock = threading.RLock() + with pytest.raises(TypeError): + pickle.dumps(rlock) + + # Test threading.Event + event = threading.Event() + with pytest.raises(TypeError): + pickle.dumps(event) + + # Test threading.Condition + condition = threading.Condition() + with pytest.raises(TypeError): + pickle.dumps(condition) + + # Test multiprocessing.Lock + mp_lock = multiprocessing.Lock() + with pytest.raises(RuntimeError): # Multiprocessing raises RuntimeError + pickle.dumps(mp_lock) + + # Test ThreadPoolExecutor + executor = ThreadPoolExecutor(max_workers=1) + with pytest.raises(TypeError): + pickle.dumps(executor) + executor.shutdown() + + # Test asyncio.Lock - Note: asyncio locks can actually be pickled in some cases + async_lock = asyncio.Lock() + with suppress(TypeError, AttributeError): + pickle.dumps(async_lock) + + # Test object containing a lock + obj_with_lock = ObjectWithLock() + with pytest.raises(TypeError): + pickle.dumps(obj_with_lock) + + # Test dictionary containing a lock + dict_with_lock = {"lock": threading.Lock(), "data": {"key": "value"}} + with pytest.raises(TypeError): + pickle.dumps(dict_with_lock) diff --git a/src/backend/tests/unit/services/tasks/__init__.py b/src/backend/tests/unit/services/tasks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/backend/tests/unit/services/tasks/test_tasks_service.py b/src/backend/tests/unit/services/tasks/test_tasks_service.py new file mode 100644 index 000000000000..6ecf4bff911d --- /dev/null +++ b/src/backend/tests/unit/services/tasks/test_tasks_service.py @@ -0,0 +1,183 @@ +import asyncio +import datetime + +import pytest +from apscheduler.events import JobExecutionEvent +from langflow.services.database.models.job.model import Job, JobStatus +from langflow.services.deps import get_settings_service, session_scope +from langflow.services.task.service import TaskService +from sqlmodel import select + + +@pytest.fixture +async def task_service(): + """Create a task service for testing.""" + service = TaskService(get_settings_service()) + await service.setup() + yield service + await service.teardown() + + +# Create a mock task function that has all the kwargs +def mock_task_func(**kwargs): + return kwargs + + +@pytest.fixture +async def sample_job(task_service: TaskService, active_user, simple_api_test): + """Create a sample job for testing.""" + task_id = await task_service.create_job( + task_func=mock_task_func, + run_at=None, + name="Test Task", + kwargs={ + "flow": simple_api_test, + "input_request": { + "input_value": "test input", + "input_type": "text", + "output_type": "text", + "tweaks": {}, + }, + "stream": False, + "api_key_user": active_user, + }, + ) + async with session_scope() as session: + stmt = select(Job).where(Job.id == task_id) + job = (await session.exec(stmt)).first() + assert job is not None, "Job was not created" + return job + + +async def test_handle_job_executed(task_service: TaskService, sample_job: Job): + """Test handling of successful job execution.""" + # Create a JobExecutionEvent + event = JobExecutionEvent( + code=0, # Success code + job_id=sample_job.id, + jobstore="default", + retval={"output": "Test result"}, + scheduled_run_time=sample_job.next_run_time, + ) + + # Handle the event + await task_service._handle_job_executed(event) + + # Verify the job status was updated + async with session_scope() as session: + stmt = select(Job).where(Job.id == sample_job.id) + updated_job = (await session.exec(stmt)).first() + assert updated_job is not None, "Job not found" + assert updated_job.status == JobStatus.COMPLETED, "Job status not updated to COMPLETED" + assert updated_job.result == {"output": "Test result"}, "Job result not saved correctly" + + +async def test_handle_job_error(task_service: TaskService, sample_job: Job): + """Test handling of job execution error.""" + # Create a JobEvent with an error + test_error = ValueError("Test error message") + event = JobExecutionEvent( + code=1, # Error code + job_id=sample_job.id, + jobstore="default", + exception=test_error, + scheduled_run_time=sample_job.next_run_time, + ) + + # Handle the error event + await task_service._handle_job_error(event) + + # Verify the job status and error were updated + async with session_scope() as session: + stmt = select(Job).where(Job.id == sample_job.id) + updated_job = (await session.exec(stmt)).first() + assert updated_job is not None, "Job not found" + assert updated_job.status == JobStatus.FAILED, "Job status not updated to FAILED" + assert updated_job.error == str(test_error), "Job error not saved correctly" + + +async def test_job_lifecycle(task_service: TaskService, sample_job: Job): + """Test the complete lifecycle of a job from creation to completion.""" + # Verify initial state + async with session_scope() as session: + stmt = select(Job).where(Job.id == sample_job.id) + job = (await session.exec(stmt)).first() + assert job is not None, "Job not found" + assert job.status == JobStatus.PENDING, "Initial job status should be PENDING" + assert job.result is None, "Initial job result should be None" + assert job.error is None, "Initial job error should be None" + + # Simulate successful execution + success_event = JobExecutionEvent( + code=0, + job_id=sample_job.id, + jobstore="default", + retval={"output": "Success result"}, + scheduled_run_time=sample_job.next_run_time, + ) + await task_service._handle_job_executed(success_event) + + # Verify successful completion + async with session_scope() as session: + stmt = select(Job).where(Job.id == sample_job.id) + completed_job = (await session.exec(stmt)).first() + assert completed_job is not None, "Job not found" + assert completed_job.status == JobStatus.COMPLETED, "Job should be marked as completed" + assert completed_job.result == {"output": "Success result"}, "Job result should be saved" + assert completed_job.error is None, "Completed job should not have an error" + + +async def test_concurrent_job_updates(task_service: TaskService, sample_job: Job): + """Test handling concurrent updates to the same job.""" + # Create multiple events for the same job + success_event = JobExecutionEvent( + code=0, + job_id=sample_job.id, + scheduled_run_time=sample_job.next_run_time, + jobstore="default", + retval="Success result", + ) + error_event = JobExecutionEvent( + code=1, + job_id=sample_job.id, + jobstore="default", + exception=ValueError("Test error"), + scheduled_run_time=sample_job.next_run_time, + ) + + # Handle events concurrently + await asyncio.gather( + task_service._handle_job_executed(success_event), + task_service._handle_job_error(error_event), + ) + + # Verify final state (one of the updates should succeed, the other should fail gracefully) + async with session_scope() as session: + stmt = select(Job).where(Job.id == sample_job.id) + final_job = (await session.exec(stmt)).first() + assert final_job is not None, "Job not found" + assert final_job.status in [JobStatus.COMPLETED, JobStatus.FAILED], "Job should be either completed or failed" + + +@pytest.mark.usefixtures("client") +async def test_invalid_job_id(task_service: TaskService): + """Test handling events for non-existent jobs.""" + # Create events with invalid job ID + invalid_success_event = JobExecutionEvent( + code=0, + job_id="nonexistent_id", + jobstore="default", + retval="Success result", + scheduled_run_time=datetime.datetime.now(datetime.timezone.utc), + ) + invalid_error_event = JobExecutionEvent( + code=1, + job_id="nonexistent_id", + jobstore="default", + exception=ValueError("Test error"), + scheduled_run_time=datetime.datetime.now(datetime.timezone.utc), + ) + + # Both handlers should handle non-existent jobs gracefully + await task_service._handle_job_executed(invalid_success_event) + await task_service._handle_job_error(invalid_error_event) diff --git a/src/backend/tests/unit/test_webhook.py b/src/backend/tests/unit/test_webhook.py index dd75a3370442..67472ae67f7f 100644 --- a/src/backend/tests/unit/test_webhook.py +++ b/src/backend/tests/unit/test_webhook.py @@ -1,3 +1,5 @@ +import asyncio + import aiofiles import anyio import pytest @@ -8,7 +10,16 @@ def _check_openai_api_key_in_environment_variables(): pass -async def test_webhook_endpoint(client, added_webhook_test): +async def poll_until_job_is_completed(client, job_id, headers, times=10): + job_endpoint = f"api/v1/jobs/{job_id}" + for _ in range(times): + response = await client.get(job_endpoint, headers=headers) + assert response.status_code == 200 + assert response.json()["pending"] is False + await asyncio.sleep(1) + + +async def test_webhook_endpoint(client, added_webhook_test, logged_in_headers): # The test is as follows: # 1. The flow when run will get a "path" from the payload and save a file with the path as the name. # We will create a temporary file path and send it to the webhook endpoint, then check if the file exists. @@ -23,9 +34,12 @@ async def test_webhook_endpoint(client, added_webhook_test): response = await client.post(endpoint, json=payload) assert response.status_code == 202 + job_id = response.json()["job_id"] + await poll_until_job_is_completed(client, job_id, logged_in_headers) + assert await file_path.exists() - assert not await file_path.exists() + assert not await file_path.exists # Send an invalid payload payload = {"invalid_key": "invalid_value"} diff --git a/uv.lock b/uv.lock index 07b5d5c3f9fe..ad720a302212 100644 --- a/uv.lock +++ b/uv.lock @@ -266,6 +266,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, ] +[[package]] +name = "apscheduler" +version = "3.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/00/6d6814ddc19be2df62c8c898c4df6b5b1914f3bd024b780028caa392d186/apscheduler-3.11.0.tar.gz", hash = "sha256:4c622d250b0955a65d5d0eb91c33e6d43fd879834bf541e0a18661ae60460133", size = 107347 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/ae/9a053dd9229c0fde6b1f1f33f609ccff1ee79ddda364c756a924c6d8563b/APScheduler-3.11.0-py3-none-any.whl", hash = "sha256:fc134ca32e50f5eadcc4938e3a4545ab19131435e851abb40b34d63d5141c6da", size = 64004 }, +] + [[package]] name = "arize-phoenix-otel" version = "0.6.1" @@ -532,7 +544,7 @@ name = "blessed" version = "1.20.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "jinxed", marker = "sys_platform == 'win32'" }, + { name = "jinxed", marker = "platform_system == 'Windows'" }, { name = "six" }, { name = "wcwidth" }, ] @@ -962,7 +974,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -3143,7 +3155,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "appnope", marker = "platform_system == 'Darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -3234,7 +3246,7 @@ name = "jinxed" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ansicon", marker = "sys_platform == 'win32'" }, + { name = "ansicon", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/20/d0/59b2b80e7a52d255f9e0ad040d2e826342d05580c4b1d7d7747cfb8db731/jinxed-1.3.0.tar.gz", hash = "sha256:1593124b18a41b7a3da3b078471442e51dbad3d77b4d4f2b0c26ab6f7d660dbf", size = 80981 } wheels = [ @@ -4219,6 +4231,7 @@ source = { editable = "src/backend/base" } dependencies = [ { name = "aiofiles" }, { name = "alembic" }, + { name = "apscheduler" }, { name = "assemblyai" }, { name = "asyncer" }, { name = "bcrypt" }, @@ -4341,6 +4354,7 @@ dev = [ requires-dist = [ { name = "aiofiles", specifier = ">=24.1.0,<25.0.0" }, { name = "alembic", specifier = ">=1.13.0,<2.0.0" }, + { name = "apscheduler", specifier = "==3.11.0" }, { name = "assemblyai", specifier = ">=0.33.0,<1.0.0" }, { name = "asyncer", specifier = ">=0.0.5,<1.0.0" }, { name = "bcrypt", specifier = "==4.0.1" }, @@ -6098,7 +6112,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pywin32", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -8575,19 +8589,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -8628,7 +8642,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ @@ -8985,6 +8999,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/ab/7e5f53c3b9d14972843a647d8d7a853969a58aecc7559cb3267302c94774/tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd", size = 346586 }, ] +[[package]] +name = "tzlocal" +version = "5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/d3/c19d65ae67636fe63953b20c2e4a8ced4497ea232c43ff8d01db16de8dc0/tzlocal-5.2.tar.gz", hash = "sha256:8d399205578f1a9342816409cc1e46a93ebd5755e39ea2d85334bea911bf0e6e", size = 30201 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/3f/c4c51c55ff8487f2e6d0e618dba917e3c3ee2caae6cf0fbb59c9b1876f2e/tzlocal-5.2-py3-none-any.whl", hash = "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8", size = 17859 }, +] + [[package]] name = "ujson" version = "5.10.0"