diff --git a/python/tasks/tiktok/celery.py b/python/tasks/tiktok/celery.py index 27452ed..bbccb68 100644 --- a/python/tasks/tiktok/celery.py +++ b/python/tasks/tiktok/celery.py @@ -1,5 +1,5 @@ from .config import tiktok_config -from .utils import kombu_register_json_dataclass +from .utils import kombu_register_json_dataclass, Backend import celery import redis @@ -9,13 +9,17 @@ c = tiktok_config() app = celery.Celery( __name__, broker=c.celery_broker, - result_backend=c.celery_result_backend, + #result_backend=c.celery_result_backend, + #backend=Backend, + #result_backend=c.celery_result_backend, accept_content=['json-dataclass'], task_serializer='json-dataclass', result_serializer='json-dataclass', + task_track_started=True, ) kombu_register_json_dataclass() +app._backend = Backend(app=app, url=c.celery_result_backend) app.autodiscover_tasks(c.celery_imports) diff --git a/python/tasks/tiktok/utils.py b/python/tasks/tiktok/utils.py index e5f6e89..079cfcb 100644 --- a/python/tasks/tiktok/utils.py +++ b/python/tasks/tiktok/utils.py @@ -1,4 +1,6 @@ import celery.app.task +import celery.backends.redis +import datetime import os import asyncio import inspect @@ -36,7 +38,11 @@ def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any: # return decorator #else: # return decorator2 - return celery.shared_task(base=Task, **kwargs) + return celery.shared_task( + base=Task, + track_started=True, + **kwargs + ) def is_async() -> bool: try: @@ -45,6 +51,61 @@ def is_async() -> bool: except RuntimeError: return False +class Backend(celery.backends.redis.RedisBackend): + def __init__(self, *args, **kwargs): + return super().__init__(*args, **kwargs) + + def _store_result(self, task_id, result, state, + traceback=None, request=None, **kwargs): + meta = self._get_result_meta(result=result, state=state, + traceback=traceback, request=request, task_id=task_id,) + meta['task_id'] = celery.backends.base.bytes_to_str(task_id) + + # Retrieve metadata from the backend, if the status + # is a success then we ignore any following update to the state. + # This solves a task deduplication issue because of network + # partitioning or lost workers. This issue involved a race condition + # making a lost task overwrite the last successful result in the + # result backend. + current_meta = self._get_task_meta_for(task_id) + + if current_meta['status'] == celery.states.SUCCESS: + return result + + try: + self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state) + except celery.backends.base.BackendStoreError as ex: + raise celery.backends.base.BackendStoreError(str(ex), state=state, task_id=task_id) from ex + + return result + + + def _get_result_meta( + self, + *args, + task_id: Optional[str]=None, + state: Optional[str]=None, + **kwargs + ): + current_meta = None + + if not task_id is None: + current_meta = self._get_task_meta_for(task_id) + + res = super()._get_result_meta(*args, state=state, **kwargs) + + if state == celery.states.STARTED: + if not 'date_started' in res: + res['date_started'] = datetime.datetime.now() + + for k in ['date_started',]: + if k in current_meta: + res[k] = current_meta[k] + + print([args, state, kwargs, res]) + + return res + class Task(celery.app.task.Task): def __call__(self, *args, **kwargs) -> Any: res = super().__call__(*args, **kwargs) @@ -56,6 +117,35 @@ class Task(celery.app.task.Task): #def apply(self, *args, **kwargs): # return self.__call__(*args, **kwargs) + #def before_start(self, task_id: str, *args, **kwargs): + # self.update_state(None, celery.states.STARTED) + # + # meta = self.backend._get_task_meta_for(task_id) + # + # assert isinstance(meta, dict) + # + # if not 'date_started' in meta: + # meta['date_started'] + # + # self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state) + + #def update_state( + # self, + # *args, + # state: Optional[str]=None, + # meta: Optional[dict[str,Any]]=None, + # **kwargs + #): + # print(['blah', meta, state]) + # + # if not meta is None: + # logger.info(json.dumps(dict(state=state))) + # + # if not 'date_started' in meta and state == celery.states.STARTED: + # meta['date_started'] = datetime.datetime.now() + # + # return super().update_stae(*args, state=state, meta=meta, **kwargs) + @classmethod def _loads( cls,