import celery.app.task import celery.backends.redis import datetime import os import asyncio import inspect import importlib import kombu.utils.json from typing import ( Any, Optional, Callable, Iterable, ) def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any: #@celery.shared_task( # base=Task, # **kwargs, #) #def decorator2(*args, **kwargs): # res = func(*args, **kwargs) # # if inspect.isawaitable(res): # return asyncio.run(res) # else: # return res # #def decorator(func2: Callable[Any, Any]): # nonlocal func # # if func is None: # func = func2 # # return decorator2 # #if func is None: # return decorator #else: # return decorator2 return celery.shared_task( base=Task, track_started=True, **kwargs ) def is_async() -> bool: try: asyncio.get_running_loop() return True except RuntimeError: return False class Backend(celery.backends.redis.RedisBackend): def __init__(self, *args, **kwargs): return super().__init__(*args, **kwargs) def _store_result(self, task_id, result, state, traceback=None, request=None, **kwargs): meta = self._get_result_meta(result=result, state=state, traceback=traceback, request=request, task_id=task_id,) meta['task_id'] = celery.backends.base.bytes_to_str(task_id) # Retrieve metadata from the backend, if the status # is a success then we ignore any following update to the state. # This solves a task deduplication issue because of network # partitioning or lost workers. This issue involved a race condition # making a lost task overwrite the last successful result in the # result backend. current_meta = self._get_task_meta_for(task_id) if current_meta['status'] == celery.states.SUCCESS: return result try: self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state) except celery.backends.base.BackendStoreError as ex: raise celery.backends.base.BackendStoreError(str(ex), state=state, task_id=task_id) from ex return result def _get_result_meta( self, *args, task_id: Optional[str]=None, state: Optional[str]=None, **kwargs ): current_meta = None if not task_id is None: current_meta = self._get_task_meta_for(task_id) res = super()._get_result_meta(*args, state=state, **kwargs) if state == celery.states.STARTED: if not 'date_started' in res: res['date_started'] = datetime.datetime.now() for k in ['date_started',]: if k in current_meta: res[k] = current_meta[k] print([args, state, kwargs, res]) return res class Task(celery.app.task.Task): def __call__(self, *args, **kwargs) -> Any: res = super().__call__(*args, **kwargs) if inspect.isawaitable(res) and not is_async(): return asyncio.run(res) else: return res #def apply(self, *args, **kwargs): # return self.__call__(*args, **kwargs) #def before_start(self, task_id: str, *args, **kwargs): # self.update_state(None, celery.states.STARTED) # # meta = self.backend._get_task_meta_for(task_id) # # assert isinstance(meta, dict) # # if not 'date_started' in meta: # meta['date_started'] # # self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state) #def update_state( # self, # *args, # state: Optional[str]=None, # meta: Optional[dict[str,Any]]=None, # **kwargs #): # print(['blah', meta, state]) # # if not meta is None: # logger.info(json.dumps(dict(state=state))) # # if not 'date_started' in meta and state == celery.states.STARTED: # meta['date_started'] = datetime.datetime.now() # # return super().update_stae(*args, state=state, meta=meta, **kwargs) @classmethod def _loads( cls, data_str: Optional[str]=None, data: Optional[Any]=None, ) -> Any: if not data_str is None: data = kombu.utils.json.loads(data_str) if isinstance(data, dict) and data.get('type') == 'dataclass_json': module_name = data['module'] class_names = data['_class'].split('.') m = importlib.import_module(module_name) c = m for current_name in class_names: c = getattr(c, current_name) return c(**data['data']) else: if isinstance(data, list): return [ cls._loads(data=o) for o in data ] elif isinstance(data, dict): return { k : cls._loads(data=v) for k, v in data.items() } else: return data @classmethod def _dumps( cls, data: Any, need_native: Optional[bool]=None, ) -> Any: if need_native is None: need_native = False native = None if hasattr(data, 'to_dict'): native = dict( type='dataclass_json', module=data.__class__.__module__, _class=data.__class__.__qualname__, data=data.to_dict(), ) else: if isinstance(data, (list, tuple)): native = [ cls._dumps(o, need_native=True,) for o in data ] elif isinstance(data, dict): native = { k : cls._dumps(v, need_native=True,) for k, v in data.items() } else: native = data if not need_native: return kombu.utils.json.dumps(native) else: return native def kombu_register_json_dataclass(): import kombu.serialization kombu.serialization.register( 'json-dataclass', Task._dumps, Task._loads, content_type='application/json', content_encoding='utf-8', ) def grid_of_videos(paths: Iterable[str]) -> Any: from ipywidgets import Output, GridspecLayout from IPython import display grid = GridspecLayout(1, len(paths)) for i, path in enumerate(paths): assert os.path.exists(path) out = Output() with out: display.display(display.Video( url='/files/%s' % path, #embed=True )) grid[0, i] = out return grid