From b02111a2c16b50f7408a21f4157c6eb0e0b0953a Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Tue, 14 Apr 2026 06:06:02 +0200 Subject: [PATCH 01/14] add indexes --- .../src/celery_library/backends/_redis.py | 105 +++++++++++++----- .../tests/unit/test_task_manager.py | 24 ++++ 2 files changed, 102 insertions(+), 27 deletions(-) diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 948f169adde3..93d3995665c1 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -2,8 +2,10 @@ import logging from dataclasses import dataclass from datetime import UTC, datetime, timedelta +from itertools import product from typing import TYPE_CHECKING, Final +from common_library.json_serialization import json_dumps from models_library.celery import ( WILDCARD, ExecutionMetadata, @@ -25,9 +27,10 @@ _CELERY_TASK_PREFIX: Final[str] = "celery-task-" _CELERY_TASK_ID_KEY_ENCODING: Final[str] = "utf-8" -_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000 _CELERY_TASK_EXEC_METADATA_KEY: Final[str] = "exec-meta" _CELERY_TASK_PROGRESS_KEY: Final[str] = "progress" +_CELERY_TASK_INDEX_PREFIX: Final[str] = "celery-task-index-" +_UUID_KEY_PREFIX: Final[str] = "uuid=" # Redis list to store streamed results _CELERY_TASK_STREAM_PREFIX: Final[str] = "celery-task-stream-" @@ -51,6 +54,37 @@ def _build_redis_stream_meta_key(task_key: TaskKey) -> str: return f"{_build_redis_stream_key(task_key)}{_CELERY_TASK_DELIMTATOR}{_CELERY_TASK_STREAM_METADATA}" +def _without_uuid_token(task_or_group_key: TaskKey | GroupKey) -> str: + return _CELERY_TASK_DELIMTATOR.join( + token for token in task_or_group_key.split(_CELERY_TASK_DELIMTATOR) if not token.startswith(_UUID_KEY_PREFIX) + ) + + +def _build_redis_owner_index_key(owner_key_without_uuid: str) -> str: + return f"{_CELERY_TASK_INDEX_PREFIX}{owner_key_without_uuid}" + + +def _build_redis_owner_index_key_for_query(owner_metadata: OwnerMetadata) -> str: + owner_key = owner_metadata.model_dump_key(task_or_group_uuid=WILDCARD) + return _build_redis_owner_index_key(_without_uuid_token(owner_key)) + + +def _build_redis_owner_index_keys_for_task(task_or_group_key: TaskKey | GroupKey) -> list[str]: + owner_tokens = [ + token.split("=", maxsplit=1) for token in _without_uuid_token(task_or_group_key).split(_CELERY_TASK_DELIMTATOR) + ] + wildcard_value = json_dumps(WILDCARD) + + keys: list[str] = [] + for mask in product((False, True), repeat=len(owner_tokens)): + query_owner_key = _CELERY_TASK_DELIMTATOR.join( + f"{key}={wildcard_value if use_wildcard else value}" + for (key, value), use_wildcard in zip(owner_tokens, mask, strict=True) + ) + keys.append(_build_redis_owner_index_key(query_owner_key)) + return keys + + @dataclass(frozen=True) class RedisTaskStore: _redis_client_sdk: RedisClientSDK @@ -62,15 +96,19 @@ async def create_group( task_keys: list[TaskKey], expiry: timedelta, ) -> None: - group_key = _build_redis_task_or_group_key(group_key) + redis_group_key = _build_redis_task_or_group_key(group_key) pipe = self._redis_client_sdk.redis.pipeline() + index_score = datetime.now(tz=UTC).timestamp() + pipe.hset( - name=group_key, + name=redis_group_key, key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) + for index_key in _build_redis_owner_index_keys_for_task(group_key): + pipe.zadd(index_key, {group_key: index_score}) - # group tasks + # group sub-tasks: store hash only, no ZSET index (filtered out in list_tasks) for task_key, (task_execution_metadata, _) in zip(task_keys, execution_metadata.tasks, strict=True): pipe.hset( name=_build_redis_task_or_group_key(task_key), @@ -79,7 +117,7 @@ async def create_group( ) await pipe.execute() await self._redis_client_sdk.redis.expire( - group_key, + redis_group_key, expiry, ) @@ -90,23 +128,29 @@ async def create_task( expiry: timedelta, ) -> None: redis_key = _build_redis_task_or_group_key(task_key) - await handle_redis_returns_union_types( - self._redis_client_sdk.redis.hset( - name=redis_key, - key=_CELERY_TASK_EXEC_METADATA_KEY, - value=execution_metadata.model_dump_json(), - ) + index_score = datetime.now(tz=UTC).timestamp() + + pipe = self._redis_client_sdk.redis.pipeline() + pipe.hset( + name=redis_key, + key=_CELERY_TASK_EXEC_METADATA_KEY, + value=execution_metadata.model_dump_json(), ) + for index_key in _build_redis_owner_index_keys_for_task(task_key): + pipe.zadd(index_key, {task_key: index_score}) + await pipe.execute() + await self._redis_client_sdk.redis.expire( redis_key, expiry, ) async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None: + redis_key = _build_redis_task_or_group_key(task_key) raw_result = await handle_redis_returns_union_types( self._redis_client_sdk.redis.hget( - _build_redis_task_or_group_key(task_key), - _CELERY_TASK_EXEC_METADATA_KEY, + name=redis_key, + key=_CELERY_TASK_EXEC_METADATA_KEY, ) ) if not raw_result: @@ -143,23 +187,25 @@ async def get_task_progress(self, task_key: TaskKey) -> ProgressReport | None: return None async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: - search_key = _CELERY_TASK_PREFIX + owner_metadata.model_dump_key(task_or_group_uuid=WILDCARD) + owner_index_key = _build_redis_owner_index_key_for_query(owner_metadata) + + raw_members = await self._redis_client_sdk.redis.zrange(owner_index_key, 0, -1) + if not raw_members: + return [] + + members = [m.decode(_CELERY_TASK_ID_KEY_ENCODING) if isinstance(m, bytes) else m for m in raw_members] - keys: list[str] = [] pipe = self._redis_client_sdk.redis.pipeline() - async for key in self._redis_client_sdk.redis.scan_iter( - match=search_key, count=_CELERY_TASK_SCAN_COUNT_PER_BATCH - ): - # fake redis (tests) returns bytes, real redis returns str - dec_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING) if isinstance(key, bytes) else key - keys.append(dec_key) - pipe.hget(dec_key, _CELERY_TASK_EXEC_METADATA_KEY) + for member in members: + pipe.hget(_build_redis_task_or_group_key(member), _CELERY_TASK_EXEC_METADATA_KEY) results = await pipe.execute() tasks = [] - for key, raw_metadata in zip(keys, results, strict=True): + stale_members: list[str] = [] + for member, raw_metadata in zip(members, results, strict=True): if raw_metadata is None: + stale_members.append(member) continue with contextlib.suppress(ValidationError): @@ -169,17 +215,22 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: tasks.append( Task( - uuid=OwnerMetadata.get_task_or_group_uuid(key), + uuid=OwnerMetadata.get_task_or_group_uuid(member), metadata=execution_metadata, ) ) + if stale_members: + await self._redis_client_sdk.redis.zrem(owner_index_key, *stale_members) + return tasks async def remove_task(self, task_key: TaskKey) -> None: - await self._redis_client_sdk.redis.delete( - _build_redis_task_or_group_key(task_key), - ) + pipe = self._redis_client_sdk.redis.pipeline() + pipe.delete(_build_redis_task_or_group_key(task_key)) + for index_key in _build_redis_owner_index_keys_for_task(task_key): + pipe.zrem(index_key, task_key) + await pipe.execute() async def set_task_progress(self, task_key: TaskKey, report: ProgressReport) -> None: await handle_redis_returns_union_types( diff --git a/packages/celery-library/tests/unit/test_task_manager.py b/packages/celery-library/tests/unit/test_task_manager.py index 4b57e64d8e8a..50c875aee096 100644 --- a/packages/celery-library/tests/unit/test_task_manager.py +++ b/packages/celery-library/tests/unit/test_task_manager.py @@ -251,6 +251,30 @@ async def test_listing_task_uuids_contains_submitted_task( assert any(task.uuid == task_uuid for task in tasks) +async def test_listing_tasks_uses_zset_index_and_not_scan( + task_manager: CeleryTaskManager, + with_celery_worker: WorkController, + fake_owner_metadata: OwnerMetadata, + monkeypatch: pytest.MonkeyPatch, +): + task_uuid = await task_manager.submit_task( + TaskExecutionMetadata( + name=dreamer_task.__name__, + ), + owner_metadata=fake_owner_metadata, + ) + + def _forbid_scan_iter(*args, **kwargs): + msg = "list_tasks must not use redis.scan_iter" + raise AssertionError(msg) + + redis_client = task_manager._task_store._redis_client_sdk.redis # noqa: SLF001 + monkeypatch.setattr(redis_client, "scan_iter", _forbid_scan_iter) + + tasks = await task_manager.list_tasks(fake_owner_metadata) + assert any(task.uuid == task_uuid for task in tasks) + + async def test_filtering_listing_tasks( task_manager: CeleryTaskManager, with_celery_worker: WorkController, From 42ca4d66d7b944c5e0396e31b1de9ea012500c86 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Tue, 14 Apr 2026 07:14:27 +0200 Subject: [PATCH 02/14] move test --- .../tests/unit/test_redis_store.py | 151 ++++++++++++++++++ .../tests/unit/test_task_manager.py | 24 --- 2 files changed, 151 insertions(+), 24 deletions(-) create mode 100644 packages/celery-library/tests/unit/test_redis_store.py diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py new file mode 100644 index 000000000000..d728fc19ee06 --- /dev/null +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -0,0 +1,151 @@ +# pylint: disable=redefined-outer-name + +from collections.abc import AsyncIterator +from datetime import timedelta + +import pytest +from celery_library.backends import RedisTaskStore +from celery_library.backends._redis import _build_redis_task_or_group_key +from faker import Faker +from models_library.celery import ( + OwnerMetadata, + Task, + TaskExecutionMetadata, + Wildcard, +) +from models_library.users import UserID +from servicelib.redis import RedisClientSDK +from settings_library.redis import RedisDatabase, RedisSettings + +_faker = Faker() + +pytest_simcore_core_services_selection = ["redis"] +pytest_simcore_ops_services_selection = [] + + +class _TestOwnerMetadata(OwnerMetadata): + user_id: UserID + product_name: str | Wildcard + + +@pytest.fixture +async def redis_task_store( + use_in_memory_redis: RedisSettings, +) -> AsyncIterator[RedisTaskStore]: + redis_client_sdk = RedisClientSDK( + use_in_memory_redis.build_redis_dsn(RedisDatabase.CELERY_TASKS), + client_name="pytest_redis_store", + ) + await redis_client_sdk.setup() + try: + yield RedisTaskStore(redis_client_sdk) + finally: + await redis_client_sdk.shutdown() + + +async def test_list_tasks_uses_zset_index_not_scan( + redis_task_store: RedisTaskStore, + monkeypatch: pytest.MonkeyPatch, +): + owner = _TestOwnerMetadata(user_id=10001, product_name="osparc", owner="test-svc") + task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) + + await redis_task_store.create_task( + task_key, + TaskExecutionMetadata(name="my_task"), + expiry=timedelta(minutes=5), + ) + + def _forbid_scan_iter(*_args, **_kwargs): + msg = "list_tasks must not use redis.scan_iter" + raise AssertionError(msg) + + monkeypatch.setattr( + redis_task_store._redis_client_sdk.redis, # noqa: SLF001 + "scan_iter", + _forbid_scan_iter, + ) + + tasks = await redis_task_store.list_tasks(owner) + assert len(tasks) == 1 + assert tasks[0].uuid == OwnerMetadata.get_task_or_group_uuid(task_key) + + +async def test_list_tasks_with_wildcard_filtering( + redis_task_store: RedisTaskStore, +): + user_id = 42 + owner = "test-svc" + expected_tasks: list[Task] = [] + + for _ in range(5): + om = _TestOwnerMetadata(user_id=user_id, product_name=_faker.word(), owner=owner) + task_key = om.model_dump_key(task_or_group_uuid=_faker.uuid4()) + await redis_task_store.create_task( + task_key, + TaskExecutionMetadata(name="my_task"), + expiry=timedelta(minutes=5), + ) + expected_tasks.append( + Task( + uuid=OwnerMetadata.get_task_or_group_uuid(task_key), + metadata=TaskExecutionMetadata(name="my_task"), + ) + ) + + for _ in range(3): + om = _TestOwnerMetadata( + user_id=_faker.pyint(min_value=100, max_value=200), + product_name=_faker.word(), + owner=owner, + ) + task_key = om.model_dump_key(task_or_group_uuid=_faker.uuid4()) + await redis_task_store.create_task( + task_key, + TaskExecutionMetadata(name="my_task"), + expiry=timedelta(minutes=5), + ) + + search = _TestOwnerMetadata(user_id=user_id, product_name="*", owner=owner) + tasks = await redis_task_store.list_tasks(search) + assert {t.uuid for t in tasks} == {t.uuid for t in expected_tasks} + + +async def test_remove_task_cleans_up_zset_indexes( + redis_task_store: RedisTaskStore, +): + owner = _TestOwnerMetadata(user_id=10003, product_name="osparc", owner="test-svc") + task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) + + await redis_task_store.create_task( + task_key, + TaskExecutionMetadata(name="my_task"), + expiry=timedelta(minutes=5), + ) + assert len(await redis_task_store.list_tasks(owner)) == 1 + + await redis_task_store.remove_task(task_key) + assert len(await redis_task_store.list_tasks(owner)) == 0 + + +async def test_stale_zset_entries_are_pruned_on_list( + redis_task_store: RedisTaskStore, +): + owner = _TestOwnerMetadata(user_id=10004, product_name="osparc", owner="test-svc") + task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) + + await redis_task_store.create_task( + task_key, + TaskExecutionMetadata(name="my_task"), + expiry=timedelta(minutes=5), + ) + + # Simulate hash expiry by deleting the hash directly (bypass remove_task) + redis = redis_task_store._redis_client_sdk.redis # noqa: SLF001 + + await redis.delete(_build_redis_task_or_group_key(task_key)) + + # First list should return empty and prune the stale entry + assert await redis_task_store.list_tasks(owner) == [] + # Second list confirms the ZSET is clean + assert await redis_task_store.list_tasks(owner) == [] diff --git a/packages/celery-library/tests/unit/test_task_manager.py b/packages/celery-library/tests/unit/test_task_manager.py index 50c875aee096..4b57e64d8e8a 100644 --- a/packages/celery-library/tests/unit/test_task_manager.py +++ b/packages/celery-library/tests/unit/test_task_manager.py @@ -251,30 +251,6 @@ async def test_listing_task_uuids_contains_submitted_task( assert any(task.uuid == task_uuid for task in tasks) -async def test_listing_tasks_uses_zset_index_and_not_scan( - task_manager: CeleryTaskManager, - with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, - monkeypatch: pytest.MonkeyPatch, -): - task_uuid = await task_manager.submit_task( - TaskExecutionMetadata( - name=dreamer_task.__name__, - ), - owner_metadata=fake_owner_metadata, - ) - - def _forbid_scan_iter(*args, **kwargs): - msg = "list_tasks must not use redis.scan_iter" - raise AssertionError(msg) - - redis_client = task_manager._task_store._redis_client_sdk.redis # noqa: SLF001 - monkeypatch.setattr(redis_client, "scan_iter", _forbid_scan_iter) - - tasks = await task_manager.list_tasks(fake_owner_metadata) - assert any(task.uuid == task_uuid for task in tasks) - - async def test_filtering_listing_tasks( task_manager: CeleryTaskManager, with_celery_worker: WorkController, From 126472187f6a276a0ba19417b5ee5958ad113f74 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Tue, 14 Apr 2026 21:51:27 +0200 Subject: [PATCH 03/14] fix Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- packages/celery-library/tests/unit/test_redis_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py index d728fc19ee06..46dba475d4b4 100644 --- a/packages/celery-library/tests/unit/test_redis_store.py +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -56,7 +56,7 @@ async def test_list_tasks_uses_zset_index_not_scan( expiry=timedelta(minutes=5), ) - def _forbid_scan_iter(*_args, **_kwargs): + def _forbid_scan_iter(*_args: object, **_kwargs: object) -> None: msg = "list_tasks must not use redis.scan_iter" raise AssertionError(msg) From 6393a07eed5370d3108c828084e0aad2c9da9bb0 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Wed, 15 Apr 2026 12:59:11 +0200 Subject: [PATCH 04/14] fix --- .../tests/unit/test_redis_store.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py index 46dba475d4b4..edffd9b72cf7 100644 --- a/packages/celery-library/tests/unit/test_redis_store.py +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -29,22 +29,30 @@ class _TestOwnerMetadata(OwnerMetadata): @pytest.fixture -async def redis_task_store( +async def redis_client_sdk( use_in_memory_redis: RedisSettings, -) -> AsyncIterator[RedisTaskStore]: +) -> AsyncIterator[RedisClientSDK]: redis_client_sdk = RedisClientSDK( use_in_memory_redis.build_redis_dsn(RedisDatabase.CELERY_TASKS), client_name="pytest_redis_store", ) await redis_client_sdk.setup() try: - yield RedisTaskStore(redis_client_sdk) + yield redis_client_sdk finally: await redis_client_sdk.shutdown() +@pytest.fixture +async def redis_task_store( + redis_client_sdk: RedisClientSDK, +) -> RedisTaskStore: + return RedisTaskStore(redis_client_sdk) + + async def test_list_tasks_uses_zset_index_not_scan( redis_task_store: RedisTaskStore, + redis_client_sdk: RedisClientSDK, monkeypatch: pytest.MonkeyPatch, ): owner = _TestOwnerMetadata(user_id=10001, product_name="osparc", owner="test-svc") @@ -61,7 +69,7 @@ def _forbid_scan_iter(*_args: object, **_kwargs: object) -> None: raise AssertionError(msg) monkeypatch.setattr( - redis_task_store._redis_client_sdk.redis, # noqa: SLF001 + redis_client_sdk.redis, "scan_iter", _forbid_scan_iter, ) @@ -130,6 +138,7 @@ async def test_remove_task_cleans_up_zset_indexes( async def test_stale_zset_entries_are_pruned_on_list( redis_task_store: RedisTaskStore, + redis_client_sdk: RedisClientSDK, ): owner = _TestOwnerMetadata(user_id=10004, product_name="osparc", owner="test-svc") task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) @@ -141,7 +150,7 @@ async def test_stale_zset_entries_are_pruned_on_list( ) # Simulate hash expiry by deleting the hash directly (bypass remove_task) - redis = redis_task_store._redis_client_sdk.redis # noqa: SLF001 + redis = redis_client_sdk.redis await redis.delete(_build_redis_task_or_group_key(task_key)) From 4fd590fa0b3d92c2f395bd5375c31c1135f94fc3 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 16 Apr 2026 13:16:26 +0200 Subject: [PATCH 05/14] use tasks uuid --- .../src/celery_library/_task_manager.py | 180 ++++++++------- .../src/celery_library/async_jobs.py | 42 ++-- .../src/celery_library/backends/_redis.py | 110 +++++++--- .../src/celery_library/errors.py | 2 +- .../celery-library/tests/unit/conftest.py | 11 +- .../tests/unit/task_manager/conftest.py | 10 +- .../task_manager/test_task_manager_cancel.py | 55 +++-- .../test_task_manager_description.py | 41 ++-- .../task_manager/test_task_manager_groups.py | 206 +++++++++++------- .../task_manager/test_task_manager_streams.py | 29 +-- .../task_manager/test_task_manager_tasks.py | 85 ++++---- .../tests/unit/test_async_jobs.py | 48 ++-- .../tests/unit/test_redis_store.py | 63 +++--- .../src/models_library/celery.py | 146 +++---------- .../notifications/rpc/_message.py | 26 +-- .../helpers/async_jobs_server.py | 15 +- .../src/pytest_simcore/helpers/storage_rpc.py | 11 +- .../celery/async_jobs/notifications.py | 17 +- .../celery/async_jobs/storage/paths.py | 10 +- .../celery/async_jobs/storage/simcore_s3.py | 5 +- .../src/servicelib/celery/task_manager.py | 30 ++- .../rpc_interfaces/notifications/_message.py | 17 +- packages/service-library/tests/test_celery.py | 155 ++----------- .../_service_function_jobs_task_client.py | 17 +- .../api/routes/tasks.py | 38 +--- .../models/domain/celery_models.py | 13 -- .../modules/celery/worker/_functions_tasks.py | 3 + .../services_rpc/async_jobs.py | 22 +- .../services_rpc/storage.py | 10 +- .../celery/test_functions_celery.py | 14 +- .../test_api_routers_function_jobs.py | 4 +- .../test_service_function_jobs_task_client.py | 3 +- .../api/celery/_email.py | 2 + .../api/rpc/_message.py | 8 +- .../services/_message.py | 25 ++- .../tests/unit/test_api_celery_send_email.py | 10 +- .../tests/unit/test_api_rpc_message.py | 22 +- .../api/_worker_tasks/_files.py | 2 + .../api/_worker_tasks/_paths.py | 2 + .../api/_worker_tasks/_simcore_s3.py | 2 +- .../api/rest/_files.py | 18 +- .../unit/test_async_jobs_handlers_paths.py | 19 +- .../test_async_jobs_handlers_simcore_s3.py | 27 +-- .../src/simcore_service_webserver/models.py | 8 - .../notifications/_service.py | 16 +- .../storage/_rest.py | 34 +-- .../simcore_service_webserver/storage/api.py | 14 +- .../tasks/_controller/_rest.py | 38 +--- .../tasks/_tasks_service.py | 18 +- .../test_notifications_service.py | 8 +- 50 files changed, 745 insertions(+), 966 deletions(-) diff --git a/packages/celery-library/src/celery_library/_task_manager.py b/packages/celery-library/src/celery_library/_task_manager.py index 2d2eeae2d09d..4ee0c3ebb6ce 100644 --- a/packages/celery-library/src/celery_library/_task_manager.py +++ b/packages/celery-library/src/celery_library/_task_manager.py @@ -17,7 +17,6 @@ GroupStatus, GroupTaskExecutionMetadata, GroupUUID, - OwnerMetadata, Task, TaskExecutionMetadata, TaskKey, @@ -63,9 +62,16 @@ def _get_task_expiry( else self._settings.CELERY_RESULT_EXPIRES ) - async def _cleanup_task(self, task_key: TaskKey) -> None: + async def _cleanup_task( + self, + task_key: TaskKey, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> None: try: - await self._task_store.remove_task(task_key) + await self._task_store.remove_task(task_key, owner=owner, user_id=user_id, product_name=product_name) except CeleryError: _logger.warning( "Unable to cleanup task '%s' during error handling", @@ -74,10 +80,10 @@ async def _cleanup_task(self, task_key: TaskKey) -> None: ) @staticmethod - def _create_task_ids(owner_metadata: OwnerMetadata) -> tuple[TaskUUID, TaskKey]: - """Generate task UUID and task key.""" + def _create_task_ids() -> tuple[TaskUUID, TaskKey]: + """Generate task UUID and task key (plain UUID string).""" task_uuid = uuid4() - task_key = owner_metadata.model_dump_key(task_or_group_uuid=task_uuid) + task_key = str(task_uuid) return task_uuid, task_key def _get_rate_limit_interval(self, task_name: str) -> float | None: @@ -94,7 +100,9 @@ async def submit_group( self, execution_metadata: GroupExecutionMetadata, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, ) -> tuple[GroupUUID, list[TaskUUID]]: """ Submit a group of tasks in parallel. @@ -104,7 +112,7 @@ async def submit_group( with log_context( _logger, logging.DEBUG, - msg=f"Submit group: {owner_metadata=} items={len(execution_metadata.tasks)}", + msg=f"Submit group: {owner=} {user_id=} {product_name=} items={len(execution_metadata.tasks)}", ): created: list[tuple[str, TaskUUID]] = [] group_key: GroupKey | None = None @@ -116,7 +124,7 @@ async def submit_group( expiries: list[timedelta] = [] for idx, (group_task_execution_metadata, task_params) in enumerate(execution_metadata.tasks): - task_uuid, task_key = self._create_task_ids(owner_metadata) + task_uuid, task_key = self._create_task_ids() expiry = self._get_task_expiry(group_task_execution_metadata) expiries.append(expiry) @@ -150,6 +158,9 @@ async def submit_group( queue=group_task_meta.queue, ephemeral=group_task_meta.ephemeral, ), + owner=owner, + user_id=user_id, + product_name=product_name, expiry=expiry, ) @@ -157,18 +168,21 @@ async def submit_group( group_result.save() assert group_result.id is not None # nosec - group_key = owner_metadata.model_dump_key(task_or_group_uuid=group_result.id) + group_key = str(group_result.id) await self._task_store.create_group( group_key, execution_metadata, [task_key for task_key, _ in task_metadata_pairs], + owner=owner, + user_id=user_id, + product_name=product_name, expiry=group_expiry, ) except CeleryError as exc: for task_key, _ in created: - await self._cleanup_task(task_key) + await self._cleanup_task(task_key, owner=owner, user_id=user_id, product_name=product_name) raise GroupSubmissionError( group_name=execution_metadata.name, @@ -184,27 +198,42 @@ async def submit_task( self, execution_metadata: TaskExecutionMetadata, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, **task_params, ) -> TaskUUID: with log_context( _logger, logging.DEBUG, - msg=f"Submit {execution_metadata.name=}: {owner_metadata=} {task_params=}", + msg=f"Submit {execution_metadata.name=}: {owner=} {user_id=} {product_name=} {task_params=}", ): - task_uuid, task_key = self._create_task_ids(owner_metadata) + task_uuid, task_key = self._create_task_ids() expiry = self._get_task_expiry(execution_metadata) try: - await self._task_store.create_task(task_key, execution_metadata, expiry=expiry) + await self._task_store.create_task( + task_key, + execution_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, + expiry=expiry, + ) + # Forward non-None owner fields so workers can access user_id/product_name + _owner_kwargs: dict[str, Any] = {} + if user_id is not None: + _owner_kwargs["user_id"] = user_id + if product_name is not None: + _owner_kwargs["product_name"] = product_name self._app.send_task( execution_metadata.name, task_id=task_key, - kwargs={"task_key": task_key} | task_params, + kwargs={"task_key": task_key} | _owner_kwargs | task_params, queue=execution_metadata.queue, ) except CeleryError as exc: - await self._cleanup_task(task_key) + await self._cleanup_task(task_key, owner=owner, user_id=user_id, product_name=product_name) raise TaskSubmissionError( task_name=execution_metadata.name, task_key=task_key, @@ -214,45 +243,45 @@ async def submit_task( return task_uuid @handle_celery_errors - async def cancel(self, owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID) -> None: - if await self._is_group(owner_metadata, task_or_group_uuid): - await self._cancel_group(owner_metadata, task_or_group_uuid) + async def cancel(self, task_or_group_uuid: TaskUUID | GroupUUID) -> None: + if await self._is_group(task_or_group_uuid): + await self._cancel_group(task_or_group_uuid) else: - await self._cancel_task(owner_metadata, task_or_group_uuid) + await self._cancel_task(task_or_group_uuid) @handle_celery_errors - async def _cancel_group(self, owner_metadata: OwnerMetadata, group_uuid: GroupUUID) -> None: + async def _cancel_group(self, group_uuid: GroupUUID) -> None: with log_context( _logger, logging.DEBUG, - msg=f"group cancellation: {owner_metadata=} {group_uuid=}", + msg=f"group cancellation: {group_uuid=}", ): - group_key = owner_metadata.model_dump_key(task_or_group_uuid=group_uuid) + group_key: GroupKey = str(group_uuid) if not await self.task_or_group_exists(group_key): - raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid) group_result = await self._restore_group_result(group_uuid) if group_result is not None: for async_result in group_result.results or []: task_key: TaskKey = async_result.id - await self._task_store.remove_task(task_key) + await self._task_store.remove_task_hash(task_key) await self._revoke_and_forget_task(task_key) group_result.forget() - await self._task_store.remove_task(group_key) + await self._task_store.remove_task_hash(group_key) @handle_celery_errors - async def _cancel_task(self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID) -> None: + async def _cancel_task(self, task_uuid: TaskUUID) -> None: with log_context( _logger, logging.DEBUG, - msg=f"task cancellation: {owner_metadata=} {task_uuid=}", + msg=f"task cancellation: {task_uuid=}", ): - task_key = owner_metadata.model_dump_key(task_or_group_uuid=task_uuid) + task_key: TaskKey = str(task_uuid) if not await self.task_or_group_exists(task_key): - raise TaskOrGroupNotFoundError(task_uuid=task_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_uuid=task_uuid) - await self._task_store.remove_task(task_key) + await self._task_store.remove_task_hash(task_key) await self._revoke_and_forget_task(task_key) async def task_or_group_exists(self, task_or_group_key: TaskKey | GroupKey) -> bool: @@ -269,22 +298,22 @@ def _revoke_and_forget_task(self, task_key: TaskKey) -> None: async_result.forget() @handle_celery_errors - async def _get_task_result(self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID) -> Any: + async def _get_task_result(self, task_uuid: TaskUUID) -> Any: with log_context( _logger, logging.DEBUG, - msg=f"Get task result: {owner_metadata=} {task_uuid=}", + msg=f"Get task result: {task_uuid=}", ): - task_key = owner_metadata.model_dump_key(task_or_group_uuid=task_uuid) + task_key: TaskKey = str(task_uuid) if not await self.task_or_group_exists(task_key): - raise TaskOrGroupNotFoundError(task_uuid=task_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_uuid=task_uuid) async_result = self._app.AsyncResult(task_key) result = async_result.result if async_result.ready(): task_metadata = await self._task_store.get_task_metadata(task_key) if task_metadata is not None and task_metadata.ephemeral: - await self._task_store.remove_task(task_key) + await self._task_store.remove_task_hash(task_key) await self._forget_task(task_key) return result @@ -328,48 +357,48 @@ async def _get_progress_with_description( def _get_task_celery_state(self, task_key: TaskKey) -> TaskState: return TaskState(self._app.AsyncResult(task_key).state) - async def _get_group_result(self, owner_metadata: OwnerMetadata, group_uuid: GroupUUID) -> list[Any]: + async def _get_group_result(self, group_uuid: GroupUUID) -> list[Any]: with log_context( _logger, logging.DEBUG, - msg=f"Get group result: {owner_metadata=} {group_uuid=}", + msg=f"Get group result: {group_uuid=}", ): - group_key = owner_metadata.model_dump_key(task_or_group_uuid=group_uuid) + group_key: GroupKey = str(group_uuid) if not await self.task_or_group_exists(group_key): - raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid) group_result = await self._restore_group_result(group_uuid) if group_result is None: - raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid) results: list[Any] = [async_result.result for async_result in (group_result.results or [])] if group_result.ready(): task_metadata = await self._task_store.get_task_metadata(group_key) if task_metadata is not None and task_metadata.ephemeral: - await self._cancel_group(owner_metadata, group_uuid) + await self._cancel_group(group_uuid) return results - async def _get_group_status(self, owner_metadata: OwnerMetadata, group_uuid: GroupUUID) -> GroupStatus: + async def _get_group_status(self, group_uuid: GroupUUID) -> GroupStatus: with log_context( _logger, logging.DEBUG, - msg=f"Getting group status: {owner_metadata=} {group_uuid=}", + msg=f"Getting group status: {group_uuid=}", ): - group_key = owner_metadata.model_dump_key(task_or_group_uuid=group_uuid) + group_key: GroupKey = str(group_uuid) if not await self.task_or_group_exists(group_key): - raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid) group_result = await self._restore_group_result(group_uuid) if group_result is None: - raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_or_group_uuid=group_uuid) # Get task UUIDs from the group result - # AsyncResult objects have .id attribute containing the task key + # AsyncResult objects have .id attribute containing the task key (UUID string) task_uuids = [ - OwnerMetadata.get_task_or_group_uuid(async_result.id) for async_result in (group_result.results or []) + TypeAdapter(TaskUUID).validate_python(async_result.id) for async_result in (group_result.results or []) ] # Check group status @@ -395,15 +424,15 @@ async def _get_group_status(self, owner_metadata: OwnerMetadata, group_uuid: Gro progress_report=progress_report, ) - async def _get_task_status(self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID) -> TaskStatus: + async def _get_task_status(self, task_uuid: TaskUUID) -> TaskStatus: with log_context( _logger, logging.DEBUG, - msg=f"Getting task status: {owner_metadata=} {task_uuid=}", + msg=f"Getting task status: {task_uuid=}", ): - task_key = owner_metadata.model_dump_key(task_or_group_uuid=task_uuid) + task_key: TaskKey = str(task_uuid) if not await self.task_or_group_exists(task_key): - raise TaskOrGroupNotFoundError(task_uuid=task_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_uuid=task_uuid) task_state = await self._get_task_celery_state(task_key) return TaskStatus( @@ -414,29 +443,26 @@ async def _get_task_status(self, owner_metadata: OwnerMetadata, task_uuid: TaskU async def _is_group( self, - owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID, ) -> bool: - task_or_group_key = owner_metadata.model_dump_key(task_or_group_uuid=task_or_group_uuid) + task_or_group_key: TaskKey | GroupKey = str(task_or_group_uuid) if not await self.task_or_group_exists(task_or_group_key): - raise TaskOrGroupNotFoundError(task_uuid=task_or_group_uuid, owner_metadata=owner_metadata) + raise TaskOrGroupNotFoundError(task_uuid=task_or_group_uuid) task_metadata = await self._task_store.get_task_metadata(task_or_group_key) return task_metadata is not None and task_metadata.type == ExecutorType.GROUP @handle_celery_errors - async def get_result(self, owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID) -> Any: - if await self._is_group(owner_metadata, task_or_group_uuid): - return await self._get_group_result(owner_metadata, task_or_group_uuid) - return await self._get_task_result(owner_metadata, task_or_group_uuid) + async def get_result(self, task_or_group_uuid: TaskUUID | GroupUUID) -> Any: + if await self._is_group(task_or_group_uuid): + return await self._get_group_result(task_or_group_uuid) + return await self._get_task_result(task_or_group_uuid) @handle_celery_errors - async def get_status( - self, owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID - ) -> TaskStatus | GroupStatus: - if await self._is_group(owner_metadata, task_or_group_uuid): - return await self._get_group_status(owner_metadata, task_or_group_uuid) - return await self._get_task_status(owner_metadata, task_or_group_uuid) + async def get_status(self, task_or_group_uuid: TaskUUID | GroupUUID) -> TaskStatus | GroupStatus: + if await self._is_group(task_or_group_uuid): + return await self._get_group_status(task_or_group_uuid) + return await self._get_task_status(task_or_group_uuid) @make_async() def _restore_group_result(self, group_uuid: GroupUUID) -> GroupResult | None: @@ -448,9 +474,17 @@ def _restore_group_result(self, group_uuid: GroupUUID) -> GroupResult | None: return None @handle_celery_errors - async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: - with log_context(_logger, logging.DEBUG, "Listing tasks: owner_metadata=%s", owner_metadata): - return await self._task_store.list_tasks(owner_metadata) + async def list_tasks( + self, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> list[Task]: + with log_context( + _logger, logging.DEBUG, "Listing tasks: owner=%s user_id=%s product_name=%s", owner, user_id, product_name + ): + return await self._task_store.list_tasks(owner=owner, user_id=user_id, product_name=product_name) @handle_celery_errors async def set_task_progress(self, task_key: TaskKey, report: ProgressReport) -> None: @@ -492,7 +526,6 @@ async def push_task_stream_items(self, task_key: TaskKey, *items: TaskStreamItem @handle_celery_errors async def pull_task_stream_items( self, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, offset: int = 0, limit: int = 50, @@ -500,13 +533,12 @@ async def pull_task_stream_items( with log_context( _logger, logging.DEBUG, - "Pull task results: owner_metadata=%s task_uuid=%s offset=%s limit=%s", - owner_metadata, + "Pull task results: task_uuid=%s offset=%s limit=%s", task_uuid, offset, limit, ): - task_key = owner_metadata.model_dump_key(task_or_group_uuid=task_uuid) + task_key: TaskKey = str(task_uuid) if not await self.task_or_group_exists(task_key): raise TaskOrGroupNotFoundError(task_key=task_key) diff --git a/packages/celery-library/src/celery_library/async_jobs.py b/packages/celery-library/src/celery_library/async_jobs.py index 0b90c629233a..30a22b3f2d55 100644 --- a/packages/celery-library/src/celery_library/async_jobs.py +++ b/packages/celery-library/src/celery_library/async_jobs.py @@ -19,7 +19,7 @@ JobNotDoneError, JobSchedulerError, ) -from models_library.celery import OwnerMetadata, TaskExecutionMetadata, TaskState, TaskStatus +from models_library.celery import TaskExecutionMetadata, TaskState, TaskStatus from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_catch from tenacity import ( @@ -46,12 +46,10 @@ async def cancel_job( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, job_id: AsyncJobId, ) -> None: try: await task_manager.cancel( - owner_metadata=owner_metadata, task_or_group_uuid=job_id, ) except TaskOrGroupNotFoundError as exc: @@ -63,22 +61,18 @@ async def cancel_job( async def get_job_result( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, job_id: AsyncJobId, ) -> AsyncJobResult: assert task_manager # nosec assert job_id # nosec - assert owner_metadata # nosec try: task_status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=job_id, ) if not task_status.is_done: raise JobNotDoneError(job_id=job_id) task_result = await task_manager.get_result( - owner_metadata=owner_metadata, task_or_group_uuid=job_id, ) except TaskOrGroupNotFoundError as exc: @@ -112,12 +106,10 @@ async def get_job_result( async def get_job_status( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, job_id: AsyncJobId, ) -> AsyncJobStatus: try: task_status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=job_id, ) except TaskOrGroupNotFoundError as exc: @@ -135,12 +127,16 @@ async def get_job_status( async def list_jobs( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, ) -> list[AsyncJobGet]: assert task_manager # nosec try: tasks = await task_manager.list_tasks( - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, ) except TaskManagerError as exc: raise JobSchedulerError(exc=f"{exc}") from exc @@ -152,12 +148,16 @@ async def submit_job( task_manager: TaskManager, *, execution_metadata: TaskExecutionMetadata, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, **kwargs, ) -> AsyncJobGet: task_id = await task_manager.submit_task( execution_metadata=execution_metadata, - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, **kwargs, ) return AsyncJobGet(job_id=task_id, job_name=execution_metadata.name) @@ -166,7 +166,6 @@ async def submit_job( async def _wait_for_completion( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, job_id: AsyncJobId, stop_after: datetime.timedelta, ) -> AsyncGenerator[AsyncJobStatus]: @@ -181,7 +180,6 @@ async def _wait_for_completion( with attempt: job_status = await get_job_status( task_manager, - owner_metadata=owner_metadata, job_id=job_id, ) yield job_status @@ -214,7 +212,6 @@ async def result(self) -> Any: async def wait_and_get_job_result( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, job_id: AsyncJobId, stop_after: datetime.timedelta, ) -> AsyncGenerator[AsyncJobResultUpdate]: @@ -225,7 +222,6 @@ async def wait_and_get_job_result( async for job_status in _wait_for_completion( task_manager, job_id=job_id, - owner_metadata=owner_metadata, stop_after=stop_after, ): assert job_status is not None # nosec @@ -237,7 +233,6 @@ async def wait_and_get_job_result( job_status, get_job_result( task_manager, - owner_metadata=owner_metadata, job_id=job_id, ), ) @@ -245,7 +240,6 @@ async def wait_and_get_job_result( try: await cancel_job( task_manager, - owner_metadata=owner_metadata, job_id=job_id, ) except Exception as exc: @@ -257,7 +251,9 @@ async def submit_job_and_wait( task_manager: TaskManager, *, execution_metadata: TaskExecutionMetadata, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, stop_after: datetime.timedelta, **kwargs, ) -> AsyncGenerator[AsyncJobResultUpdate]: @@ -266,14 +262,15 @@ async def submit_job_and_wait( async_job = await submit_job( task_manager, execution_metadata=execution_metadata, - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, **kwargs, ) except (TimeoutError, CancelledError): if async_job is not None: await cancel_job( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) raise @@ -281,7 +278,6 @@ async def submit_job_and_wait( async for wait_and_ in wait_and_get_job_result( task_manager, job_id=async_job.job_id, - owner_metadata=owner_metadata, stop_after=stop_after, ): yield wait_and_ diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 93d3995665c1..515ecf0a9ee7 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -5,19 +5,17 @@ from itertools import product from typing import TYPE_CHECKING, Final -from common_library.json_serialization import json_dumps from models_library.celery import ( - WILDCARD, ExecutionMetadata, ExecutorType, GroupExecutionMetadata, GroupKey, - OwnerMetadata, Task, TaskExecutionMetadata, TaskKey, TaskStore, TaskStreamItem, + TaskUUID, ) from models_library.progress_bar import ProgressReport from pydantic import TypeAdapter, ValidationError @@ -30,7 +28,8 @@ _CELERY_TASK_EXEC_METADATA_KEY: Final[str] = "exec-meta" _CELERY_TASK_PROGRESS_KEY: Final[str] = "progress" _CELERY_TASK_INDEX_PREFIX: Final[str] = "celery-task-index-" -_UUID_KEY_PREFIX: Final[str] = "uuid=" + +_TASK_UUID_ADAPTER: Final[TypeAdapter[TaskUUID]] = TypeAdapter(TaskUUID) # Redis list to store streamed results _CELERY_TASK_STREAM_PREFIX: Final[str] = "celery-task-stream-" @@ -42,6 +41,9 @@ _logger = logging.getLogger(__name__) +# --- key builders --- + + def _build_redis_task_or_group_key(key: TaskKey | GroupKey) -> str: return f"{_CELERY_TASK_PREFIX}{key}" @@ -54,34 +56,56 @@ def _build_redis_stream_meta_key(task_key: TaskKey) -> str: return f"{_build_redis_stream_key(task_key)}{_CELERY_TASK_DELIMTATOR}{_CELERY_TASK_STREAM_METADATA}" -def _without_uuid_token(task_or_group_key: TaskKey | GroupKey) -> str: - return _CELERY_TASK_DELIMTATOR.join( - token for token in task_or_group_key.split(_CELERY_TASK_DELIMTATOR) if not token.startswith(_UUID_KEY_PREFIX) - ) +def _build_redis_index_key(suffix: str) -> str: + return f"{_CELERY_TASK_INDEX_PREFIX}{suffix}" -def _build_redis_owner_index_key(owner_key_without_uuid: str) -> str: - return f"{_CELERY_TASK_INDEX_PREFIX}{owner_key_without_uuid}" +def _concrete_owner_fields( + owner: str, + user_id: int | None, + product_name: str | None, +) -> list[tuple[str, str | int]]: + """Return (field_name, value) pairs for non-None owner fields, sorted by name.""" + pairs: list[tuple[str, str | int]] = [("owner", owner)] + if user_id is not None: + pairs.append(("user_id", user_id)) + if product_name is not None: + pairs.append(("product_name", product_name)) + return sorted(pairs) -def _build_redis_owner_index_key_for_query(owner_metadata: OwnerMetadata) -> str: - owner_key = owner_metadata.model_dump_key(task_or_group_uuid=WILDCARD) - return _build_redis_owner_index_key(_without_uuid_token(owner_key)) +def _build_redis_index_key_for_query( + owner: str, + user_id: int | None, + product_name: str | None, +) -> str: + """Build the single sorted-set key used to answer a list_tasks query. + Concrete fields are kept; missing (None) fields are omitted — + the sorted set was pre-populated for every field subset at creation time. + """ + parts = [f"{k}={v}" for k, v in _concrete_owner_fields(owner, user_id, product_name)] + return _build_redis_index_key(_CELERY_TASK_DELIMTATOR.join(parts)) -def _build_redis_owner_index_keys_for_task(task_or_group_key: TaskKey | GroupKey) -> list[str]: - owner_tokens = [ - token.split("=", maxsplit=1) for token in _without_uuid_token(task_or_group_key).split(_CELERY_TASK_DELIMTATOR) - ] - wildcard_value = json_dumps(WILDCARD) + +def _build_redis_index_keys_for_creation( + owner: str, + user_id: int | None, + product_name: str | None, +) -> list[str]: + """Generate all 2^n sorted-set index keys for the given owner fields. + + Every subset of the concrete fields gets its own key so that any query + specifying a subset of those fields can be answered with a single + sorted-set lookup. + """ + fields = _concrete_owner_fields(owner, user_id, product_name) keys: list[str] = [] - for mask in product((False, True), repeat=len(owner_tokens)): - query_owner_key = _CELERY_TASK_DELIMTATOR.join( - f"{key}={wildcard_value if use_wildcard else value}" - for (key, value), use_wildcard in zip(owner_tokens, mask, strict=True) - ) - keys.append(_build_redis_owner_index_key(query_owner_key)) + for mask in product((False, True), repeat=len(fields)): + selected = [(k, v) for (k, v), include in zip(fields, mask, strict=True) if include] + suffix = _CELERY_TASK_DELIMTATOR.join(f"{k}={v}" for k, v in selected) if selected else "" + keys.append(_build_redis_index_key(suffix)) return keys @@ -94,6 +118,10 @@ async def create_group( group_key: GroupKey, execution_metadata: GroupExecutionMetadata, task_keys: list[TaskKey], + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, expiry: timedelta, ) -> None: redis_group_key = _build_redis_task_or_group_key(group_key) @@ -105,7 +133,7 @@ async def create_group( key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) - for index_key in _build_redis_owner_index_keys_for_task(group_key): + for index_key in _build_redis_index_keys_for_creation(owner, user_id, product_name): pipe.zadd(index_key, {group_key: index_score}) # group sub-tasks: store hash only, no ZSET index (filtered out in list_tasks) @@ -125,6 +153,10 @@ async def create_task( self, task_key: TaskKey, execution_metadata: TaskExecutionMetadata, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, expiry: timedelta, ) -> None: redis_key = _build_redis_task_or_group_key(task_key) @@ -136,7 +168,7 @@ async def create_task( key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) - for index_key in _build_redis_owner_index_keys_for_task(task_key): + for index_key in _build_redis_index_keys_for_creation(owner, user_id, product_name): pipe.zadd(index_key, {task_key: index_score}) await pipe.execute() @@ -186,8 +218,14 @@ async def get_task_progress(self, task_key: TaskKey) -> ProgressReport | None: ) return None - async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: - owner_index_key = _build_redis_owner_index_key_for_query(owner_metadata) + async def list_tasks( + self, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> list[Task]: + owner_index_key = _build_redis_index_key_for_query(owner, user_id, product_name) raw_members = await self._redis_client_sdk.redis.zrange(owner_index_key, 0, -1) if not raw_members: @@ -215,7 +253,7 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: tasks.append( Task( - uuid=OwnerMetadata.get_task_or_group_uuid(member), + uuid=_TASK_UUID_ADAPTER.validate_python(member), metadata=execution_metadata, ) ) @@ -225,13 +263,23 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: return tasks - async def remove_task(self, task_key: TaskKey) -> None: + async def remove_task( + self, + task_key: TaskKey, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> None: pipe = self._redis_client_sdk.redis.pipeline() pipe.delete(_build_redis_task_or_group_key(task_key)) - for index_key in _build_redis_owner_index_keys_for_task(task_key): + for index_key in _build_redis_index_keys_for_creation(owner, user_id, product_name): pipe.zrem(index_key, task_key) await pipe.execute() + async def remove_task_hash(self, task_key: TaskKey) -> None: + await self._redis_client_sdk.redis.delete(_build_redis_task_or_group_key(task_key)) + async def set_task_progress(self, task_key: TaskKey, report: ProgressReport) -> None: await handle_redis_returns_union_types( self._redis_client_sdk.redis.hset( diff --git a/packages/celery-library/src/celery_library/errors.py b/packages/celery-library/src/celery_library/errors.py index 87603cb3d5bd..52caa7a63ca8 100644 --- a/packages/celery-library/src/celery_library/errors.py +++ b/packages/celery-library/src/celery_library/errors.py @@ -39,7 +39,7 @@ class TaskSubmissionError(OsparcErrorMixin, Exception): class TaskOrGroupNotFoundError(OsparcErrorMixin, Exception): - msg_template = "Task or group with uuid '{task_uuid}' and owner_metadata '{owner_metadata}' was not found" + msg_template = "Task or group with uuid '{task_uuid}' was not found" class TaskManagerError(OsparcErrorMixin, Exception): diff --git a/packages/celery-library/tests/unit/conftest.py b/packages/celery-library/tests/unit/conftest.py index 1c341a75a165..e4d7627b48e7 100644 --- a/packages/celery-library/tests/unit/conftest.py +++ b/packages/celery-library/tests/unit/conftest.py @@ -1,12 +1,11 @@ import pytest -from models_library.celery import OwnerMetadata -from models_library.users import UserID -class MyOwnerMetadata(OwnerMetadata): - user_id: UserID +@pytest.fixture +def fake_owner() -> str: + return "test-owner" @pytest.fixture -def fake_owner_metadata() -> OwnerMetadata: - return MyOwnerMetadata(user_id=42, owner="test-owner") +def fake_user_id() -> int: + return 42 diff --git a/packages/celery-library/tests/unit/task_manager/conftest.py b/packages/celery-library/tests/unit/task_manager/conftest.py index becd4f657005..56fc4ef335d9 100644 --- a/packages/celery-library/tests/unit/task_manager/conftest.py +++ b/packages/celery-library/tests/unit/task_manager/conftest.py @@ -15,7 +15,6 @@ from common_library.errors_classes import OsparcErrorMixin from models_library.celery import ( TASK_DONE_STATES, - OwnerMetadata, TaskKey, TaskState, TaskStatus, @@ -143,38 +142,35 @@ def _(celery_app: Celery) -> None: async def wait_for_task_success( task_manager: TaskManager, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, ) -> None: """Wait for a task to reach SUCCESS state.""" async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - status = await task_manager.get_status(owner_metadata, task_uuid) + status = await task_manager.get_status(task_uuid) assert isinstance(status, TaskStatus) assert status.task_state == TaskState.SUCCESS async def wait_for_task_not_pending( task_manager: TaskManager, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, ) -> None: """Wait for a task to leave PENDING state (i.e. the worker picked it up).""" async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - status = await task_manager.get_status(owner_metadata, task_uuid) + status = await task_manager.get_status(task_uuid) assert isinstance(status, TaskStatus) assert status.task_state != TaskState.PENDING async def wait_for_task_done( task_manager: TaskManager, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, ) -> None: """Wait for a task to reach any DONE state.""" async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - status = await task_manager.get_status(owner_metadata, task_uuid) + status = await task_manager.get_status(task_uuid) assert isinstance(status, TaskStatus) assert status.task_state in TASK_DONE_STATES diff --git a/packages/celery-library/tests/unit/task_manager/test_task_manager_cancel.py b/packages/celery-library/tests/unit/task_manager/test_task_manager_cancel.py index c7535aad3539..af1031581491 100644 --- a/packages/celery-library/tests/unit/task_manager/test_task_manager_cancel.py +++ b/packages/celery-library/tests/unit/task_manager/test_task_manager_cancel.py @@ -15,7 +15,6 @@ from models_library.celery import ( GroupExecutionMetadata, GroupTaskExecutionMetadata, - OwnerMetadata, TaskExecutionMetadata, TaskState, TaskStatus, @@ -31,28 +30,31 @@ async def test_cancel_single_task_calls_revoke( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): """Cancelling a single task must call revoke() on the AsyncResult so the worker skips it entirely. """ task_uuid = await task_manager.submit_task( TaskExecutionMetadata(name=noop_task.__name__), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) with patch("celery.result.AsyncResult.revoke") as mock_revoke: - await task_manager.cancel(fake_owner_metadata, task_uuid) + await task_manager.cancel(task_uuid) mock_revoke.assert_called_once() with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, task_uuid) + await task_manager.get_status(task_uuid) async def test_cancel_group_calls_revoke_for_each_task( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): """Cancelling a group must call revoke() on every sub-task's AsyncResult. @@ -71,25 +73,27 @@ async def test_cancel_group_calls_revoke_for_each_task( name="rate_limited_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) with patch("celery.result.AsyncResult.revoke") as mock_revoke: - await task_manager.cancel(fake_owner_metadata, group_uuid) + await task_manager.cancel(group_uuid) assert mock_revoke.call_count == num_tasks with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, group_uuid) + await task_manager.get_status(group_uuid) for task_uuid in task_uuids: with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, task_uuid) + await task_manager.get_status(task_uuid) async def test_new_task_succeeds_after_cancelling_rate_limited_group( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): """After cancelling a group of rate-limited tasks, a newly submitted rate-limited task still completes successfully. @@ -108,19 +112,21 @@ async def test_new_task_succeeds_after_cancelling_rate_limited_group( name="rate_limited_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) - await task_manager.cancel(fake_owner_metadata, group_uuid) + await task_manager.cancel(group_uuid) new_task_uuid = await task_manager.submit_task( TaskExecutionMetadata(name=noop_task.__name__), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) - await wait_for_task_success(task_manager, fake_owner_metadata, new_task_uuid) + await wait_for_task_success(task_manager, new_task_uuid) - status = await task_manager.get_status(fake_owner_metadata, new_task_uuid) + status = await task_manager.get_status(new_task_uuid) assert isinstance(status, TaskStatus) assert status.task_state == TaskState.SUCCESS @@ -128,26 +134,29 @@ async def test_new_task_succeeds_after_cancelling_rate_limited_group( async def test_new_task_succeeds_after_cancelling_single_rate_limited_task( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): """Cancel a single rate-limited task, then verify a new submission completes successfully. """ task_uuid = await task_manager.submit_task( TaskExecutionMetadata(name=noop_task.__name__), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) - await wait_for_task_not_pending(task_manager, fake_owner_metadata, task_uuid) - await task_manager.cancel(fake_owner_metadata, task_uuid) + await wait_for_task_not_pending(task_manager, task_uuid) + await task_manager.cancel(task_uuid) new_task_uuid = await task_manager.submit_task( TaskExecutionMetadata(name=noop_task.__name__), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) - await wait_for_task_success(task_manager, fake_owner_metadata, new_task_uuid) + await wait_for_task_success(task_manager, new_task_uuid) - status = await task_manager.get_status(fake_owner_metadata, new_task_uuid) + status = await task_manager.get_status(new_task_uuid) assert isinstance(status, TaskStatus) assert status.task_state == TaskState.SUCCESS diff --git a/packages/celery-library/tests/unit/task_manager/test_task_manager_description.py b/packages/celery-library/tests/unit/task_manager/test_task_manager_description.py index a0719dc8d2b6..500ee15915f7 100644 --- a/packages/celery-library/tests/unit/task_manager/test_task_manager_description.py +++ b/packages/celery-library/tests/unit/task_manager/test_task_manager_description.py @@ -1,12 +1,10 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument - from celery.worker.worker import WorkController # pylint: disable=no-name-in-module from models_library.celery import ( GroupExecutionMetadata, GroupStatus, GroupTaskExecutionMetadata, - OwnerMetadata, TaskExecutionMetadata, TaskStatus, ) @@ -24,7 +22,8 @@ async def test_task_description_is_returned_in_progress_message( task_manager: TaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): description = "Processing important files" task_uuid = await task_manager.submit_task( @@ -32,22 +31,20 @@ async def test_task_description_is_returned_in_progress_message( name=fake_file_processor.__name__, description=description, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, files=[f"file{n}" for n in range(3)], ) - # Check that the description appears in progress while task is running async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - status = await task_manager.get_status(fake_owner_metadata, task_uuid) + status = await task_manager.get_status(task_uuid) assert isinstance(status, TaskStatus) assert status.progress_report.message is not None assert status.progress_report.message.description == description - - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) - + await wait_for_task_success(task_manager, task_uuid) # Check that the description is still present after completion - final_status = await task_manager.get_status(fake_owner_metadata, task_uuid) + final_status = await task_manager.get_status(task_uuid) assert isinstance(final_status, TaskStatus) assert final_status.progress_report.message is not None assert final_status.progress_report.message.description == description @@ -56,17 +53,18 @@ async def test_task_description_is_returned_in_progress_message( async def test_task_without_description_has_no_message_in_progress( task_manager: TaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=dreamer_task.__name__, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) - # Check initial status has no message - status = await task_manager.get_status(fake_owner_metadata, task_uuid) + status = await task_manager.get_status(task_uuid) assert isinstance(status, TaskStatus) assert status.progress_report.message is None @@ -74,7 +72,8 @@ async def test_task_without_description_has_no_message_in_progress( async def test_group_description_is_returned_in_progress_message( task_manager: TaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): description = "Processing files group" group_uuid, task_uuids = await task_manager.submit_group( @@ -88,19 +87,17 @@ async def test_group_description_is_returned_in_progress_message( ) ], ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) - async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - status = await task_manager.get_status(fake_owner_metadata, group_uuid) + status = await task_manager.get_status(group_uuid) assert isinstance(status, GroupStatus) assert status.progress_report.message is not None assert status.progress_report.message.description == description - - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuids[0]) - - final_status = await task_manager.get_status(fake_owner_metadata, group_uuid) + await wait_for_task_success(task_manager, task_uuids[0]) + final_status = await task_manager.get_status(group_uuid) assert isinstance(final_status, GroupStatus) assert final_status.progress_report.message is not None assert final_status.progress_report.message.description == description diff --git a/packages/celery-library/tests/unit/task_manager/test_task_manager_groups.py b/packages/celery-library/tests/unit/task_manager/test_task_manager_groups.py index 23fb87b0e0a0..effad7970124 100644 --- a/packages/celery-library/tests/unit/task_manager/test_task_manager_groups.py +++ b/packages/celery-library/tests/unit/task_manager/test_task_manager_groups.py @@ -14,7 +14,6 @@ GroupStatus, GroupTaskExecutionMetadata, GroupUUID, - OwnerMetadata, TaskExecutionMetadata, TaskUUID, ) @@ -37,7 +36,8 @@ async def test_submit_group_all_tasks_complete_successfully( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group of tasks num_tasks = 3 @@ -54,7 +54,8 @@ async def test_submit_group_all_tasks_complete_successfully( name="fake_file_processing_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert group_id is not None @@ -62,18 +63,19 @@ async def test_submit_group_all_tasks_complete_successfully( # Wait for all tasks to complete for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # Verify all results for task_uuid in task_uuids: - result = await task_manager.get_result(fake_owner_metadata, task_uuid) + result = await task_manager.get_result(task_uuid) assert result == "archive.zip" async def test_submit_group_tasks_appear_in_listing( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group of tasks num_tasks = 4 @@ -93,26 +95,28 @@ async def test_submit_group_tasks_appear_in_listing( name="tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) # Verify none of group tasks appear in listing async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - tasks = await task_manager.list_tasks(fake_owner_metadata) + tasks = await task_manager.list_tasks(owner=fake_owner, user_id=fake_user_id) task_uuids_from_list = {task.uuid for task in tasks} assert all(uuid not in task_uuids_from_list for uuid in task_uuids) finally: # Clean up for task_uuid in task_uuids: with contextlib.suppress(TaskOrGroupNotFoundError): - await task_manager.cancel(fake_owner_metadata, task_uuid) + await task_manager.cancel(task_uuid) async def test_submit_group_with_mixed_task_types( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group with different task types group_tasks = [ @@ -135,28 +139,30 @@ async def test_submit_group_with_mixed_task_types( name="mixed_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert len(task_uuids) == 3 # Wait for all tasks to complete for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # Verify first two tasks return "archive.zip" - assert await task_manager.get_result(fake_owner_metadata, task_uuids[0]) == "archive.zip" - assert await task_manager.get_result(fake_owner_metadata, task_uuids[1]) == "archive.zip" + assert await task_manager.get_result(task_uuids[0]) == "archive.zip" + assert await task_manager.get_result(task_uuids[1]) == "archive.zip" # Verify streaming task result - result = await task_manager.get_result(fake_owner_metadata, task_uuids[2]) + result = await task_manager.get_result(task_uuids[2]) assert result == "completed-2-results" async def test_submit_group_can_cancel_individual_tasks( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group of long-running tasks num_tasks = 3 @@ -173,28 +179,30 @@ async def test_submit_group_can_cancel_individual_tasks( name="cancellable_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) # Wait a bit to ensure tasks are running await asyncio.sleep(2.0) # Cancel the first task - await task_manager.cancel(fake_owner_metadata, task_uuids[0]) + await task_manager.cancel(task_uuids[0]) # Verify first task is gone with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, task_uuids[0]) + await task_manager.get_status(task_uuids[0]) # Cancel remaining tasks for task_uuid in task_uuids[1:]: - await task_manager.cancel(fake_owner_metadata, task_uuid) + await task_manager.cancel(task_uuid) async def test_cancelling_a_group_cancels_all_tasks( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): num_tasks = 3 group_tasks = [ @@ -210,28 +218,30 @@ async def test_cancelling_a_group_cancels_all_tasks( name="cancellable_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) # Wait a bit to ensure tasks are running await asyncio.sleep(2.0) - await task_manager.cancel(fake_owner_metadata, group_uuid) + await task_manager.cancel(group_uuid) # Group itself should no longer exist with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, group_uuid) + await task_manager.get_status(group_uuid) # All individual tasks should also be gone for task_uuid in task_uuids: with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, task_uuid) + await task_manager.get_status(task_uuid) async def test_submit_group_with_failures( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group with some failing tasks group_tasks = [ @@ -254,21 +264,22 @@ async def test_submit_group_with_failures( name="group_with_failures", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert len(task_uuids) == 3 # Wait for all tasks to finish for task_uuid in task_uuids: - await wait_for_task_done(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_done(task_manager, task_uuid) # Verify successful tasks - assert await task_manager.get_result(fake_owner_metadata, task_uuids[0]) == "archive.zip" - assert await task_manager.get_result(fake_owner_metadata, task_uuids[2]) == "archive.zip" + assert await task_manager.get_result(task_uuids[0]) == "archive.zip" + assert await task_manager.get_result(task_uuids[2]) == "archive.zip" # Verify failed task - result = await task_manager.get_result(fake_owner_metadata, task_uuids[1]) + result = await task_manager.get_result(task_uuids[1]) assert isinstance(result, TransferableCeleryError) assert "Something strange happened: BOOM!" in f"{result}" @@ -276,7 +287,8 @@ async def test_submit_group_with_failures( async def test_submit_group_with_ephemeral_tasks( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group with ephemeral tasks num_tasks = 2 @@ -293,36 +305,39 @@ async def test_submit_group_with_ephemeral_tasks( name="ephemeral_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert len(task_uuids) == num_tasks # Wait for all tasks to complete and get results (which should clean them up) for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # Getting the result should trigger cleanup for ephemeral tasks - result = await task_manager.get_result(fake_owner_metadata, task_uuid) + result = await task_manager.get_result(task_uuid) assert result == "archive.zip" for task_uuid in task_uuids: # Second attempt to get result should fail as ephemeral tasks are cleaned up with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, task_uuid) + await task_manager.get_status(task_uuid) async def test_submit_empty_group( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): _, task_uuids = await task_manager.submit_group( GroupExecutionMetadata( name="empty_group", tasks=[], ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert task_uuids == [] @@ -331,7 +346,8 @@ async def test_submit_empty_group( async def test_get_group_status_returns_status_for_running_group( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group of long-running tasks num_tasks = 3 @@ -348,7 +364,8 @@ async def test_get_group_status_returns_status_for_running_group( name="running_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) try: @@ -356,7 +373,7 @@ async def test_get_group_status_returns_status_for_running_group( await asyncio.sleep(1.0) # Get group status while tasks are running - group_status = await task_manager.get_status(fake_owner_metadata, group_id) + group_status = await task_manager.get_status(group_id) assert isinstance(group_status, GroupStatus) assert group_status.group_uuid == group_id @@ -368,13 +385,14 @@ async def test_get_group_status_returns_status_for_running_group( # Clean up for task_uuid in task_uuids: with contextlib.suppress(TaskOrGroupNotFoundError): - await task_manager.cancel(fake_owner_metadata, task_uuid) + await task_manager.cancel(task_uuid) async def test_get_group_status_returns_done_when_all_tasks_complete( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group of fast tasks num_tasks = 2 @@ -391,15 +409,16 @@ async def test_get_group_status_returns_done_when_all_tasks_complete( name="fast_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) # Wait for all tasks to complete for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # Get group status - group_status = await task_manager.get_status(fake_owner_metadata, group_id) + group_status = await task_manager.get_status(group_id) assert isinstance(group_status, GroupStatus) assert group_status.group_uuid == group_id @@ -413,7 +432,8 @@ async def test_get_group_status_returns_done_when_all_tasks_complete( async def test_get_group_status_successful_false_when_task_fails( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group with one failing task group_tasks = [ @@ -432,15 +452,16 @@ async def test_get_group_status_successful_false_when_task_fails( name="failing_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) # Wait for all tasks to finish for task_uuid in task_uuids: - await wait_for_task_done(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_done(task_manager, task_uuid) # Get group status - group_status = await task_manager.get_status(fake_owner_metadata, group_id) + group_status = await task_manager.get_status(group_id) assert isinstance(group_status, GroupStatus) assert group_status.group_uuid == group_id @@ -455,18 +476,20 @@ async def test_get_group_status_successful_false_when_task_fails( async def test_get_group_status_with_nonexistent_group_raises_error( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): fake_group_uuid = TypeAdapter(GroupUUID).validate_python(_faker.uuid4()) with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, fake_group_uuid) + await task_manager.get_status(fake_group_uuid) async def test_get_group_status_tracks_progress( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit a group of longer-running tasks num_tasks = 4 @@ -483,7 +506,8 @@ async def test_get_group_status_tracks_progress( name="long_running_tasks_group", tasks=group_tasks, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) try: @@ -492,7 +516,7 @@ async def test_get_group_status_tracks_progress( group_status = None async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - group_status = await task_manager.get_status(fake_owner_metadata, group_id) + group_status = await task_manager.get_status(group_id) assert isinstance(group_status, GroupStatus) # Progress should never go backwards @@ -510,13 +534,14 @@ async def test_get_group_status_tracks_progress( # Clean up for task_uuid in task_uuids: with contextlib.suppress(TaskOrGroupNotFoundError): - await task_manager.cancel(fake_owner_metadata, task_uuid) + await task_manager.cancel(task_uuid) async def test_get_group_status_with_empty_group( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit an empty group group_id, _ = await task_manager.submit_group( @@ -524,11 +549,12 @@ async def test_get_group_status_with_empty_group( name="empty_group", tasks=[], ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) # Get group status - group_status = await task_manager.get_status(fake_owner_metadata, group_id) + group_status = await task_manager.get_status(group_id) assert isinstance(group_status, GroupStatus) assert group_status.group_uuid == group_id @@ -542,23 +568,26 @@ async def test_get_group_status_with_empty_group( async def test_get_result_dispatches_to_task_result( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): task_uuid = await task_manager.submit_task( TaskExecutionMetadata(name=fake_file_processor.__name__), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, files=["file1"], ) - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) - result = await task_manager.get_result(fake_owner_metadata, task_uuid) + result = await task_manager.get_result(task_uuid) assert result == "archive.zip" async def test_get_result_dispatches_to_group_result( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): group_tasks = [ ( @@ -570,13 +599,14 @@ async def test_get_result_dispatches_to_group_result( group_id, task_uuids = await task_manager.submit_group( GroupExecutionMetadata(name="result_dispatch_group", tasks=group_tasks), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) - result = await task_manager.get_result(fake_owner_metadata, group_id) + result = await task_manager.get_result(group_id) assert isinstance(result, list) assert len(result) == 2 assert all(r == "archive.zip" for r in result) @@ -585,17 +615,19 @@ async def test_get_result_dispatches_to_group_result( async def test_get_result_with_nonexistent_uuid_raises_error( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): fake_uuid = TypeAdapter(TaskUUID).validate_python(_faker.uuid4()) with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_result(fake_owner_metadata, fake_uuid) + await task_manager.get_result(fake_uuid) async def test_get_group_result_returns_all_results( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): num_tasks = 3 group_tasks = [ @@ -608,20 +640,22 @@ async def test_get_group_result_returns_all_results( group_id, task_uuids = await task_manager.submit_group( GroupExecutionMetadata(name="all_results_group", tasks=group_tasks), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) - results = await task_manager.get_result(fake_owner_metadata, group_id) + results = await task_manager.get_result(group_id) assert results == ["archive.zip"] * num_tasks async def test_get_group_result_with_failures( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): group_tasks = [ ( @@ -636,13 +670,14 @@ async def test_get_group_result_with_failures( group_id, task_uuids = await task_manager.submit_group( GroupExecutionMetadata(name="failures_result_group", tasks=group_tasks), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) for task_uuid in task_uuids: - await wait_for_task_done(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_done(task_manager, task_uuid) - results = await task_manager.get_result(fake_owner_metadata, group_id) + results = await task_manager.get_result(group_id) assert len(results) == 2 assert results[0] == "archive.zip" assert isinstance(results[1], TransferableCeleryError) @@ -652,7 +687,8 @@ async def test_get_group_result_with_failures( async def test_get_group_result_with_ephemeral_cleans_up( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): num_tasks = 2 group_tasks = [ @@ -665,26 +701,28 @@ async def test_get_group_result_with_ephemeral_cleans_up( group_id, task_uuids = await task_manager.submit_group( GroupExecutionMetadata(name="ephemeral_result_group", tasks=group_tasks, ephemeral=True), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) for task_uuid in task_uuids: - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # First call returns results and triggers cleanup - results = await task_manager.get_result(fake_owner_metadata, group_id) + results = await task_manager.get_result(group_id) assert results == ["archive.zip"] * num_tasks # Second call should fail because the group was cleaned up with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_result(fake_owner_metadata, group_id) + await task_manager.get_result(group_id) async def test_get_group_result_with_nonexistent_group_raises_error( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): fake_group_uuid = TypeAdapter(GroupUUID).validate_python(_faker.uuid4()) with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_result(fake_owner_metadata, fake_group_uuid) + await task_manager.get_result(fake_group_uuid) diff --git a/packages/celery-library/tests/unit/task_manager/test_task_manager_streams.py b/packages/celery-library/tests/unit/task_manager/test_task_manager_streams.py index 222125363275..0482118330e9 100644 --- a/packages/celery-library/tests/unit/task_manager/test_task_manager_streams.py +++ b/packages/celery-library/tests/unit/task_manager/test_task_manager_streams.py @@ -7,7 +7,6 @@ from celery_library.errors import TaskOrGroupNotFoundError from faker import Faker from models_library.celery import ( - OwnerMetadata, TaskExecutionMetadata, TaskStreamItem, TaskUUID, @@ -27,7 +26,8 @@ async def test_push_task_result_streams_data_during_execution( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): num_results = 3 @@ -36,7 +36,8 @@ async def test_push_task_result_streams_data_during_execution( name=streaming_results_task.__name__, ephemeral=False, # Keep task available after completion for result pulling ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, num_results=num_results, ) @@ -44,7 +45,7 @@ async def test_push_task_result_streams_data_during_execution( results = [] async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - result, is_done, _ = await task_manager.pull_task_stream_items(fake_owner_metadata, task_uuid, limit=10) + result, is_done, _ = await task_manager.pull_task_stream_items(task_uuid, limit=10) results.extend(result) assert is_done @@ -52,14 +53,14 @@ async def test_push_task_result_streams_data_during_execution( assert results == [TaskStreamItem(data=f"result-{i}") for i in range(num_results)] # Wait for task completion - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # Final task result should be available - final_result = await task_manager.get_result(fake_owner_metadata, task_uuid) + final_result = await task_manager.get_result(task_uuid) assert final_result == f"completed-{num_results}-results" # After task completion, try to pull any remaining results - remaining_results, is_done, _ = await task_manager.pull_task_stream_items(fake_owner_metadata, task_uuid, limit=10) + remaining_results, is_done, _ = await task_manager.pull_task_stream_items(task_uuid, limit=10) assert remaining_results == [] assert is_done @@ -67,7 +68,8 @@ async def test_push_task_result_streams_data_during_execution( async def test_pull_task_stream_items_with_limit( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): # Submit task with fewer results to make it more predictable task_uuid = await task_manager.submit_task( @@ -75,16 +77,16 @@ async def test_pull_task_stream_items_with_limit( name=streaming_results_task.__name__, ephemeral=False, # Keep task available after completion for result pulling ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, num_results=5, ) # Wait for task to complete - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) # Pull all results in one go to avoid consumption issues all_results, is_done_final, _last_update_final = await task_manager.pull_task_stream_items( - fake_owner_metadata, task_uuid, limit=20, # High limit to get all items ) @@ -102,12 +104,13 @@ async def test_pull_task_stream_items_with_limit( async def test_pull_task_stream_items_from_nonexistent_task_raises_error( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): fake_task_uuid = TypeAdapter(TaskUUID).validate_python(_faker.uuid4()) with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.pull_task_stream_items(fake_owner_metadata, fake_task_uuid) + await task_manager.pull_task_stream_items(fake_task_uuid) async def test_push_task_stream_items_to_nonexistent_task_raises_error( diff --git a/packages/celery-library/tests/unit/task_manager/test_task_manager_tasks.py b/packages/celery-library/tests/unit/task_manager/test_task_manager_tasks.py index f604b978ad70..5859d44c1fc5 100644 --- a/packages/celery-library/tests/unit/task_manager/test_task_manager_tasks.py +++ b/packages/celery-library/tests/unit/task_manager/test_task_manager_tasks.py @@ -9,12 +9,10 @@ from celery_library.errors import TaskOrGroupNotFoundError, TransferableCeleryError from faker import Faker from models_library.celery import ( - OwnerMetadata, TaskExecutionMetadata, TaskState, TaskStatus, TaskUUID, - Wildcard, ) from servicelib.celery.task_manager import TaskManager from tenacity import AsyncRetrying @@ -33,84 +31,92 @@ async def test_submitting_task_calling_async_function_results_with_success_state( task_manager: TaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=fake_file_processor.__name__, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, files=[f"file{n}" for n in range(5)], ) - await wait_for_task_success(task_manager, fake_owner_metadata, task_uuid) - task_status = await task_manager.get_status(fake_owner_metadata, task_uuid) + await wait_for_task_success(task_manager, task_uuid) + task_status = await task_manager.get_status(task_uuid) assert isinstance(task_status, TaskStatus) assert task_status.task_state == TaskState.SUCCESS - assert (await task_manager.get_result(fake_owner_metadata, task_uuid)) == "archive.zip" + assert (await task_manager.get_result(task_uuid)) == "archive.zip" async def test_submitting_task_with_failure_results_with_error( task_manager: TaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=failure_task.__name__, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - raw_result = await task_manager.get_result(fake_owner_metadata, task_uuid) + raw_result = await task_manager.get_result(task_uuid) assert isinstance(raw_result, TransferableCeleryError) - raw_result = await task_manager.get_result(fake_owner_metadata, task_uuid) + raw_result = await task_manager.get_result(task_uuid) assert f"{raw_result}" == "Something strange happened: BOOM!" async def test_cancelling_a_running_task_aborts_and_deletes( task_manager: TaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=dreamer_task.__name__, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) await asyncio.sleep(3.0) - await task_manager.cancel(fake_owner_metadata, task_uuid) + await task_manager.cancel(task_uuid) with pytest.raises(TaskOrGroupNotFoundError): - await task_manager.get_status(fake_owner_metadata, task_uuid) + await task_manager.get_status(task_uuid) - assert task_uuid not in await task_manager.list_tasks(fake_owner_metadata) + assert task_uuid not in await task_manager.list_tasks(owner=fake_owner, user_id=fake_user_id) async def test_listing_task_uuids_contains_submitted_task( task_manager: CeleryTaskManager, with_celery_worker: WorkController, - fake_owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=dreamer_task.__name__, ), - owner_metadata=fake_owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - tasks = await task_manager.list_tasks(fake_owner_metadata) + tasks = await task_manager.list_tasks(owner=fake_owner, user_id=fake_user_id) assert any(task.uuid == task_uuid for task in tasks) - tasks = await task_manager.list_tasks(fake_owner_metadata) + tasks = await task_manager.list_tasks(owner=fake_owner, user_id=fake_user_id) assert any(task.uuid == task_uuid for task in tasks) @@ -118,49 +124,38 @@ async def test_filtering_listing_tasks( task_manager: CeleryTaskManager, with_celery_worker: WorkController, ): - class MyOwnerMetadata(OwnerMetadata): - user_id: int - product_name: str | Wildcard - user_id = 42 owner = "test-owner" expected_task_uuids: set[TaskUUID] = set() - all_tasks: list[tuple[TaskUUID, MyOwnerMetadata]] = [] + all_task_uuids: list[TaskUUID] = [] try: for _ in range(5): - owner_metadata = MyOwnerMetadata(user_id=user_id, product_name=_faker.word(), owner=owner) task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=dreamer_task.__name__, ), - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=_faker.word(), ) expected_task_uuids.add(task_uuid) - all_tasks.append((task_uuid, owner_metadata)) + all_task_uuids.append(task_uuid) for _ in range(3): - owner_metadata = MyOwnerMetadata( - user_id=_faker.pyint(min_value=100, max_value=200), - product_name=_faker.word(), - owner=owner, - ) task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=dreamer_task.__name__, ), - owner_metadata=owner_metadata, + owner=owner, + user_id=_faker.pyint(min_value=100, max_value=200), + product_name=_faker.word(), ) - all_tasks.append((task_uuid, owner_metadata)) - - search_owner_metadata = MyOwnerMetadata( - user_id=user_id, - product_name="*", - owner=owner, - ) - tasks = await task_manager.list_tasks(search_owner_metadata) + all_task_uuids.append(task_uuid) + + # Query by owner + user_id only (product_name=None acts as wildcard) + tasks = await task_manager.list_tasks(owner=owner, user_id=user_id) assert expected_task_uuids == {task.uuid for task in tasks} finally: - # clean up all tasks. this should ideally be done in the fixture - for task_uuid, owner_metadata in all_tasks: - await task_manager.cancel(owner_metadata, task_uuid) + for task_uuid in all_task_uuids: + await task_manager.cancel(task_uuid) diff --git a/packages/celery-library/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py index 8499e9754996..3c8d4661c32a 100644 --- a/packages/celery-library/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -22,7 +22,7 @@ JobError, JobMissingError, ) -from models_library.celery import OwnerMetadata, TaskExecutionMetadata, TaskKey +from models_library.celery import TaskExecutionMetadata, TaskKey from servicelib.celery.task_manager import TaskManager from tenacity import ( AsyncRetrying, @@ -37,12 +37,13 @@ class AccessRightError(OsparcErrorMixin, RuntimeError): @pytest.fixture -def owner_metadata(faker: Faker) -> OwnerMetadata: - return OwnerMetadata( - user_id=faker.pyint(min_value=1), - product_name=faker.word(), - owner="pytest_client", - ) +def fake_owner() -> str: + return "pytest-client" + + +@pytest.fixture +def fake_user_id(faker: Faker) -> int: + return faker.pyint(min_value=1) class Action(str, Enum): @@ -101,7 +102,6 @@ def _(celery_app: Celery) -> None: async def _wait_for_job( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, job_id: AsyncJobId, stop_after: timedelta = timedelta(seconds=5), ) -> None: @@ -114,7 +114,6 @@ async def _wait_for_job( with attempt: status = await get_job_status( task_manager, - owner_metadata=owner_metadata, job_id=job_id, ) assert status.done is True, "Please check logs above, something went wrong with task execution" @@ -142,32 +141,33 @@ async def test_async_jobs_workflow( task_manager: TaskManager, with_celery_worker: WorkController, execution_metadata: TaskExecutionMetadata, - owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, payload: Any, ): async_job = await submit_job( task_manager, execution_metadata=execution_metadata, - owner_metadata=owner_metadata, + owner=fake_owner, + user_id=fake_user_id, action=Action.ECHO, payload=payload, ) jobs = await list_jobs( task_manager, - owner_metadata=owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert len(jobs) > 0 await _wait_for_job( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) async_job_result = await get_job_result( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) assert async_job_result.result == payload @@ -183,39 +183,39 @@ async def test_async_jobs_cancel( task_manager: TaskManager, with_celery_worker: WorkController, execution_metadata: TaskExecutionMetadata, - owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, ): async_job = await submit_job( task_manager, execution_metadata=execution_metadata, - owner_metadata=owner_metadata, + owner=fake_owner, + user_id=fake_user_id, action=Action.SLEEP, payload=60 * 10, # test hangs if not cancelled properly ) await cancel_job( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) jobs = await list_jobs( task_manager, - owner_metadata=owner_metadata, + owner=fake_owner, + user_id=fake_user_id, ) assert async_job.job_id not in [job.job_id for job in jobs] with pytest.raises(JobMissingError): await get_job_status( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) with pytest.raises(JobMissingError): await get_job_result( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) @@ -241,20 +241,21 @@ async def test_async_jobs_raises( task_manager: TaskManager, with_celery_worker: WorkController, execution_metadata: TaskExecutionMetadata, - owner_metadata: OwnerMetadata, + fake_owner: str, + fake_user_id: int, error: Exception, ): async_job = await submit_job( task_manager, execution_metadata=execution_metadata, - owner_metadata=owner_metadata, + owner=fake_owner, + user_id=fake_user_id, action=Action.RAISE, payload=pickle.dumps(error), ) await _wait_for_job( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, stop_after=timedelta(minutes=1), ) @@ -262,7 +263,6 @@ async def test_async_jobs_raises( with pytest.raises(JobError) as exc: await get_job_result( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, ) assert exc.value.exc_type == type(error).__name__ diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py index edffd9b72cf7..313afbfd83a1 100644 --- a/packages/celery-library/tests/unit/test_redis_store.py +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -2,18 +2,16 @@ from collections.abc import AsyncIterator from datetime import timedelta +from uuid import UUID import pytest from celery_library.backends import RedisTaskStore from celery_library.backends._redis import _build_redis_task_or_group_key from faker import Faker from models_library.celery import ( - OwnerMetadata, Task, TaskExecutionMetadata, - Wildcard, ) -from models_library.users import UserID from servicelib.redis import RedisClientSDK from settings_library.redis import RedisDatabase, RedisSettings @@ -23,11 +21,6 @@ pytest_simcore_ops_services_selection = [] -class _TestOwnerMetadata(OwnerMetadata): - user_id: UserID - product_name: str | Wildcard - - @pytest.fixture async def redis_client_sdk( use_in_memory_redis: RedisSettings, @@ -55,12 +48,14 @@ async def test_list_tasks_uses_zset_index_not_scan( redis_client_sdk: RedisClientSDK, monkeypatch: pytest.MonkeyPatch, ): - owner = _TestOwnerMetadata(user_id=10001, product_name="osparc", owner="test-svc") - task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) + task_key = _faker.uuid4() await redis_task_store.create_task( task_key, TaskExecutionMetadata(name="my_task"), + owner="test-svc", + user_id=10001, + product_name="osparc", expiry=timedelta(minutes=5), ) @@ -74,9 +69,9 @@ def _forbid_scan_iter(*_args: object, **_kwargs: object) -> None: _forbid_scan_iter, ) - tasks = await redis_task_store.list_tasks(owner) + tasks = await redis_task_store.list_tasks(owner="test-svc", user_id=10001, product_name="osparc") assert len(tasks) == 1 - assert tasks[0].uuid == OwnerMetadata.get_task_or_group_uuid(task_key) + assert tasks[0].uuid == UUID(task_key) async def test_list_tasks_with_wildcard_filtering( @@ -87,65 +82,69 @@ async def test_list_tasks_with_wildcard_filtering( expected_tasks: list[Task] = [] for _ in range(5): - om = _TestOwnerMetadata(user_id=user_id, product_name=_faker.word(), owner=owner) - task_key = om.model_dump_key(task_or_group_uuid=_faker.uuid4()) + task_key = _faker.uuid4() await redis_task_store.create_task( task_key, TaskExecutionMetadata(name="my_task"), + owner=owner, + user_id=user_id, + product_name=_faker.word(), expiry=timedelta(minutes=5), ) expected_tasks.append( Task( - uuid=OwnerMetadata.get_task_or_group_uuid(task_key), + uuid=UUID(task_key), metadata=TaskExecutionMetadata(name="my_task"), ) ) for _ in range(3): - om = _TestOwnerMetadata( - user_id=_faker.pyint(min_value=100, max_value=200), - product_name=_faker.word(), - owner=owner, - ) - task_key = om.model_dump_key(task_or_group_uuid=_faker.uuid4()) + task_key = _faker.uuid4() await redis_task_store.create_task( task_key, TaskExecutionMetadata(name="my_task"), + owner=owner, + user_id=_faker.pyint(min_value=100, max_value=200), + product_name=_faker.word(), expiry=timedelta(minutes=5), ) - search = _TestOwnerMetadata(user_id=user_id, product_name="*", owner=owner) - tasks = await redis_task_store.list_tasks(search) + # Query by owner + user_id only (product_name=None acts as wildcard) + tasks = await redis_task_store.list_tasks(owner=owner, user_id=user_id) assert {t.uuid for t in tasks} == {t.uuid for t in expected_tasks} async def test_remove_task_cleans_up_zset_indexes( redis_task_store: RedisTaskStore, ): - owner = _TestOwnerMetadata(user_id=10003, product_name="osparc", owner="test-svc") - task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) + task_key = _faker.uuid4() await redis_task_store.create_task( task_key, TaskExecutionMetadata(name="my_task"), + owner="test-svc", + user_id=10003, + product_name="osparc", expiry=timedelta(minutes=5), ) - assert len(await redis_task_store.list_tasks(owner)) == 1 + assert len(await redis_task_store.list_tasks(owner="test-svc", user_id=10003, product_name="osparc")) == 1 - await redis_task_store.remove_task(task_key) - assert len(await redis_task_store.list_tasks(owner)) == 0 + await redis_task_store.remove_task(task_key, owner="test-svc", user_id=10003, product_name="osparc") + assert len(await redis_task_store.list_tasks(owner="test-svc", user_id=10003, product_name="osparc")) == 0 async def test_stale_zset_entries_are_pruned_on_list( redis_task_store: RedisTaskStore, redis_client_sdk: RedisClientSDK, ): - owner = _TestOwnerMetadata(user_id=10004, product_name="osparc", owner="test-svc") - task_key = owner.model_dump_key(task_or_group_uuid=_faker.uuid4()) + task_key = _faker.uuid4() await redis_task_store.create_task( task_key, TaskExecutionMetadata(name="my_task"), + owner="test-svc", + user_id=10004, + product_name="osparc", expiry=timedelta(minutes=5), ) @@ -155,6 +154,6 @@ async def test_stale_zset_entries_are_pruned_on_list( await redis.delete(_build_redis_task_or_group_key(task_key)) # First list should return empty and prune the stale entry - assert await redis_task_store.list_tasks(owner) == [] + assert await redis_task_store.list_tasks(owner="test-svc", user_id=10004, product_name="osparc") == [] # Second list confirms the ZSET is clean - assert await redis_task_store.list_tasks(owner) == [] + assert await redis_task_store.list_tasks(owner="test-svc", user_id=10004, product_name="osparc") == [] diff --git a/packages/models-library/src/models_library/celery.py b/packages/models-library/src/models_library/celery.py index 4645fb61bcdb..0ef6d7b18f5e 100644 --- a/packages/models-library/src/models_library/celery.py +++ b/packages/models-library/src/models_library/celery.py @@ -1,11 +1,9 @@ from datetime import datetime, timedelta from enum import auto -from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeVar +from typing import Annotated, Any, Final, Literal, Protocol, TypeVar from uuid import UUID -import orjson -from common_library.json_serialization import json_dumps, json_loads -from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, StringConstraints, TypeAdapter, model_validator +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, StringConstraints, TypeAdapter from pydantic.config import JsonDict from .progress_bar import ProgressReport @@ -26,116 +24,9 @@ DEFAULT_QUEUE: Final[str] = "default" - -_KEY_DELIMITATOR: Final[str] = ":" -_FORBIDDEN_KEY_CHARS = ("*", _KEY_DELIMITATOR, "=") -_FORBIDDEN_VALUE_CHARS = (_KEY_DELIMITATOR, "=") -type AllowedTypes = int | float | bool | str | list[str] | list[int] | list[float] | list[bool] | None - -type Wildcard = Literal["*"] -WILDCARD: Final[Wildcard] = "*" - -_UUID_KEY: Final[str] = "uuid" _TASK_UUID_ADAPTER: Final[TypeAdapter[TaskUUID]] = TypeAdapter(TaskUUID) -class _TypeValidationModel(BaseModel): - filters: dict[str, AllowedTypes] - - -class OwnerMetadata(BaseModel): - """ - Class for associating metadata with a celery task. - The implementation is very flexible and allows the task owner to define their own metadata. - This could be metadata for validating if a user has access to a given task (e.g. user_id or product_name) or - metadata for keeping track of how to handle a task, - e.g. which schema will the result of the task have. - - The class exposes a filtering mechanism to list tasks using wildcards. - - Example usage: - class StorageOwnerMetadata(OwnerMetadata): - user_id: int | Wildcard - product_name: int | Wildcard - owner = APP_NAME - - Where APP_NAME is the name of the service. Listing tasks using the filter - `StorageOwnerMetadata(user_id=123, product_name=WILDCARD)` will return all tasks with - user_id 123, any product_name submitted from the service. - - If the metadata schema is known, the class allows deserializing the metadata (recreate_as_model). - I.e. one can recover the metadata from the task: - metadata -> task_uuid -> metadata - - """ - - model_config = ConfigDict(extra="allow", frozen=True) - owner: Annotated[ - str, - StringConstraints(min_length=1, pattern=r"^[a-z_-]+$"), - Field(description='Identifies the service owning the task. Should be the "APP_NAME" of the service.'), - ] - - @model_validator(mode="after") - def _check_valid_filters(self) -> Self: - def _validate_type() -> None: - try: - _TypeValidationModel.model_validate({"filters": self.model_dump()}) - except ValueError as err: - msg = "Invalid filter type" - raise TypeError(msg) from err - - for key, value in self.model_dump().items(): - # forbidden key chars - if any(x in key for x in _FORBIDDEN_KEY_CHARS): - msg = f"Invalid filter key: '{key}'" - raise ValueError(msg) - # forbidden value chars - if any(x in json_dumps(value) for x in _FORBIDDEN_VALUE_CHARS): - msg = f"Invalid filter value for key '{key}': '{value}'" - raise ValueError(msg) - - if _UUID_KEY in self.model_dump(): - msg = f"'{_UUID_KEY}' is a reserved key" - raise ValueError(msg) - - _validate_type() - return self - - def model_dump_key(self, task_or_group_uuid: TaskUUID | GroupUUID | Wildcard) -> TaskKey | GroupKey: - data = self.model_dump(mode="json") - data.update({_UUID_KEY: f"{task_or_group_uuid}"}) - return _KEY_DELIMITATOR.join([f"{k}={json_dumps(v)}" for k, v in sorted(data.items())]) - - @classmethod - def model_validate_key(cls, task_or_group_key: TaskKey | GroupKey) -> Self: - data = cls._deserialize_task_or_group_key(task_or_group_key) - data.pop(_UUID_KEY, None) - return cls.model_validate(data) - - @classmethod - def _deserialize_task_or_group_key(cls, task_or_group_key: TaskKey | GroupKey) -> dict[str, AllowedTypes]: - key_value_pairs = [item.split("=") for item in task_or_group_key.split(_KEY_DELIMITATOR)] - try: - return {key: json_loads(value) for key, value in key_value_pairs} - except orjson.JSONDecodeError as err: - msg = f"Invalid key format: {task_or_group_key}" - raise ValueError(msg) from err - - @classmethod - def get_task_or_group_uuid(cls, task_or_group_key: TaskKey | GroupKey) -> TaskUUID | GroupUUID: - data = cls._deserialize_task_or_group_key(task_or_group_key) - try: - uuid_string = data.get(_UUID_KEY) - if not isinstance(uuid_string, str): - msg = f"Invalid task_id format: {task_or_group_key}" - raise TypeError(msg) - return _TASK_UUID_ADAPTER.validate_python(uuid_string) - except ValueError as err: - msg = f"Invalid task_id format: {task_or_group_key}" - raise ValueError(msg) from err - - class TaskState(StrAutoEnum): PENDING = auto() STARTED = auto() @@ -234,6 +125,10 @@ async def create_group( group_key: GroupKey, execution_metadata: GroupExecutionMetadata, task_keys: list[TaskKey], + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, expiry: timedelta, ) -> None: ... @@ -241,6 +136,10 @@ async def create_task( self, task_key: TaskKey, execution_metadata: TaskExecutionMetadata, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, expiry: timedelta, ) -> None: ... @@ -250,9 +149,30 @@ async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None async def get_task_progress(self, task_key: TaskKey) -> ProgressReport | None: ... - async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ... + async def list_tasks( + self, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> list[Task]: ... + + async def remove_task( + self, + task_key: TaskKey, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> None: ... + + async def remove_task_hash(self, task_key: TaskKey) -> None: + """Remove only the task hash from the store, without cleaning sorted-set indexes. - async def remove_task(self, task_key: TaskKey) -> None: ... + Stale index entries are cleaned lazily by ``list_tasks``. + Use this when the owner info is unavailable (e.g. cancel, ephemeral cleanup). + """ + ... async def set_task_progress( self, diff --git a/packages/models-library/src/models_library/notifications/rpc/_message.py b/packages/models-library/src/models_library/notifications/rpc/_message.py index 435085b3d712..4f33f9bf33f7 100644 --- a/packages/models-library/src/models_library/notifications/rpc/_message.py +++ b/packages/models-library/src/models_library/notifications/rpc/_message.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field from pydantic.config import JsonDict -from ...celery import GroupUUID, OwnerMetadata, TaskUUID +from ...celery import GroupUUID, TaskUUID from ._email import Addressing, Message from ._template import TemplateRef @@ -15,7 +15,9 @@ class SendMessageRequest(BaseModel): description="Channel-specific message payload (e.g. EmailMessage for email).", ), ] - owner_metadata: OwnerMetadata | None = None + owner: str | None = None + user_id: int | None = None + product_name: str | None = None @staticmethod def _update_json_schema_extra(schema: JsonDict) -> None: @@ -43,11 +45,9 @@ def _update_json_schema_extra(schema: JsonDict) -> None: "body_text": "Welcome to osparc!", }, }, - "owner_metadata": { - "user_id": 123, - "product_name": "osparc", - "owner": "notification-service", - }, + "owner": "notification-service", + "user_id": 123, + "product_name": "osparc", }, ] } @@ -73,7 +73,9 @@ class SendMessageFromTemplateRequest(BaseModel): description="Template context variables. Must conform to the context_schema of the referenced template.", ), ] - owner_metadata: OwnerMetadata | None = None + owner: str | None = None + user_id: int | None = None + product_name: str | None = None @staticmethod def _update_json_schema_extra(schema: JsonDict) -> None: @@ -102,11 +104,9 @@ def _update_json_schema_extra(schema: JsonDict) -> None: "user": {"first_name": "John"}, "link": "https://osparc.io", }, - "owner_metadata": { - "user_id": 123, - "product_name": "osparc", - "owner": "notification-service", - }, + "owner": "notification-service", + "user_id": 123, + "product_name": "osparc", }, ] } diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py index 8a45090d83d3..d3f2b9e6d392 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py @@ -10,7 +10,6 @@ AsyncJobStatus, ) from models_library.api_schemas_async_jobs.exceptions import BaseAsyncjobRpcError -from models_library.celery import OwnerMetadata from models_library.progress_bar import ProgressReport from models_library.rabbitmq_basic_types import RPCNamespace from pydantic import validate_call @@ -29,12 +28,10 @@ async def cancel( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - owner_metadata: OwnerMetadata, ) -> None: assert rabbitmq_rpc_client assert rpc_namespace assert job_id - assert owner_metadata if self.exception is not None: raise self.exception @@ -46,12 +43,10 @@ async def status( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - owner_metadata: OwnerMetadata, ) -> AsyncJobStatus: assert rabbitmq_rpc_client assert rpc_namespace assert job_id - assert owner_metadata if self.exception is not None: raise self.exception @@ -73,12 +68,10 @@ async def result( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - owner_metadata: OwnerMetadata, ) -> AsyncJobResult: assert rabbitmq_rpc_client assert rpc_namespace assert job_id - assert owner_metadata if self.exception is not None: raise self.exception @@ -90,12 +83,16 @@ async def list_jobs( rabbitmq_rpc_client: RabbitMQRPCClient | MockType, *, rpc_namespace: RPCNamespace, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, filter_: str = "", ) -> list[AsyncJobGet]: assert rabbitmq_rpc_client assert rpc_namespace - assert owner_metadata + assert owner + assert user_id is not None or user_id is None + assert product_name is not None or product_name is None assert filter_ is not None if self.exception is not None: diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc.py b/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc.py index 99ed6d7012a1..34e2d6b9d75c 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc.py @@ -12,7 +12,6 @@ AsyncJobGet, ) from models_library.api_schemas_webserver.storage import PathToExport -from models_library.celery import OwnerMetadata from models_library.products import ProductName from models_library.users import UserID from pydantic import TypeAdapter, validate_call @@ -29,19 +28,17 @@ async def start_export_data( *, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], - owner_metadata: OwnerMetadata, + owner: str, user_id: UserID, product_name: ProductName, - ) -> tuple[AsyncJobGet, OwnerMetadata]: + ) -> AsyncJobGet: assert rabbitmq_rpc_client - assert owner_metadata + assert owner assert paths_to_export assert export_as assert user_id assert product_name - async_job_get = TypeAdapter(AsyncJobGet).validate_python( + return TypeAdapter(AsyncJobGet).validate_python( AsyncJobGet.model_json_schema()["examples"][0], ) - - return async_job_get, owner_metadata diff --git a/packages/service-library/src/servicelib/celery/async_jobs/notifications.py b/packages/service-library/src/servicelib/celery/async_jobs/notifications.py index 43a69ff1770d..f5e9d13633ca 100644 --- a/packages/service-library/src/servicelib/celery/async_jobs/notifications.py +++ b/packages/service-library/src/servicelib/celery/async_jobs/notifications.py @@ -4,7 +4,6 @@ GroupExecutionMetadata, GroupTaskExecutionMetadata, GroupUUID, - OwnerMetadata, TaskExecutionMetadata, TaskName, TaskUUID, @@ -19,7 +18,9 @@ async def submit_send_message_task( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, message: dict[str, Any], # NOTE: validated internally description: str | None = None, ) -> tuple[TaskUUID, TaskName]: @@ -29,7 +30,9 @@ async def submit_send_message_task( queue=NOTIFICATIONS_SERVICE_QUEUE_NAME, description=description, ), - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, message=message, ), SEND_MESSAGE_TASK_NAME_TEMPLATE.format(message["channel"]) @@ -37,7 +40,9 @@ async def submit_send_message_task( async def submit_send_messages_task( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, messages: list[dict[str, Any]], # NOTE: validated internally description: str | None = None, ) -> tuple[GroupUUID, list[TaskUUID], TaskName]: @@ -57,6 +62,8 @@ async def submit_send_messages_task( for message in messages ], ), - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, ) return group_uuid, task_uuids, SEND_MESSAGE_TASK_NAME_TEMPLATE.format(messages[0]["channel"]) diff --git a/packages/service-library/src/servicelib/celery/async_jobs/storage/paths.py b/packages/service-library/src/servicelib/celery/async_jobs/storage/paths.py index a030b8f9132e..fa1f1487c56d 100644 --- a/packages/service-library/src/servicelib/celery/async_jobs/storage/paths.py +++ b/packages/service-library/src/servicelib/celery/async_jobs/storage/paths.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Final -from models_library.celery import OwnerMetadata, TaskExecutionMetadata, TaskName, TaskUUID +from models_library.celery import TaskExecutionMetadata, TaskName, TaskUUID from models_library.products import ProductName from models_library.projects_nodes_io import LocationID from models_library.users import UserID @@ -15,7 +15,7 @@ async def submit_compute_path_size_task( task_manager: TaskManager, - owner_metadata: OwnerMetadata, + owner: str, user_id: UserID, product_name: ProductName, location_id: LocationID, @@ -25,7 +25,7 @@ async def submit_compute_path_size_task( TaskExecutionMetadata( name=COMPUTE_PATH_SIZE_TASK_NAME, ), - owner_metadata=owner_metadata, + owner=owner, user_id=user_id, product_name=product_name, location_id=location_id, @@ -35,7 +35,7 @@ async def submit_compute_path_size_task( async def submit_delete_paths_task( task_manager: TaskManager, - owner_metadata: OwnerMetadata, + owner: str, user_id: UserID, location_id: LocationID, paths: set[Path], @@ -44,7 +44,7 @@ async def submit_delete_paths_task( TaskExecutionMetadata( name=DELETE_PATHS_TASK_NAME, ), - owner_metadata=owner_metadata, + owner=owner, user_id=user_id, location_id=location_id, paths=paths, diff --git a/packages/service-library/src/servicelib/celery/async_jobs/storage/simcore_s3.py b/packages/service-library/src/servicelib/celery/async_jobs/storage/simcore_s3.py index ac3906017ba6..09245c642675 100644 --- a/packages/service-library/src/servicelib/celery/async_jobs/storage/simcore_s3.py +++ b/packages/service-library/src/servicelib/celery/async_jobs/storage/simcore_s3.py @@ -4,7 +4,6 @@ from models_library.api_schemas_async_jobs.async_jobs import AsyncJobGet from models_library.api_schemas_webserver.storage import PathToExport from models_library.celery import ( - OwnerMetadata, TaskExecutionMetadata, ) from models_library.products import ProductName @@ -22,7 +21,7 @@ class TaskQueueNames(StrEnum): async def submit_export_data( task_manager: TaskManager, - owner_metadata: OwnerMetadata, + owner: str, user_id: UserID, product_name: ProductName, paths_to_export: list[PathToExport], @@ -42,7 +41,7 @@ async def submit_export_data( ephemeral=False, queue=TaskQueueNames.CPU_BOUND, ), - owner_metadata=owner_metadata, + owner=owner, user_id=user_id, product_name=product_name, paths_to_export=paths_to_export, diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index f9c51f8a646e..1c0c74321231 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -6,7 +6,6 @@ GroupKey, GroupStatus, GroupUUID, - OwnerMetadata, Task, TaskExecutionMetadata, TaskKey, @@ -23,22 +22,34 @@ async def submit_group( self, execution_metadata: GroupExecutionMetadata, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, ) -> tuple[GroupUUID, list[TaskUUID]]: ... async def submit_task( - self, execution_metadata: TaskExecutionMetadata, *, owner_metadata: OwnerMetadata, **task_params + self, + execution_metadata: TaskExecutionMetadata, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + **task_params, ) -> TaskUUID: ... - async def cancel(self, owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID) -> None: ... + async def cancel(self, task_or_group_uuid: TaskUUID | GroupUUID) -> None: ... - async def get_result(self, owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID) -> Any: ... + async def get_result(self, task_or_group_uuid: TaskUUID | GroupUUID) -> Any: ... - async def get_status( - self, owner_metadata: OwnerMetadata, task_or_group_uuid: TaskUUID | GroupUUID - ) -> TaskStatus | GroupStatus: ... + async def get_status(self, task_or_group_uuid: TaskUUID | GroupUUID) -> TaskStatus | GroupStatus: ... - async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ... + async def list_tasks( + self, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> list[Task]: ... async def set_task_progress(self, task_key: TaskKey, report: ProgressReport) -> None: ... @@ -46,7 +57,6 @@ async def push_task_stream_items(self, task_key: TaskKey, *items: TaskStreamItem async def pull_task_stream_items( self, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, offset: int = 0, limit: int = 20, diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/notifications/_message.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/notifications/_message.py index a26c03f53090..77b770f893db 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/notifications/_message.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/notifications/_message.py @@ -1,7 +1,6 @@ import logging from typing import Any -from models_library.celery import OwnerMetadata from models_library.notifications.rpc import ( NOTIFICATIONS_RPC_NAMESPACE, Addressing, @@ -26,14 +25,18 @@ async def send_message( rabbitmq_rpc_client: RabbitMQRPCClient, *, message: Message, - owner_metadata: OwnerMetadata | None = None, + owner: str | None = None, + user_id: int | None = None, + product_name: str | None = None, ) -> SendMessageResponse: result = await rabbitmq_rpc_client.request( NOTIFICATIONS_RPC_NAMESPACE, TypeAdapter(RPCMethodName).validate_python("send_message"), request=SendMessageRequest( message=message, - owner_metadata=OwnerMetadata.model_validate(owner_metadata.model_dump()) if owner_metadata else None, + owner=owner, + user_id=user_id, + product_name=product_name, ), ) assert isinstance(result, SendMessageResponse) # nosec @@ -48,7 +51,9 @@ async def send_message_from_template( addressing: Addressing, template_ref: TemplateRef, context: dict[str, Any], - owner_metadata: OwnerMetadata | None = None, + owner: str | None = None, + user_id: int | None = None, + product_name: str | None = None, ) -> SendMessageResponse: result = await rabbitmq_rpc_client.request( NOTIFICATIONS_RPC_NAMESPACE, @@ -57,7 +62,9 @@ async def send_message_from_template( template_ref=template_ref, addressing=addressing, context=context, - owner_metadata=OwnerMetadata.model_validate(owner_metadata.model_dump()) if owner_metadata else None, + owner=owner, + user_id=user_id, + product_name=product_name, ), ) assert isinstance(result, SendMessageResponse) # nosec diff --git a/packages/service-library/tests/test_celery.py b/packages/service-library/tests/test_celery.py index 230839126244..f44a1d2b72c2 100644 --- a/packages/service-library/tests/test_celery.py +++ b/packages/service-library/tests/test_celery.py @@ -1,148 +1,25 @@ -from types import NoneType -from typing import Annotated - # pylint: disable=redefined-outer-name # pylint: disable=protected-access -import pydantic -import pytest -from common_library.json_serialization import json_dumps -from faker import Faker -from models_library.celery import ( - OwnerMetadata, - TaskUUID, - Wildcard, -) -from pydantic import StringConstraints, TypeAdapter - -_faker = Faker() - - -class _TestOwnerMetadata(OwnerMetadata): - string_: str - int_: int - bool_: bool - none_: None - uuid_: str - - -@pytest.fixture -def test_owner_metadata() -> dict[str, str | int | bool | None | list[str]]: - data = { - "string_": _faker.word(), - "int_": _faker.random_int(), - "bool_": _faker.boolean(), - "none_": None, - "uuid_": _faker.uuid4(), - "owner": _faker.word().lower(), - } - _TestOwnerMetadata.model_validate(data) # ensure it's valid - return data - - -async def test_task_filter_serialization( - test_owner_metadata: dict[str, str | int | bool | None | list[str]], -): - task_filter = _TestOwnerMetadata.model_validate(test_owner_metadata) - assert task_filter.model_dump() == test_owner_metadata - - -async def test_task_filter_sorting_key_not_serialized(): - class _OwnerMetadata(OwnerMetadata): - a: int | Wildcard - b: str | Wildcard - - owner_metadata = _OwnerMetadata.model_validate( - { - "a": _faker.random_int(), - "b": _faker.word(), - "owner": _faker.word().lower(), - } - ) - task_uuid = TypeAdapter(TaskUUID).validate_python(_faker.uuid4()) - copy_owner_metadata = owner_metadata.model_dump() - copy_owner_metadata.update({"uuid": f"{task_uuid}"}) - - expected_key = ":".join([f"{k}={json_dumps(v)}" for k, v in sorted(copy_owner_metadata.items())]) - assert owner_metadata.model_dump_key(task_or_group_uuid=task_uuid) == expected_key - - -async def test_task_filter_task_uuid( - test_owner_metadata: dict[str, str | int | bool | None | list[str]], -): - task_filter = _TestOwnerMetadata.model_validate(test_owner_metadata) - task_uuid = TypeAdapter(TaskUUID).validate_python(_faker.uuid4()) - task_key = task_filter.model_dump_key(task_uuid) - assert OwnerMetadata.get_task_or_group_uuid(task_or_group_key=task_key) == task_uuid - - -async def test_owner_metadata_task_key_dump_and_validate(): - class MyModel(OwnerMetadata): - int_: int - bool_: bool - str_: str - float_: float - none_: NoneType - list_s: list[str] - list_i: list[int] - list_f: list[float] - list_b: list[bool] - - mymodel = MyModel( - int_=1, - none_=None, - bool_=True, - str_="test", - float_=1.0, - owner="myowner", - list_b=[True, False], - list_f=[1.0, 2.0], - list_i=[1, 2], - list_s=["a", "b"], - ) - task_uuid = TypeAdapter(TaskUUID).validate_python(_faker.uuid4()) - task_key = mymodel.model_dump_key(task_uuid) - mymodel_recreated = MyModel.model_validate_key(task_or_group_key=task_key) - assert mymodel_recreated == mymodel - - -@pytest.mark.parametrize( - "bad_data", - [ - {"foo": "bar:baz"}, - {"foo": "bar=baz"}, - {"foo:bad": "bar"}, - {"foo=bad": "bar"}, - {"foo": ":baz"}, - {"foo": "=baz"}, - ], -) -def test_task_filter_validator_raises_on_forbidden_chars(bad_data): - with pytest.raises(pydantic.ValidationError): - OwnerMetadata.model_validate(bad_data) - -async def test_task_owner(): - class MyOwnerMetadata(OwnerMetadata): - extra_field: str +import inspect - with pytest.raises(pydantic.ValidationError): - MyOwnerMetadata(owner="", extra_field="value") +from servicelib.celery.task_manager import TaskManager - with pytest.raises(pydantic.ValidationError): - MyOwnerMetadata(owner="UPPER_CASE", extra_field="value") - class MyNextFilter(OwnerMetadata): - owner: Annotated[str, StringConstraints(strip_whitespace=True, pattern=r"^the_task_owner$")] +def test_task_manager_protocol_has_plain_owner_params(): + """Verify the TaskManager protocol uses plain owner/user_id/product_name params.""" - with pytest.raises(pydantic.ValidationError): - MyNextFilter(owner="wrong_owner") + sig = inspect.signature(TaskManager.submit_task) + assert "owner" in sig.parameters + assert "user_id" in sig.parameters + assert "product_name" in sig.parameters + sig = inspect.signature(TaskManager.list_tasks) + assert "owner" in sig.parameters + assert "user_id" in sig.parameters + assert "product_name" in sig.parameters -def test_owner_metadata_serialize_deserialize(test_owner_metadata): - test_owner_metadata = _TestOwnerMetadata.model_validate(test_owner_metadata) - data = test_owner_metadata.model_dump() - deserialized_data = OwnerMetadata.model_validate(data) - assert len(_TestOwnerMetadata.model_fields) > len( - OwnerMetadata.model_fields - ) # ensure extra data is available in _TestOwnerMetadata -> needed for RPC - assert deserialized_data.model_dump() == data + sig = inspect.signature(TaskManager.cancel) + # cancel should NOT have owner params + assert "owner" not in sig.parameters + assert "user_id" not in sig.parameters diff --git a/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py b/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py index 642a28371d06..4fbecd6c4905 100644 --- a/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py +++ b/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py @@ -47,7 +47,6 @@ StudyJobOutputRequestButNotSucceededError, ) from .models.api_resources import JobLinks -from .models.domain.celery_models import ApiServerOwnerMetadata from .models.domain.functions import FunctionJobPatch from .models.schemas.functions import FunctionJobCreationTaskStatus from .models.schemas.jobs import JobInputs, JobPricingSpecification @@ -82,13 +81,9 @@ async def _celery_task_status( ) -> FunctionJobCreationTaskStatus: if job_creation_task_id is None: return FunctionJobCreationTaskStatus.NOT_YET_SCHEDULED - owner_metadata = ApiServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ) task_uuid: TaskUUID = TypeAdapter(TaskUUID).validate_python(f"{job_creation_task_id}") try: - task_status = await task_manager.get_status(owner_metadata=owner_metadata, task_or_group_uuid=task_uuid) + task_status = await task_manager.get_status(task_or_group_uuid=task_uuid) assert isinstance(task_status, TaskStatus) # nosec return FunctionJobCreationTaskStatus[task_status.task_state] except TaskOrGroupNotFoundError as err: @@ -99,7 +94,6 @@ async def _celery_task_status( error=err, error_context={ "task_uuid": task_uuid, - "owner_metadata": owner_metadata, "user_id": user_id, "product_name": product_name, }, @@ -324,11 +318,6 @@ async def create_function_job_creation_tasks( job_input_list=[JobInputs(values=_ or {}) for _ in uncached_inputs], ) - owner_metadata = ApiServerOwnerMetadata( - user_id=user_identity.user_id, - product_name=user_identity.product_name, - owner=APP_NAME, - ) task_uuids = await logged_gather( *( self._celery_task_manager.submit_task( @@ -337,7 +326,9 @@ async def create_function_job_creation_tasks( ephemeral=False, queue=API_SERVER_CELERY_QUEUE_DEFAULT, ), - owner_metadata=owner_metadata, + owner=APP_NAME, + user_id=user_identity.user_id, + product_name=user_identity.product_name, user_identity=user_identity, function=function, pre_registered_function_job_data=pre_registered_function_job_data, diff --git a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py index 8e6b006503a7..ec86070884e6 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py @@ -22,10 +22,8 @@ from pydantic import TypeAdapter from servicelib.fastapi.dependencies import get_app +from ..._meta import APP_NAME from ...exceptions.backend_errors import CeleryTaskNotFoundError -from ...models.domain.celery_models import ( - ApiServerOwnerMetadata, -) from ...models.schemas.base import ApiServerEnvelope from ...models.schemas.errors import ErrorGet from ..dependencies.authentication import get_current_user_id, get_product_name @@ -73,13 +71,11 @@ async def list_tasks( product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - owner_metadata = ApiServerOwnerMetadata( + tasks = await task_manager.list_tasks( + owner=APP_NAME, user_id=user_id, product_name=product_name, ) - tasks = await task_manager.list_tasks( - owner_metadata=owner_metadata, - ) app_router = app.router data = [ @@ -110,17 +106,12 @@ async def list_tasks( async def get_task_status( task_uuid: AsyncJobId, app: Annotated[FastAPI, Depends(get_app)], - user_id: Annotated[UserID, Depends(get_current_user_id)], - product_name: Annotated[ProductName, Depends(get_product_name)], + _user_id: Annotated[UserID, Depends(get_current_user_id)], + _product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - owner_metadata = ApiServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ) with _exception_mapper(task_uuid=task_uuid): task_status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=TypeAdapter(TaskUUID).validate_python(f"{task_uuid}"), ) @@ -150,17 +141,12 @@ async def get_task_status( async def cancel_task( task_uuid: AsyncJobId, app: Annotated[FastAPI, Depends(get_app)], - user_id: Annotated[UserID, Depends(get_current_user_id)], - product_name: Annotated[ProductName, Depends(get_product_name)], + _user_id: Annotated[UserID, Depends(get_current_user_id)], + _product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - owner_metadata = ApiServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ) with _exception_mapper(task_uuid=task_uuid): await task_manager.cancel( - owner_metadata=owner_metadata, task_or_group_uuid=TypeAdapter(TaskUUID).validate_python(f"{task_uuid}"), ) @@ -186,18 +172,13 @@ async def cancel_task( async def get_task_result( task_uuid: AsyncJobId, app: Annotated[FastAPI, Depends(get_app)], - user_id: Annotated[UserID, Depends(get_current_user_id)], - product_name: Annotated[ProductName, Depends(get_product_name)], + _user_id: Annotated[UserID, Depends(get_current_user_id)], + _product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - owner_metadata = ApiServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ) with _exception_mapper(task_uuid=task_uuid): task_status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=TypeAdapter(TaskUUID).validate_python(f"{task_uuid}"), ) @@ -208,7 +189,6 @@ async def get_task_result( ) task_result = await task_manager.get_result( - owner_metadata=owner_metadata, task_or_group_uuid=TypeAdapter(TaskUUID).validate_python(f"{task_uuid}"), ) diff --git a/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py b/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py index 9ce2f5af9b28..7f75a03906da 100644 --- a/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py +++ b/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py @@ -1,6 +1,3 @@ -from typing import Annotated - -from models_library.celery import OwnerMetadata from models_library.functions import ( RegisteredProjectFunction, RegisteredProjectFunctionJob, @@ -8,11 +5,7 @@ RegisteredSolverFunction, RegisteredSolverFunctionJob, ) -from models_library.products import ProductName -from models_library.users import UserID -from pydantic import Field, StringConstraints -from ..._meta import APP_NAME from ...api.dependencies.authentication import Identity from ...models.api_resources import JobLinks from ...models.domain.functions import PreRegisteredFunctionJobData @@ -31,9 +24,3 @@ RegisteredSolverFunction, RegisteredSolverFunctionJob, ) - - -class ApiServerOwnerMetadata(OwnerMetadata): - user_id: UserID - product_name: ProductName - owner: Annotated[str, StringConstraints(pattern=rf"^{APP_NAME}$"), Field(frozen=True)] = APP_NAME diff --git a/services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py b/services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py index 517938813643..b296207fb1d0 100644 --- a/services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py +++ b/services/api-server/src/simcore_service_api_server/modules/celery/worker/_functions_tasks.py @@ -1,3 +1,5 @@ +from typing import Any + from celery import ( # type: ignore[import-untyped] # pylint: disable=no-name-in-module Task, ) @@ -114,6 +116,7 @@ async def run_function( job_links: JobLinks, x_simcore_parent_project_uuid: ProjectID | None, x_simcore_parent_node_id: NodeID | None, + **_kwargs: Any, ) -> RegisteredFunctionJob: assert task_key # nosec app = get_app_server(task.app).app diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py index eb89aef2fd6c..539af60abf1c 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py @@ -14,7 +14,6 @@ JobNotDoneError, JobSchedulerError, ) -from models_library.celery import OwnerMetadata from servicelib.celery.task_manager import TaskManager from ..exceptions.service_errors_utils import service_exception_mapper @@ -37,10 +36,9 @@ class AsyncJobClient: JobSchedulerError: TaskSchedulerError, } ) - async def cancel(self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata) -> None: + async def cancel(self, *, job_id: AsyncJobId) -> None: return await cancel_job( self._task_manager, - owner_metadata=owner_metadata, job_id=job_id, ) @@ -49,10 +47,9 @@ async def cancel(self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata) -> JobSchedulerError: TaskSchedulerError, } ) - async def status(self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata) -> AsyncJobStatus: + async def status(self, *, job_id: AsyncJobId) -> AsyncJobStatus: return await get_job_status( self._task_manager, - owner_metadata=owner_metadata, job_id=job_id, ) @@ -64,10 +61,9 @@ async def status(self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata) -> JobError: TaskError, } ) - async def result(self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata) -> AsyncJobResult: + async def result(self, *, job_id: AsyncJobId) -> AsyncJobResult: return await get_job_result( self._task_manager, - owner_metadata=owner_metadata, job_id=job_id, ) @@ -76,8 +72,16 @@ async def result(self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata) -> JobSchedulerError: TaskSchedulerError, } ) - async def list_jobs(self, *, owner_metadata: OwnerMetadata) -> list[AsyncJobGet]: + async def list_jobs( + self, + *, + owner: str, + user_id: int | None = None, + product_name: str | None = None, + ) -> list[AsyncJobGet]: return await list_jobs( self._task_manager, - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, ) diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/storage.py b/services/api-server/src/simcore_service_api_server/services_rpc/storage.py index 2a2dd5fb9876..41ac5d44d848 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/storage.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/storage.py @@ -6,15 +6,13 @@ AsyncJobGet, ) from models_library.api_schemas_webserver.storage import PathToExport -from models_library.celery import OwnerMetadata, TaskExecutionMetadata +from models_library.celery import TaskExecutionMetadata from models_library.products import ProductName from models_library.users import UserID from servicelib.celery.task_manager import TaskManager +from .._meta import APP_NAME from ..exceptions.service_errors_utils import service_exception_mapper -from ..models.domain.celery_models import ( - ApiServerOwnerMetadata, -) _exception_mapper = partial(service_exception_mapper, service_name="Storage") @@ -33,9 +31,7 @@ async def start_data_export( return await submit_job( self._task_manager, execution_metadata=TaskExecutionMetadata(name="export_data_as_download_link"), - owner_metadata=OwnerMetadata.model_validate( - ApiServerOwnerMetadata(user_id=self._user_id, product_name=self._product_name).model_dump() - ), + owner=APP_NAME, user_id=self._user_id, product_name=self._product_name, paths_to_export=paths_to_export, diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py index 57bb61fc2af0..455920f20175 100644 --- a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py +++ b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py @@ -57,16 +57,13 @@ X_SIMCORE_PARENT_NODE_ID, X_SIMCORE_PARENT_PROJECT_UUID, ) -from simcore_service_api_server._meta import API_VTAG +from simcore_service_api_server._meta import API_VTAG, APP_NAME from simcore_service_api_server.api.dependencies.authentication import Identity from simcore_service_api_server.api.dependencies.celery import ( get_task_manager, ) from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError from simcore_service_api_server.models.api_resources import JobLinks -from simcore_service_api_server.models.domain.celery_models import ( - ApiServerOwnerMetadata, -) from simcore_service_api_server.models.domain.functions import ( PreRegisteredFunctionJobData, ) @@ -130,6 +127,7 @@ async def run_function( job_links: JobLinks, x_simcore_parent_project_uuid: NodeID | None, x_simcore_parent_node_id: NodeID | None, + **_kwargs: Any, ) -> RegisteredFunctionJob: return RegisteredProjectFunctionJob( title=_faker.sentence(), @@ -338,14 +336,12 @@ async def test_celery_error_propagation( user_identity: Identity, with_api_server_celery_worker: TestWorkController, ): - owner_metadata = ApiServerOwnerMetadata( - user_id=user_identity.user_id, - product_name=user_identity.product_name, - ) task_manager = get_task_manager(app=app) task_uuid = await task_manager.submit_task( TaskExecutionMetadata(name="exception_task", queue=API_SERVER_CELERY_QUEUE_DEFAULT), - owner_metadata=owner_metadata, + owner=APP_NAME, + user_id=user_identity.user_id, + product_name=user_identity.product_name, ) with pytest.raises(HTTPStatusError) as exc_info: diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py index 30235391f46c..da383dc14446 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py @@ -18,7 +18,7 @@ ProjectFunctionJob, RegisteredProjectFunctionJob, ) -from models_library.celery import OwnerMetadata, TaskState, TaskStatus, TaskUUID +from models_library.celery import TaskState, TaskStatus, TaskUUID from models_library.functions import ( FunctionJob, FunctionJobStatus, @@ -263,7 +263,7 @@ async def test_get_function_job_status( ) -> None: _expected_return_status = status.HTTP_200_OK - async def _get_task_status(task_or_group_uuid: TaskUUID, owner_metadata: OwnerMetadata) -> TaskStatus: + async def _get_task_status(task_or_group_uuid: TaskUUID) -> TaskStatus: assert f"{task_or_group_uuid}" == job_creation_task_id return TaskStatus( task_uuid=task_or_group_uuid, diff --git a/services/api-server/tests/unit/service/test_service_function_jobs_task_client.py b/services/api-server/tests/unit/service/test_service_function_jobs_task_client.py index f47bf1e24e41..2d5cf8585d5a 100644 --- a/services/api-server/tests/unit/service/test_service_function_jobs_task_client.py +++ b/services/api-server/tests/unit/service/test_service_function_jobs_task_client.py @@ -1,7 +1,6 @@ # pylint: disable=redefined-outer-name import datetime -import json from collections.abc import Callable from uuid import uuid4 @@ -70,7 +69,7 @@ async def _raise(*args, **kwargs): ) for state in list(TaskState) ] - + [TaskOrGroupNotFoundError(task_uuid=_faker.uuid4(), owner_metadata=json.dumps({"owner": "test-owner"}))], + + [TaskOrGroupNotFoundError(task_uuid=_faker.uuid4())], ) @pytest.mark.parametrize("job_creation_task_id", [_faker.uuid4(), None]) async def test_celery_status_conversion( diff --git a/services/notifications/src/simcore_service_notifications/api/celery/_email.py b/services/notifications/src/simcore_service_notifications/api/celery/_email.py index 789976161e59..60761126d3c0 100644 --- a/services/notifications/src/simcore_service_notifications/api/celery/_email.py +++ b/services/notifications/src/simcore_service_notifications/api/celery/_email.py @@ -2,6 +2,7 @@ import logging from email.headerregistry import Address +from typing import Any from celery import ( # type: ignore[import-untyped] Task, @@ -26,6 +27,7 @@ async def send_email_message( task: Task, task_key: TaskKey, message: EmailMessage, + **_kwargs: Any, ) -> None: assert task # nosec assert task_key # nosec diff --git a/services/notifications/src/simcore_service_notifications/api/rpc/_message.py b/services/notifications/src/simcore_service_notifications/api/rpc/_message.py index 02f817a8e0af..f50ed1de0597 100644 --- a/services/notifications/src/simcore_service_notifications/api/rpc/_message.py +++ b/services/notifications/src/simcore_service_notifications/api/rpc/_message.py @@ -28,7 +28,9 @@ async def send_message( message_service = get_message_service(app) task_or_group_uuid, task_name = await message_service.send_message( message=request.message, - owner_metadata=request.owner_metadata, + owner=request.owner, + user_id=request.user_id, + product_name=request.product_name, ) return SendMessageResponse(task_or_group_uuid=task_or_group_uuid, task_name=task_name) @@ -51,6 +53,8 @@ async def send_message_from_template( addressing=request.addressing, ref=TemplateRef(**request.template_ref.model_dump()), context=request.context, - owner_metadata=request.owner_metadata, + owner=request.owner, + user_id=request.user_id, + product_name=request.product_name, ) return SendMessageResponse(task_or_group_uuid=task_or_group_uuid, task_name=task_name) diff --git a/services/notifications/src/simcore_service_notifications/services/_message.py b/services/notifications/src/simcore_service_notifications/services/_message.py index 83485ba03e70..bea5218e1ba3 100644 --- a/services/notifications/src/simcore_service_notifications/services/_message.py +++ b/services/notifications/src/simcore_service_notifications/services/_message.py @@ -4,7 +4,6 @@ from models_library.celery import ( GroupUUID, - OwnerMetadata, TaskName, TaskUUID, ) @@ -27,8 +26,6 @@ _logger = logging.getLogger(__name__) -_OWNER_METADATA = OwnerMetadata(owner=APP_NAME) - def _prepare_celery_messages(message: Message) -> list[dict[str, Any]]: """Dispatches to channel handler to fan out into per-recipient celery payloads. @@ -58,9 +55,11 @@ async def send_message( self, *, message: Message, - owner_metadata: OwnerMetadata | None = None, + owner: str | None = None, + user_id: int | None = None, + product_name: str | None = None, ) -> tuple[TaskUUID | GroupUUID, TaskName]: - resolved_owner = owner_metadata or _OWNER_METADATA + resolved_owner = owner or APP_NAME messages = _prepare_celery_messages(message) num_recipients = len(messages) @@ -69,7 +68,9 @@ async def send_message( if num_recipients == 1: task_uuid, task_name = await submit_send_message_task( self.task_manager, - owner_metadata=resolved_owner, + owner=resolved_owner, + user_id=user_id, + product_name=product_name, message=messages[0], description=description, ) @@ -84,7 +85,9 @@ async def send_message( group_uuid, _, task_name = await submit_send_messages_task( self.task_manager, - owner_metadata=resolved_owner, + owner=resolved_owner, + user_id=user_id, + product_name=product_name, messages=messages, description=description, ) @@ -96,7 +99,9 @@ async def send_message_from_template( addressing: Addressing, ref: TemplateRef, context: dict[str, Any], - owner_metadata: OwnerMetadata | None = None, + owner: str | None = None, + user_id: int | None = None, + product_name: str | None = None, ) -> tuple[TaskUUID | GroupUUID, TaskName]: preview = self.template_service.preview_template(ref=ref, context=context) message = EmailMessage( @@ -105,5 +110,7 @@ async def send_message_from_template( ) return await self.send_message( message=message, - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, ) diff --git a/services/notifications/tests/unit/test_api_celery_send_email.py b/services/notifications/tests/unit/test_api_celery_send_email.py index 0dcf0414cab1..07370e95f9f6 100644 --- a/services/notifications/tests/unit/test_api_celery_send_email.py +++ b/services/notifications/tests/unit/test_api_celery_send_email.py @@ -2,7 +2,7 @@ import pytest from faker import Faker -from models_library.celery import OwnerMetadata, TaskExecutionMetadata, TaskState, TaskStatus +from models_library.celery import TaskExecutionMetadata, TaskState, TaskStatus from models_library.notifications.celery import EmailContact, EmailContent, EmailMessage from servicelib.celery.task_manager import TaskManager from simcore_service_notifications.api.celery.tasks import ( @@ -37,16 +37,12 @@ async def test_send_mail( smtp_mock_or_none: AsyncMock | None, faker: Faker, ): - owner_metadata = OwnerMetadata( - owner="test_service", - ) - user_email = faker.email() task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=send_email_message.__name__, ), - owner_metadata=owner_metadata, + owner="test_service", message=EmailMessage( from_=EmailContact(email=faker.email()), to=EmailContact(email=user_email), @@ -60,7 +56,7 @@ async def test_send_mail( async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: - status = await task_manager.get_status(owner_metadata, task_uuid) + status = await task_manager.get_status(task_uuid) assert isinstance(status, TaskStatus) # nosec assert status.task_state == TaskState.SUCCESS diff --git a/services/notifications/tests/unit/test_api_rpc_message.py b/services/notifications/tests/unit/test_api_rpc_message.py index 14e8b0b24453..575de74f6a49 100644 --- a/services/notifications/tests/unit/test_api_rpc_message.py +++ b/services/notifications/tests/unit/test_api_rpc_message.py @@ -4,7 +4,6 @@ import pytest from faker import Faker -from models_library.celery import OwnerMetadata from models_library.notifications import Channel from models_library.notifications.errors import ( NotificationsTemplateContextValidationError, @@ -103,19 +102,11 @@ async def test_send_message_single_recipient( assert response.task_name == "send_email_message" -async def test_send_message_with_owner_metadata( +async def test_send_message_with_owner_params( rpc_client: RabbitMQRPCClient, single_recipient_email_message: EmailMessage, mocker: MockerFixture, ): - owner_metadata = OwnerMetadata.model_validate( - { - "owner": "webserver", - "user_id": 42, - "product_name": "osparc", - } - ) - spy = mocker.patch( f"{_message_module.__name__}.submit_send_message_task", wraps=_message_module.submit_send_message_task, @@ -124,7 +115,9 @@ async def test_send_message_with_owner_metadata( response = await send_message( rpc_client, message=single_recipient_email_message, - owner_metadata=owner_metadata, + owner="webserver", + user_id=42, + product_name="osparc", ) assert isinstance(response, SendMessageResponse) assert response.task_or_group_uuid @@ -132,10 +125,9 @@ async def test_send_message_with_owner_metadata( spy.assert_awaited_once() call_kwargs = spy.call_args.kwargs - assert call_kwargs["owner_metadata"] == owner_metadata - assert call_kwargs["owner_metadata"].owner == "webserver" - assert call_kwargs["owner_metadata"].model_dump()["user_id"] == 42 - assert call_kwargs["owner_metadata"].model_dump()["product_name"] == "osparc" + assert call_kwargs["owner"] == "webserver" + assert call_kwargs["user_id"] == 42 + assert call_kwargs["product_name"] == "osparc" async def test_send_message_multiple_recipients( diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py index 638a33da6149..215755707a51 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py @@ -1,4 +1,5 @@ import logging +from typing import Any from celery import Task # type: ignore[import-untyped] from celery_library.worker.app_server import get_app_server @@ -23,6 +24,7 @@ async def complete_upload_file( location_id: LocationID, file_id: StorageFileID, body: FileUploadCompletionBody, + **_kwargs: Any, ) -> FileMetaData: assert task_key # nosec with log_context( diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py index c9c22440ea67..c39c76ea2004 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from typing import Any from celery import Task # type: ignore[import-untyped] from celery_library.worker.app_server import get_app_server @@ -43,6 +44,7 @@ async def delete_paths( user_id: UserID, location_id: LocationID, paths: set[Path], + **_kwargs: Any, ) -> None: assert task_key # nosec with log_context( diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py index b1ddd31b73ac..ecabd7f7e952 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py @@ -43,7 +43,7 @@ async def _task_progress_cb(task: Task, task_key: TaskKey, report: ProgressRepor async def deep_copy_files_from_project( - task: Task, task_key: TaskKey, user_id: UserID, body: FoldersBody + task: Task, task_key: TaskKey, user_id: UserID, body: FoldersBody, **_kwargs: Any ) -> dict[str, Any]: with log_context( _logger, diff --git a/services/storage/src/simcore_service_storage/api/rest/_files.py b/services/storage/src/simcore_service_storage/api/rest/_files.py index 9906fe09707d..394a08724e7a 100644 --- a/services/storage/src/simcore_service_storage/api/rest/_files.py +++ b/services/storage/src/simcore_service_storage/api/rest/_files.py @@ -15,11 +15,10 @@ FileUploadSchema, SoftCopyBody, ) -from models_library.celery import OwnerMetadata, TaskExecutionMetadata, TaskUUID +from models_library.celery import TaskExecutionMetadata, TaskUUID from models_library.generics import Envelope from models_library.projects_nodes_io import LocationID, StorageFileID from models_library.rabbitmq_messages import FileNotificationEventType -from models_library.users import UserID from pydantic import AnyUrl, ByteSize, TypeAdapter from servicelib.aiohttp import status from servicelib.celery.task_manager import TaskManager @@ -44,15 +43,6 @@ from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file from .dependencies.celery import get_task_manager - -def _get_owner_metadata(*, user_id: UserID) -> OwnerMetadata: - _data = { - "owner": APP_NAME, - "user_id": user_id, - } - return OwnerMetadata.model_validate(_data) - - _logger = logging.getLogger(__name__) router = APIRouter( @@ -290,12 +280,11 @@ async def complete_upload_file( # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - owner_metadata = _get_owner_metadata(user_id=query_params.user_id) task_uuid = await task_manager.submit_task( TaskExecutionMetadata( name=remote_complete_upload_file.__name__, ), - owner_metadata=owner_metadata, + owner=APP_NAME, user_id=query_params.user_id, location_id=location_id, file_id=file_id, @@ -342,16 +331,13 @@ async def is_completed_upload_file( # therefore we wait a bit to see if it completes fast and return a 204 # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - owner_metadata = _get_owner_metadata(user_id=query_params.user_id) task_status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=TypeAdapter(TaskUUID).validate_python(future_id), ) # first check if the task is in the app if task_status.is_done: task_result = TypeAdapter(FileMetaData).validate_python( await task_manager.get_result( - owner_metadata=owner_metadata, task_or_group_uuid=TypeAdapter(TaskUUID).validate_python(future_id), ) ) diff --git a/services/storage/tests/unit/test_async_jobs_handlers_paths.py b/services/storage/tests/unit/test_async_jobs_handlers_paths.py index 54d12e0b0209..5e536f462e0a 100644 --- a/services/storage/tests/unit/test_async_jobs_handlers_paths.py +++ b/services/storage/tests/unit/test_async_jobs_handlers_paths.py @@ -23,7 +23,7 @@ from models_library.api_schemas_async_jobs.async_jobs import ( AsyncJobResult, ) -from models_library.celery import OwnerMetadata, TaskExecutionMetadata, Wildcard +from models_library.celery import TaskExecutionMetadata from models_library.products import ProductName from models_library.projects_nodes_io import LocationID, NodeID, SimcoreS3FileID from models_library.users import UserID @@ -38,11 +38,6 @@ type _IsFile = bool -class TestOwnerMetadata(OwnerMetadata): - user_id: int | Wildcard - product_name: str | Wildcard - - def _filter_and_group_paths_one_level_deeper(paths: list[Path], prefix: Path) -> list[tuple[Path, _IsFile]]: relative_paths = (path for path in paths if path.is_relative_to(prefix)) return sorted( @@ -70,15 +65,14 @@ async def _assert_compute_path_size( async_job = await submit_job( task_manager, execution_metadata=TaskExecutionMetadata(name="compute_path_size"), - owner_metadata=TestOwnerMetadata(user_id=user_id, product_name=product_name, owner="pytest_client_name"), - location_id=location_id, - path=path, + owner="pytest_client_name", user_id=user_id, product_name=product_name, + location_id=location_id, + path=path, ) async for job_composed_result in wait_and_get_job_result( task_manager, - owner_metadata=TestOwnerMetadata(user_id=user_id, product_name=product_name, owner="pytest_client_name"), job_id=async_job.job_id, stop_after=datetime.timedelta(seconds=120), ): @@ -104,14 +98,13 @@ async def _assert_delete_paths( async_job = await submit_job( task_manager, execution_metadata=TaskExecutionMetadata(name="delete_paths"), - owner_metadata=TestOwnerMetadata(user_id=user_id, product_name=product_name, owner="pytest_client_name"), - location_id=location_id, + owner="pytest_client_name", user_id=user_id, + location_id=location_id, paths=paths, ) async for job_composed_result in wait_and_get_job_result( task_manager, - owner_metadata=TestOwnerMetadata(user_id=user_id, product_name=product_name, owner="pytest_client_name"), job_id=async_job.job_id, stop_after=datetime.timedelta(seconds=120), ): diff --git a/services/storage/tests/unit/test_async_jobs_handlers_simcore_s3.py b/services/storage/tests/unit/test_async_jobs_handlers_simcore_s3.py index 20c3ea9082c6..51458f3b521c 100644 --- a/services/storage/tests/unit/test_async_jobs_handlers_simcore_s3.py +++ b/services/storage/tests/unit/test_async_jobs_handlers_simcore_s3.py @@ -36,7 +36,7 @@ ) from models_library.api_schemas_webserver.storage import PathToExport from models_library.basic_types import SHA256Str -from models_library.celery import OwnerMetadata, TaskExecutionMetadata +from models_library.celery import TaskExecutionMetadata from models_library.products import ProductName from models_library.projects_nodes_io import NodeID, NodeIDStr, SimcoreS3FileID from models_library.users import UserID @@ -65,12 +65,6 @@ pytest_simcore_ops_services_selection = ["adminer"] -class _TestOwnerMetadata(OwnerMetadata): - user_id: UserID - product_name: ProductName - owner: str = "PYTEST_CLIENT_NAME" - - async def _request_copy_folders( task_manager: TaskManager, user_id: UserID, @@ -85,23 +79,17 @@ async def _request_copy_folders( logging.INFO, f"Copying folders from {source_project['uuid']} to {dst_project['uuid']}", ) as ctx: - owner_metadata = _TestOwnerMetadata( - user_id=user_id, - product_name=product_name, - owner="PYTEST_CLIENT_NAME", - ) - async_job = await submit_job( task_manager, execution_metadata=TaskExecutionMetadata(name="deep_copy_files_from_project"), - owner_metadata=owner_metadata, + owner="PYTEST_CLIENT_NAME", user_id=user_id, + product_name=product_name, body=FoldersBody(source=source_project, destination=dst_project, nodes_map=nodes_map), ) async for async_job_result in wait_and_get_job_result( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, stop_after=stop_after, ): @@ -509,16 +497,10 @@ async def _request_start_export_data( logging.INFO, f"Data export form {paths_to_export=}", ) as ctx: - owner_metadata = _TestOwnerMetadata( - user_id=user_id, - product_name=product_name, - owner="PYTEST_CLIENT_NAME", - ) - async_job = await submit_job( task_manager, execution_metadata=TaskExecutionMetadata(name=task_name), - owner_metadata=owner_metadata, + owner="PYTEST_CLIENT_NAME", user_id=user_id, product_name=product_name, paths_to_export=paths_to_export, @@ -526,7 +508,6 @@ async def _request_start_export_data( async for async_job_result in wait_and_get_job_result( task_manager, - owner_metadata=owner_metadata, job_id=async_job.job_id, stop_after=stop_after, ): diff --git a/services/web/server/src/simcore_service_webserver/models.py b/services/web/server/src/simcore_service_webserver/models.py index 23a7f8082b7d..a9fa5c010a8c 100644 --- a/services/web/server/src/simcore_service_webserver/models.py +++ b/services/web/server/src/simcore_service_webserver/models.py @@ -1,6 +1,5 @@ from typing import Annotated -from models_library.celery import OwnerMetadata from models_library.products import ProductName from models_library.rest_base import RequestParameters from models_library.users import UserID @@ -9,7 +8,6 @@ from servicelib.aiohttp.request_keys import RQT_USERID_KEY from servicelib.rest_constants import X_CLIENT_SESSION_ID_HEADER -from ._meta import APP_NAME from .constants import RQ_PRODUCT_KEY type PhoneNumberStr = Annotated[ @@ -57,9 +55,3 @@ class ClientSessionHeaderParams(RequestParameters): model_config = ConfigDict( validate_by_name=True, ) - - -class WebServerOwnerMetadata(OwnerMetadata): - user_id: UserID | None - product_name: ProductName - owner: Annotated[str, StringConstraints(pattern=rf"^{APP_NAME}$"), Field(frozen=True)] = APP_NAME diff --git a/services/web/server/src/simcore_service_webserver/notifications/_service.py b/services/web/server/src/simcore_service_webserver/notifications/_service.py index 4ca4416b9ef3..3b9f150ece21 100644 --- a/services/web/server/src/simcore_service_webserver/notifications/_service.py +++ b/services/web/server/src/simcore_service_webserver/notifications/_service.py @@ -30,7 +30,7 @@ send_message_from_template as remote_send_message_from_template, ) -from ..models import WebServerOwnerMetadata +from .._meta import APP_NAME from ..products import products_service from ..rabbitmq import get_rabbitmq_rpc_client from ..users import users_service @@ -201,10 +201,9 @@ async def send_message( response = await remote_send_message( get_rabbitmq_rpc_client(app), message=_RPC_MESSAGE_ADAPTER.validate_python(message.model_dump()), - owner_metadata=WebServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ), + owner=APP_NAME, + user_id=user_id, + product_name=product_name, ) return response.task_or_group_uuid, response.task_name @@ -244,10 +243,9 @@ async def send_message_from_template( TemplateRef(channel=channel, template_name=template_name).model_dump() ), context=enriched_context, - owner_metadata=WebServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ), + owner=APP_NAME, + user_id=user_id, + product_name=product_name, ) return response.task_or_group_uuid, response.task_name diff --git a/services/web/server/src/simcore_service_webserver/storage/_rest.py b/services/web/server/src/simcore_service_webserver/storage/_rest.py index c2fc17dce157..9a450020976a 100644 --- a/services/web/server/src/simcore_service_webserver/storage/_rest.py +++ b/services/web/server/src/simcore_service_webserver/storage/_rest.py @@ -30,7 +30,7 @@ StorageLocationPathParams, StoragePathComputeSizeParams, ) -from models_library.celery import OwnerMetadata, TaskExecutionMetadata +from models_library.celery import TaskExecutionMetadata from models_library.products import ProductName from models_library.projects_nodes_io import LocationID from models_library.utils.change_case import camel_to_snake @@ -57,11 +57,11 @@ from servicelib.rest_responses import unwrap_envelope from yarl import URL -from .._meta import API_VTAG +from .._meta import API_VTAG, APP_NAME from ..celery import get_task_manager from ..constants import RQ_PRODUCT_KEY from ..login.decorators import login_required -from ..models import AuthenticatedRequestContext, WebServerOwnerMetadata +from ..models import AuthenticatedRequestContext from ..security.decorators import permission_required from ..tasks._controller._rest_exceptions import handle_rest_requests_exceptions from .schemas import StorageFileIDStr @@ -211,12 +211,7 @@ async def compute_path_size(request: web.Request) -> web.Response: execution_metadata=TaskExecutionMetadata( name=COMPUTE_PATH_SIZE_TASK_NAME, ), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - ).model_dump() - ), + owner=APP_NAME, user_id=req_ctx.user_id, product_name=req_ctx.product_name, location_id=path_params.location_id, @@ -242,12 +237,7 @@ async def batch_delete_paths(request: web.Request): execution_metadata=TaskExecutionMetadata( name=DELETE_PATHS_TASK_NAME, ), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - ).model_dump() - ), + owner=APP_NAME, user_id=req_ctx.user_id, location_id=path_params.location_id, paths=body.paths, @@ -503,12 +493,7 @@ class _PathParams(BaseModel): async_job_get = await submit_export_data( task_manager=get_task_manager(request.app), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - ).model_dump() - ), + owner=APP_NAME, user_id=req_ctx.user_id, product_name=req_ctx.product_name, paths_to_export=body.paths, @@ -545,12 +530,7 @@ class _PathParams(BaseModel): execution_metadata=TaskExecutionMetadata( name=SEARCH_TASK_NAME, ), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - ).model_dump() - ), + owner=APP_NAME, user_id=req_ctx.user_id, product_name=req_ctx.product_name, name_pattern=search_body.filters.name_pattern, diff --git a/services/web/server/src/simcore_service_webserver/storage/api.py b/services/web/server/src/simcore_service_webserver/storage/api.py index 1837dea15c8c..dcab2a1d5907 100644 --- a/services/web/server/src/simcore_service_webserver/storage/api.py +++ b/services/web/server/src/simcore_service_webserver/storage/api.py @@ -19,7 +19,7 @@ FoldersBody, PresignedLink, ) -from models_library.celery import OwnerMetadata, TaskExecutionMetadata +from models_library.celery import TaskExecutionMetadata from models_library.generics import Envelope from models_library.products import ProductName from models_library.projects import ProjectID @@ -30,8 +30,8 @@ from servicelib.logging_utils import log_context from yarl import URL +from .._meta import APP_NAME from ..celery import get_task_manager -from ..models import WebServerOwnerMetadata from ..projects.models import ProjectDict from ..projects.utils import NodesMap from .settings import StorageSettings, get_plugin_settings @@ -107,12 +107,9 @@ async def copy_data_folders_from_project( async for job_composed_result in submit_job_and_wait( task_manager, execution_metadata=TaskExecutionMetadata(name="deep_copy_files_from_project"), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=user_id, - product_name=product_name, - ).model_dump() - ), + owner=APP_NAME, + user_id=user_id, + product_name=product_name, body=TypeAdapter(FoldersBody).validate_python( { "source": source_project, @@ -121,7 +118,6 @@ async def copy_data_folders_from_project( } ), stop_after=datetime.timedelta(seconds=_TOTAL_TIMEOUT_TO_COPY_DATA_SECS), - user_id=user_id, ): yield job_composed_result diff --git a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py index c8ea7a6f031f..a279f9d54ced 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py @@ -8,7 +8,6 @@ TaskResult, TaskStatus, ) -from models_library.celery import OwnerMetadata from servicelib.aiohttp import status from servicelib.aiohttp.long_running_tasks.server import ( get_long_running_manager, @@ -22,11 +21,11 @@ ) from servicelib.long_running_tasks import lrt_api -from ..._meta import API_VTAG +from ..._meta import API_VTAG, APP_NAME from ...celery import get_task_manager from ...login.decorators import login_required from ...long_running_tasks.plugin import webserver_request_context_decorator -from ...models import AuthenticatedRequestContext, WebServerOwnerMetadata +from ...models import AuthenticatedRequestContext from .. import _tasks_service from ._rest_exceptions import handle_rest_requests_exceptions from ._rest_schemas import TaskPathParams, TaskStreamQueryParams, TaskStreamResponse @@ -58,12 +57,9 @@ async def get_async_jobs(request: web.Request) -> web.Response: tasks = await _tasks_service.list_tasks( get_task_manager(request.app), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, - ).model_dump() - ), + owner=APP_NAME, + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, ) return create_data_response( @@ -102,12 +98,6 @@ async def get_async_job_status(request: web.Request) -> web.Response: task_status = await _tasks_service.get_task_status( get_task_manager(request.app), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, - ).model_dump() - ), task_uuid=_path_params.task_id, ) @@ -139,12 +129,6 @@ async def cancel_async_job(request: web.Request) -> web.Response: await _tasks_service.cancel_task( get_task_manager(request.app), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, - ).model_dump() - ), task_uuid=_path_params.task_id, ) @@ -163,12 +147,6 @@ async def get_async_job_result(request: web.Request) -> web.Response: task_result = await _tasks_service.get_task_result( get_task_manager(request.app), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, - ).model_dump() - ), task_uuid=_path_params.task_id, ) @@ -191,12 +169,6 @@ async def get_async_job_stream(request: web.Request) -> web.Response: task_result, end = await _tasks_service.pull_task_stream_items( get_task_manager(request.app), - owner_metadata=OwnerMetadata.model_validate( - WebServerOwnerMetadata( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, - ).model_dump() - ), task_uuid=_path_params.task_id, limit=_query_params.limit, ) diff --git a/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py b/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py index 2ed318aecafa..06583ca7b20c 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py @@ -20,7 +20,6 @@ JobSchedulerError, ) from models_library.celery import ( - OwnerMetadata, TaskState, TaskStatus, TaskStreamItem, @@ -39,12 +38,10 @@ async def cancel_task( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, ): try: await task_manager.cancel( - owner_metadata=owner_metadata, task_or_group_uuid=task_uuid, ) except TaskOrGroupNotFoundError as exc: @@ -56,18 +53,15 @@ async def cancel_task( async def get_task_result( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, ) -> AsyncJobResult: try: status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=task_uuid, ) if not status.is_done: raise JobNotDoneError(job_id=task_uuid) result = await task_manager.get_result( - owner_metadata=owner_metadata, task_or_group_uuid=task_uuid, ) except TaskOrGroupNotFoundError as exc: @@ -99,12 +93,10 @@ async def get_task_result( async def get_task_status( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, ) -> AsyncJobStatus: try: task_status = await task_manager.get_status( - owner_metadata=owner_metadata, task_or_group_uuid=task_uuid, ) except TaskOrGroupNotFoundError as exc: @@ -122,13 +114,11 @@ async def get_task_status( async def pull_task_stream_items( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, task_uuid: TaskUUID, limit: int = 50, ) -> tuple[list[TaskStreamItem], bool]: try: results, end, last_update = await task_manager.pull_task_stream_items( - owner_metadata=owner_metadata, task_uuid=task_uuid, limit=limit, ) @@ -148,11 +138,15 @@ async def pull_task_stream_items( async def list_tasks( task_manager: TaskManager, *, - owner_metadata: OwnerMetadata, + owner: str, + user_id: int | None = None, + product_name: str | None = None, ) -> list[AsyncJobGet]: try: tasks = await task_manager.list_tasks( - owner_metadata=owner_metadata, + owner=owner, + user_id=user_id, + product_name=product_name, ) except TaskManagerError as exc: raise JobSchedulerError(exc=f"{exc}") from exc diff --git a/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications_service.py b/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications_service.py index 3b9ea87f4b88..7cd58f44f58a 100644 --- a/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications_service.py +++ b/services/web/server/tests/unit/with_dbs/04/notifications/test_notifications_service.py @@ -220,10 +220,10 @@ async def test_send_message_from_template_passes_correct_template_ref( assert len(addressing.to) == 1 assert addressing.to[0].email == external_contacts[0].email - # Verify owner_metadata - owner_metadata = call_kwargs["owner_metadata"] - assert owner_metadata.user_id == logged_user["id"] - assert owner_metadata.product_name == "osparc" + # Verify owner params + assert call_kwargs["owner"] is not None + assert call_kwargs["user_id"] == logged_user["id"] + assert call_kwargs["product_name"] == "osparc" async def test_send_message_from_template_unsupported_channel( From b4e4bf1871568beeac7efb98881c79e186846d45 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 16 Apr 2026 13:27:19 +0200 Subject: [PATCH 06/14] fix --- .../src/pytest_simcore/helpers/async_jobs_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py index d3f2b9e6d392..71cd2835e846 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py @@ -91,8 +91,8 @@ async def list_jobs( assert rabbitmq_rpc_client assert rpc_namespace assert owner - assert user_id is not None or user_id is None - assert product_name is not None or product_name is None + assert user_id is None or isinstance(user_id, int) + assert product_name is None or isinstance(product_name, str) assert filter_ is not None if self.exception is not None: From 283c8c4102f8384332453c0faa0ff98851ab5d3c Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 16 Apr 2026 15:55:25 +0200 Subject: [PATCH 07/14] fix --- .../src/celery_library/backends/_redis.py | 44 +++++-------------- .../tests/unit/test_redis_store.py | 41 ++++++++++++++--- 2 files changed, 46 insertions(+), 39 deletions(-) diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 515ecf0a9ee7..0d04a692f709 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -2,7 +2,6 @@ import logging from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from itertools import product from typing import TYPE_CHECKING, Final from models_library.celery import ( @@ -74,41 +73,20 @@ def _concrete_owner_fields( return sorted(pairs) -def _build_redis_index_key_for_query( +def _build_redis_index_key_for_owner( owner: str, user_id: int | None, product_name: str | None, ) -> str: - """Build the single sorted-set key used to answer a list_tasks query. + """Build a single sorted-set index key from the concrete owner fields. - Concrete fields are kept; missing (None) fields are omitted — - the sorted set was pre-populated for every field subset at creation time. + Used for both creation and querying — each task lives in exactly one + index, so the caller must supply the same fields at query time. """ parts = [f"{k}={v}" for k, v in _concrete_owner_fields(owner, user_id, product_name)] return _build_redis_index_key(_CELERY_TASK_DELIMTATOR.join(parts)) -def _build_redis_index_keys_for_creation( - owner: str, - user_id: int | None, - product_name: str | None, -) -> list[str]: - """Generate all 2^n sorted-set index keys for the given owner fields. - - Every subset of the concrete fields gets its own key so that any query - specifying a subset of those fields can be answered with a single - sorted-set lookup. - """ - fields = _concrete_owner_fields(owner, user_id, product_name) - - keys: list[str] = [] - for mask in product((False, True), repeat=len(fields)): - selected = [(k, v) for (k, v), include in zip(fields, mask, strict=True) if include] - suffix = _CELERY_TASK_DELIMTATOR.join(f"{k}={v}" for k, v in selected) if selected else "" - keys.append(_build_redis_index_key(suffix)) - return keys - - @dataclass(frozen=True) class RedisTaskStore: _redis_client_sdk: RedisClientSDK @@ -133,8 +111,8 @@ async def create_group( key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) - for index_key in _build_redis_index_keys_for_creation(owner, user_id, product_name): - pipe.zadd(index_key, {group_key: index_score}) + index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) + pipe.zadd(index_key, {group_key: index_score}) # group sub-tasks: store hash only, no ZSET index (filtered out in list_tasks) for task_key, (task_execution_metadata, _) in zip(task_keys, execution_metadata.tasks, strict=True): @@ -168,8 +146,8 @@ async def create_task( key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) - for index_key in _build_redis_index_keys_for_creation(owner, user_id, product_name): - pipe.zadd(index_key, {task_key: index_score}) + index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) + pipe.zadd(index_key, {task_key: index_score}) await pipe.execute() await self._redis_client_sdk.redis.expire( @@ -225,7 +203,7 @@ async def list_tasks( user_id: int | None = None, product_name: str | None = None, ) -> list[Task]: - owner_index_key = _build_redis_index_key_for_query(owner, user_id, product_name) + owner_index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) raw_members = await self._redis_client_sdk.redis.zrange(owner_index_key, 0, -1) if not raw_members: @@ -273,8 +251,8 @@ async def remove_task( ) -> None: pipe = self._redis_client_sdk.redis.pipeline() pipe.delete(_build_redis_task_or_group_key(task_key)) - for index_key in _build_redis_index_keys_for_creation(owner, user_id, product_name): - pipe.zrem(index_key, task_key) + index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) + pipe.zrem(index_key, task_key) await pipe.execute() async def remove_task_hash(self, task_key: TaskKey) -> None: diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py index 313afbfd83a1..ed9eccd9050a 100644 --- a/packages/celery-library/tests/unit/test_redis_store.py +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -74,13 +74,15 @@ def _forbid_scan_iter(*_args: object, **_kwargs: object) -> None: assert tasks[0].uuid == UUID(task_key) -async def test_list_tasks_with_wildcard_filtering( +async def test_list_tasks_filters_by_exact_owner_fields( redis_task_store: RedisTaskStore, ): - user_id = 42 owner = "test-svc" + product = "osparc" + user_id = 42 expected_tasks: list[Task] = [] + # 5 tasks with same owner + product_name + user_id for _ in range(5): task_key = _faker.uuid4() await redis_task_store.create_task( @@ -88,7 +90,7 @@ async def test_list_tasks_with_wildcard_filtering( TaskExecutionMetadata(name="my_task"), owner=owner, user_id=user_id, - product_name=_faker.word(), + product_name=product, expiry=timedelta(minutes=5), ) expected_tasks.append( @@ -98,6 +100,7 @@ async def test_list_tasks_with_wildcard_filtering( ) ) + # 3 tasks with a different user id for _ in range(3): task_key = _faker.uuid4() await redis_task_store.create_task( @@ -105,15 +108,41 @@ async def test_list_tasks_with_wildcard_filtering( TaskExecutionMetadata(name="my_task"), owner=owner, user_id=_faker.pyint(min_value=100, max_value=200), - product_name=_faker.word(), + product_name=product, expiry=timedelta(minutes=5), ) - # Query by owner + user_id only (product_name=None acts as wildcard) - tasks = await redis_task_store.list_tasks(owner=owner, user_id=user_id) + tasks = await redis_task_store.list_tasks(owner=owner, user_id=user_id, product_name=product) assert {t.uuid for t in tasks} == {t.uuid for t in expected_tasks} +async def test_list_tasks_with_no_user_id( + redis_task_store: RedisTaskStore, +): + """Internal notifications have no user_id.""" + owner = "notifications-svc" + product = "osparc" + + task_key = _faker.uuid4() + await redis_task_store.create_task( + task_key, + TaskExecutionMetadata(name="send_notification"), + owner=owner, + user_id=None, + product_name=product, + expiry=timedelta(minutes=5), + ) + + # Query without user_id matches + tasks = await redis_task_store.list_tasks(owner=owner, product_name=product) + assert len(tasks) == 1 + assert tasks[0].uuid == UUID(task_key) + + # Query with a user_id does NOT match + tasks = await redis_task_store.list_tasks(owner=owner, user_id=1, product_name=product) + assert len(tasks) == 0 + + async def test_remove_task_cleans_up_zset_indexes( redis_task_store: RedisTaskStore, ): From 1d2bf43a6355f7c5636f407f7aa0517562267cdc Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Fri, 17 Apr 2026 16:03:42 +0200 Subject: [PATCH 08/14] fix --- packages/models-library/src/models_library/celery.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/models-library/src/models_library/celery.py b/packages/models-library/src/models_library/celery.py index 0ef6d7b18f5e..5b1e5dca2906 100644 --- a/packages/models-library/src/models_library/celery.py +++ b/packages/models-library/src/models_library/celery.py @@ -172,7 +172,6 @@ async def remove_task_hash(self, task_key: TaskKey) -> None: Stale index entries are cleaned lazily by ``list_tasks``. Use this when the owner info is unavailable (e.g. cancel, ephemeral cleanup). """ - ... async def set_task_progress( self, From b4ed19cc0778806540844a69e10e0a9b6f6b9c6a Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 23 Apr 2026 12:47:14 +0200 Subject: [PATCH 09/14] fix --- .../src/celery_library/backends/_redis.py | 15 +++++ .../tests/unit/test_redis_store.py | 56 ++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 0d04a692f709..0f089146a682 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -91,6 +91,19 @@ def _build_redis_index_key_for_owner( class RedisTaskStore: _redis_client_sdk: RedisClientSDK + async def _refresh_index_key_ttl(self, index_key: str, expiry: timedelta) -> None: + """Ensure the index key TTL is at least ``expiry``. + + ``EXPIRE ... GT`` cannot be used because Redis treats a key with no TTL + as having infinite TTL for the purpose of the GT comparison, so it would + leave the index key persistent on first creation. We instead read the + current TTL and only extend it. + """ + current_ttl = await self._redis_client_sdk.redis.ttl(index_key) + # ttl returns: -2 (no key), -1 (no TTL), or remaining seconds. + if current_ttl < int(expiry.total_seconds()): + await self._redis_client_sdk.redis.expire(index_key, expiry) + async def create_group( self, group_key: GroupKey, @@ -126,6 +139,7 @@ async def create_group( redis_group_key, expiry, ) + await self._refresh_index_key_ttl(index_key, expiry) async def create_task( self, @@ -154,6 +168,7 @@ async def create_task( redis_key, expiry, ) + await self._refresh_index_key_ttl(index_key, expiry) async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None: redis_key = _build_redis_task_or_group_key(task_key) diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py index ed9eccd9050a..21270503e4ac 100644 --- a/packages/celery-library/tests/unit/test_redis_store.py +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -6,7 +6,10 @@ import pytest from celery_library.backends import RedisTaskStore -from celery_library.backends._redis import _build_redis_task_or_group_key +from celery_library.backends._redis import ( + _build_redis_index_key_for_owner, + _build_redis_task_or_group_key, +) from faker import Faker from models_library.celery import ( Task, @@ -186,3 +189,54 @@ async def test_stale_zset_entries_are_pruned_on_list( assert await redis_task_store.list_tasks(owner="test-svc", user_id=10004, product_name="osparc") == [] # Second list confirms the ZSET is clean assert await redis_task_store.list_tasks(owner="test-svc", user_id=10004, product_name="osparc") == [] + + +async def test_index_key_has_ttl_and_only_grows( + redis_task_store: RedisTaskStore, + redis_client_sdk: RedisClientSDK, +): + """The ZSET index key must have a TTL bounded by the longest member expiry, + so it cannot grow unboundedly when ``list_tasks`` is never called. + """ + owner, user_id, product = "test-svc", 10005, "osparc" + index_key = _build_redis_index_key_for_owner(owner, user_id, product) + redis = redis_client_sdk.redis + + short_expiry = timedelta(minutes=1) + long_expiry = timedelta(hours=1) + + # First add with a short expiry: index key gets a TTL ~ short_expiry + await redis_task_store.create_task( + _faker.uuid4(), + TaskExecutionMetadata(name="my_task"), + owner=owner, + user_id=user_id, + product_name=product, + expiry=short_expiry, + ) + ttl_after_short = await redis.ttl(index_key) + assert 0 < ttl_after_short <= int(short_expiry.total_seconds()) + + # Add a longer-lived task: TTL must be extended to cover it + await redis_task_store.create_task( + _faker.uuid4(), + TaskExecutionMetadata(name="my_task"), + owner=owner, + user_id=user_id, + product_name=product, + expiry=long_expiry, + ) + ttl_after_long = await redis.ttl(index_key) + assert ttl_after_long >= int(long_expiry.total_seconds()) - 1 + + # Add another short-lived task: TTL must NOT shrink below the long task's lifetime + await redis_task_store.create_task( + _faker.uuid4(), + TaskExecutionMetadata(name="my_task"), + owner=owner, + user_id=user_id, + product_name=product, + expiry=short_expiry, + ) + ttl_after_second_short = await redis.ttl(index_key) + assert ttl_after_second_short >= int(long_expiry.total_seconds()) - 5 From 1dfdd010914c15d4445eb85feceb2bf32489677f Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 23 Apr 2026 12:50:47 +0200 Subject: [PATCH 10/14] fix --- .../src/celery_library/_task_manager.py | 3 ++ .../src/celery_library/backends/_redis.py | 15 +++++--- .../tests/unit/test_redis_store.py | 36 +++++++++++++++++++ .../src/models_library/celery.py | 1 + 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/packages/celery-library/src/celery_library/_task_manager.py b/packages/celery-library/src/celery_library/_task_manager.py index 4ee0c3ebb6ce..08d7eb37c9d4 100644 --- a/packages/celery-library/src/celery_library/_task_manager.py +++ b/packages/celery-library/src/celery_library/_task_manager.py @@ -162,6 +162,9 @@ async def submit_group( user_id=user_id, product_name=product_name, expiry=expiry, + # Group sub-tasks are listed via their parent group, so they + # must not appear in the owner's task index. + index=False, ) group_result: GroupResult = group(sigs).apply_async() diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/_redis.py index 0f089146a682..ed14abe472c9 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/_redis.py @@ -127,7 +127,10 @@ async def create_group( index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) pipe.zadd(index_key, {group_key: index_score}) - # group sub-tasks: store hash only, no ZSET index (filtered out in list_tasks) + # Sub-task hashes are stored so they can be looked up by UUID, but they + # are intentionally NOT added to the owner index: the parent group is + # the listable unit. ``create_task`` is called with ``index=False`` for + # each sub-task in ``TaskManager.submit_group``. for task_key, (task_execution_metadata, _) in zip(task_keys, execution_metadata.tasks, strict=True): pipe.hset( name=_build_redis_task_or_group_key(task_key), @@ -150,6 +153,7 @@ async def create_task( user_id: int | None = None, product_name: str | None = None, expiry: timedelta, + index: bool = True, ) -> None: redis_key = _build_redis_task_or_group_key(task_key) index_score = datetime.now(tz=UTC).timestamp() @@ -160,15 +164,18 @@ async def create_task( key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) - index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) - pipe.zadd(index_key, {task_key: index_score}) + index_key: str | None = None + if index: + index_key = _build_redis_index_key_for_owner(owner, user_id, product_name) + pipe.zadd(index_key, {task_key: index_score}) await pipe.execute() await self._redis_client_sdk.redis.expire( redis_key, expiry, ) - await self._refresh_index_key_ttl(index_key, expiry) + if index_key is not None: + await self._refresh_index_key_ttl(index_key, expiry) async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None: redis_key = _build_redis_task_or_group_key(task_key) diff --git a/packages/celery-library/tests/unit/test_redis_store.py b/packages/celery-library/tests/unit/test_redis_store.py index 21270503e4ac..583cb6a4c255 100644 --- a/packages/celery-library/tests/unit/test_redis_store.py +++ b/packages/celery-library/tests/unit/test_redis_store.py @@ -240,3 +240,39 @@ async def test_index_key_has_ttl_and_only_grows( ) ttl_after_second_short = await redis.ttl(index_key) assert ttl_after_second_short >= int(long_expiry.total_seconds()) - 5 + + +async def test_create_task_with_index_false_skips_owner_index( + redis_task_store: RedisTaskStore, + redis_client_sdk: RedisClientSDK, +): + """Group sub-tasks must not appear in the owner index (they are listed via + their parent group), so ``create_task(index=False)`` must skip the zadd. + """ + owner, user_id, product = "test-svc", 10006, "osparc" + indexed_key = _faker.uuid4() + sub_task_key = _faker.uuid4() + + await redis_task_store.create_task( + indexed_key, + TaskExecutionMetadata(name="my_task"), + owner=owner, + user_id=user_id, + product_name=product, + expiry=timedelta(minutes=5), + ) + await redis_task_store.create_task( + sub_task_key, + TaskExecutionMetadata(name="my_sub_task"), + owner=owner, + user_id=user_id, + product_name=product, + expiry=timedelta(minutes=5), + index=False, + ) + + # Sub-task hash exists (so status/result lookups by UUID still work)... + assert await redis_client_sdk.redis.exists(_build_redis_task_or_group_key(sub_task_key)) == 1 + # ...but only the indexed task appears in the owner listing. + listed = await redis_task_store.list_tasks(owner=owner, user_id=user_id, product_name=product) + assert {t.uuid for t in listed} == {UUID(indexed_key)} diff --git a/packages/models-library/src/models_library/celery.py b/packages/models-library/src/models_library/celery.py index 5b1e5dca2906..f369f98967be 100644 --- a/packages/models-library/src/models_library/celery.py +++ b/packages/models-library/src/models_library/celery.py @@ -141,6 +141,7 @@ async def create_task( user_id: int | None = None, product_name: str | None = None, expiry: timedelta, + index: bool = True, ) -> None: ... async def task_or_group_exists(self, task_or_group_key: TaskKey | GroupKey) -> bool: ... From a9f2d739054cc2d297d34e77009ad4befc8d9957 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 23 Apr 2026 13:00:04 +0200 Subject: [PATCH 11/14] type --- packages/celery-library/src/celery_library/_task_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/celery-library/src/celery_library/_task_manager.py b/packages/celery-library/src/celery_library/_task_manager.py index 08d7eb37c9d4..cebe3eb2616e 100644 --- a/packages/celery-library/src/celery_library/_task_manager.py +++ b/packages/celery-library/src/celery_library/_task_manager.py @@ -400,7 +400,7 @@ async def _get_group_status(self, group_uuid: GroupUUID) -> GroupStatus: # Get task UUIDs from the group result # AsyncResult objects have .id attribute containing the task key (UUID string) - task_uuids = [ + task_uuids: list[TaskUUID] = [ TypeAdapter(TaskUUID).validate_python(async_result.id) for async_result in (group_result.results or []) ] From 0c525bf8bf40ad8f4f6f1eeb9f0523fcf10e6190 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 23 Apr 2026 13:25:44 +0200 Subject: [PATCH 12/14] fix --- .../src/celery_library/_task_manager.py | 18 ++++++++---------- .../tests/unit/task_manager/conftest.py | 12 ++++++------ .../src/servicelib/celery/task_manager.py | 2 -- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/packages/celery-library/src/celery_library/_task_manager.py b/packages/celery-library/src/celery_library/_task_manager.py index cebe3eb2616e..a1e8d943ea5e 100644 --- a/packages/celery-library/src/celery_library/_task_manager.py +++ b/packages/celery-library/src/celery_library/_task_manager.py @@ -202,14 +202,18 @@ async def submit_task( execution_metadata: TaskExecutionMetadata, *, owner: str, - user_id: int | None = None, - product_name: str | None = None, **task_params, ) -> TaskUUID: + # NOTE: ``user_id`` and ``product_name`` are not dedicated parameters of + # this method. When present in ``task_params`` they are observed (not + # consumed) to build the owner index key, then forwarded to the worker + # as part of ``task_params`` like any other task argument. + user_id: int | None = task_params.get("user_id") + product_name: str | None = task_params.get("product_name") with log_context( _logger, logging.DEBUG, - msg=f"Submit {execution_metadata.name=}: {owner=} {user_id=} {product_name=} {task_params=}", + msg=f"Submit {execution_metadata.name=}: {owner=} {task_params=}", ): task_uuid, task_key = self._create_task_ids() expiry = self._get_task_expiry(execution_metadata) @@ -223,16 +227,10 @@ async def submit_task( product_name=product_name, expiry=expiry, ) - # Forward non-None owner fields so workers can access user_id/product_name - _owner_kwargs: dict[str, Any] = {} - if user_id is not None: - _owner_kwargs["user_id"] = user_id - if product_name is not None: - _owner_kwargs["product_name"] = product_name self._app.send_task( execution_metadata.name, task_id=task_key, - kwargs={"task_key": task_key} | _owner_kwargs | task_params, + kwargs={"task_key": task_key} | task_params, queue=execution_metadata.queue, ) except CeleryError as exc: diff --git a/packages/celery-library/tests/unit/task_manager/conftest.py b/packages/celery-library/tests/unit/task_manager/conftest.py index 56fc4ef335d9..32cf5c8b4c61 100644 --- a/packages/celery-library/tests/unit/task_manager/conftest.py +++ b/packages/celery-library/tests/unit/task_manager/conftest.py @@ -6,7 +6,7 @@ import time from collections.abc import Callable from random import randint -from typing import Final +from typing import Any, Final import pytest from celery import Celery, Task # pylint: disable=no-name-in-module @@ -59,7 +59,7 @@ def sleep_for(seconds: float) -> None: return "archive.zip" -def fake_file_processor(task: Task, task_key: TaskKey, files: list[str]) -> str: +def fake_file_processor(task: Task, task_key: TaskKey, files: list[str], **_kwargs: Any) -> str: assert task_key assert task.name _logger.info("Calling _fake_file_processor") @@ -73,14 +73,14 @@ class MyError(OsparcErrorMixin, Exception): msg_template = "Something strange happened: {msg}" -def failure_task(task: Task, task_key: TaskKey) -> None: +def failure_task(task: Task, task_key: TaskKey, **_kwargs: Any) -> None: assert task_key assert task msg = "BOOM!" raise MyError(msg=msg) -async def dreamer_task(task: Task, task_key: TaskKey) -> list[int]: +async def dreamer_task(task: Task, task_key: TaskKey, **_kwargs: Any) -> list[int]: numbers = [] for _ in range(30): numbers.append(randint(1, 90)) # noqa: S311 @@ -88,7 +88,7 @@ async def dreamer_task(task: Task, task_key: TaskKey) -> list[int]: return numbers -def streaming_results_task(task: Task, task_key: TaskKey, num_results: int = 5) -> str: +def streaming_results_task(task: Task, task_key: TaskKey, num_results: int = 5, **_kwargs: Any) -> str: assert task_key assert task.name @@ -117,7 +117,7 @@ async def _stream_results(sleep_interval: float) -> None: _RATE_LIMITED_NOOP_RATE: Final[str] = "6/m" # NOTE: 6 tasks per minute -def noop_task(task: Task, task_key: TaskKey) -> str: +def noop_task(task: Task, task_key: TaskKey, **_kwargs: Any) -> str: assert task_key return "done" diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index 1c0c74321231..318d6eac1365 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -32,8 +32,6 @@ async def submit_task( execution_metadata: TaskExecutionMetadata, *, owner: str, - user_id: int | None = None, - product_name: str | None = None, **task_params, ) -> TaskUUID: ... From 863b2ed9d39ba2e7b9d413ac7cc2f53e78d528d5 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 23 Apr 2026 13:31:07 +0200 Subject: [PATCH 13/14] fix --- .../src/celery_library/_task_manager.py | 18 +++++++++++------- .../src/servicelib/celery/task_manager.py | 2 ++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/packages/celery-library/src/celery_library/_task_manager.py b/packages/celery-library/src/celery_library/_task_manager.py index a1e8d943ea5e..e649a3de7e97 100644 --- a/packages/celery-library/src/celery_library/_task_manager.py +++ b/packages/celery-library/src/celery_library/_task_manager.py @@ -202,22 +202,26 @@ async def submit_task( execution_metadata: TaskExecutionMetadata, *, owner: str, + user_id: int | None = None, + product_name: str | None = None, **task_params, ) -> TaskUUID: - # NOTE: ``user_id`` and ``product_name`` are not dedicated parameters of - # this method. When present in ``task_params`` they are observed (not - # consumed) to build the owner index key, then forwarded to the worker - # as part of ``task_params`` like any other task argument. - user_id: int | None = task_params.get("user_id") - product_name: str | None = task_params.get("product_name") with log_context( _logger, logging.DEBUG, - msg=f"Submit {execution_metadata.name=}: {owner=} {task_params=}", + msg=f"Submit {execution_metadata.name=}: {owner=} {user_id=} {product_name=} {task_params=}", ): task_uuid, task_key = self._create_task_ids() expiry = self._get_task_expiry(execution_metadata) + # Merge owner fields into the worker payload alongside ``task_key``. + # This avoids forcing every worker to declare them while still + # forwarding them to those that do. + if user_id is not None: + task_params.setdefault("user_id", user_id) + if product_name is not None: + task_params.setdefault("product_name", product_name) + try: await self._task_store.create_task( task_key, diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index 318d6eac1365..1c0c74321231 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -32,6 +32,8 @@ async def submit_task( execution_metadata: TaskExecutionMetadata, *, owner: str, + user_id: int | None = None, + product_name: str | None = None, **task_params, ) -> TaskUUID: ... From f2fbbdac03ec8d1f582db385fa8f0fb480ae5e11 Mon Sep 17 00:00:00 2001 From: Giancarlo Romeo Date: Thu, 23 Apr 2026 13:42:15 +0200 Subject: [PATCH 14/14] fix --- packages/celery-library/tests/unit/test_async_jobs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/celery-library/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py index 3c8d4661c32a..0879e9a8f1da 100644 --- a/packages/celery-library/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -63,13 +63,13 @@ async def _process_action(action: str, payload: Any) -> Any: return None -def sync_job(task: Task, task_key: TaskKey, action: Action, payload: Any) -> Any: +def sync_job(task: Task, task_key: TaskKey, action: Action, payload: Any, **_kwargs: Any) -> Any: _ = task _ = task_key return asyncio.run(_process_action(action, payload)) -async def async_job(task: Task, task_key: TaskKey, action: Action, payload: Any) -> Any: +async def async_job(task: Task, task_key: TaskKey, action: Action, payload: Any, **_kwargs: Any) -> Any: _ = task _ = task_key return await _process_action(action, payload)