Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 111 additions & 74 deletions packages/celery-library/src/celery_library/_task_manager.py

Large diffs are not rendered by default.

42 changes: 19 additions & 23 deletions packages/celery-library/src/celery_library/async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
JobNotDoneError,
JobSchedulerError,
)
from models_library.celery import OwnerMetadata, TaskExecutionMetadata, TaskState, TaskStatus
from models_library.celery import TaskExecutionMetadata, TaskState, TaskStatus
from servicelib.celery.task_manager import TaskManager
from servicelib.logging_utils import log_catch
from tenacity import (
Expand All @@ -46,12 +46,10 @@
async def cancel_job(
task_manager: TaskManager,
*,
owner_metadata: OwnerMetadata,
job_id: AsyncJobId,
) -> None:
try:
await task_manager.cancel(
owner_metadata=owner_metadata,
task_or_group_uuid=job_id,
)
except TaskOrGroupNotFoundError as exc:
Expand All @@ -63,22 +61,18 @@ async def cancel_job(
async def get_job_result(
task_manager: TaskManager,
*,
owner_metadata: OwnerMetadata,
job_id: AsyncJobId,
) -> AsyncJobResult:
assert task_manager # nosec
assert job_id # nosec
assert owner_metadata # nosec

try:
task_status = await task_manager.get_status(
owner_metadata=owner_metadata,
task_or_group_uuid=job_id,
)
if not task_status.is_done:
raise JobNotDoneError(job_id=job_id)
task_result = await task_manager.get_result(
owner_metadata=owner_metadata,
task_or_group_uuid=job_id,
)
except TaskOrGroupNotFoundError as exc:
Expand Down Expand Up @@ -112,12 +106,10 @@ async def get_job_result(
async def get_job_status(
task_manager: TaskManager,
*,
owner_metadata: OwnerMetadata,
job_id: AsyncJobId,
) -> AsyncJobStatus:
try:
task_status = await task_manager.get_status(
owner_metadata=owner_metadata,
task_or_group_uuid=job_id,
)
except TaskOrGroupNotFoundError as exc:
Expand All @@ -135,12 +127,16 @@ async def get_job_status(
async def list_jobs(
task_manager: TaskManager,
*,
owner_metadata: OwnerMetadata,
owner: str,
user_id: int | None = None,
product_name: str | None = None,
) -> list[AsyncJobGet]:
assert task_manager # nosec
try:
tasks = await task_manager.list_tasks(
owner_metadata=owner_metadata,
owner=owner,
user_id=user_id,
product_name=product_name,
)
except TaskManagerError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc
Expand All @@ -152,12 +148,16 @@ async def submit_job(
task_manager: TaskManager,
*,
execution_metadata: TaskExecutionMetadata,
owner_metadata: OwnerMetadata,
owner: str,
user_id: int | None = None,
product_name: str | None = None,
**kwargs,
) -> AsyncJobGet:
task_id = await task_manager.submit_task(
execution_metadata=execution_metadata,
owner_metadata=owner_metadata,
owner=owner,
user_id=user_id,
product_name=product_name,
**kwargs,
)
return AsyncJobGet(job_id=task_id, job_name=execution_metadata.name)
Expand All @@ -166,7 +166,6 @@ async def submit_job(
async def _wait_for_completion(
task_manager: TaskManager,
*,
owner_metadata: OwnerMetadata,
job_id: AsyncJobId,
stop_after: datetime.timedelta,
) -> AsyncGenerator[AsyncJobStatus]:
Expand All @@ -181,7 +180,6 @@ async def _wait_for_completion(
with attempt:
job_status = await get_job_status(
task_manager,
owner_metadata=owner_metadata,
job_id=job_id,
)
yield job_status
Expand Down Expand Up @@ -214,7 +212,6 @@ async def result(self) -> Any:
async def wait_and_get_job_result(
task_manager: TaskManager,
*,
owner_metadata: OwnerMetadata,
job_id: AsyncJobId,
stop_after: datetime.timedelta,
) -> AsyncGenerator[AsyncJobResultUpdate]:
Expand All @@ -225,7 +222,6 @@ async def wait_and_get_job_result(
async for job_status in _wait_for_completion(
task_manager,
job_id=job_id,
owner_metadata=owner_metadata,
stop_after=stop_after,
):
assert job_status is not None # nosec
Expand All @@ -237,15 +233,13 @@ async def wait_and_get_job_result(
job_status,
get_job_result(
task_manager,
owner_metadata=owner_metadata,
job_id=job_id,
),
)
except (TimeoutError, CancelledError) as error:
try:
await cancel_job(
task_manager,
owner_metadata=owner_metadata,
job_id=job_id,
)
except Exception as exc:
Expand All @@ -257,7 +251,9 @@ async def submit_job_and_wait(
task_manager: TaskManager,
*,
execution_metadata: TaskExecutionMetadata,
owner_metadata: OwnerMetadata,
owner: str,
user_id: int | None = None,
product_name: str | None = None,
stop_after: datetime.timedelta,
**kwargs,
) -> AsyncGenerator[AsyncJobResultUpdate]:
Expand All @@ -266,22 +262,22 @@ async def submit_job_and_wait(
async_job = await submit_job(
task_manager,
execution_metadata=execution_metadata,
owner_metadata=owner_metadata,
owner=owner,
user_id=user_id,
product_name=product_name,
**kwargs,
)
except (TimeoutError, CancelledError):
if async_job is not None:
await cancel_job(
task_manager,
owner_metadata=owner_metadata,
job_id=async_job.job_id,
)
raise

async for wait_and_ in wait_and_get_job_result(
task_manager,
job_id=async_job.job_id,
owner_metadata=owner_metadata,
stop_after=stop_after,
):
yield wait_and_
Loading
Loading