import celery.app.task import asyncio import inspect import importlib import kombu.utils.json from typing import ( Any, Optional, Callable, ) def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any: #@celery.shared_task( # base=Task, # **kwargs, #) #def decorator2(*args, **kwargs): # res = func(*args, **kwargs) # # if inspect.isawaitable(res): # return asyncio.run(res) # else: # return res # #def decorator(func2: Callable[Any, Any]): # nonlocal func # # if func is None: # func = func2 # # return decorator2 # #if func is None: # return decorator #else: # return decorator2 return celery.shared_task(base=Task, **kwargs) def is_async() -> bool: try: asyncio.get_running_loop() return True except RuntimeError: return False class Task(celery.app.task.Task): def __call__(self, *args, **kwargs) -> Any: res = super().__call__(*args, **kwargs) if inspect.isawaitable(res) and not is_async(): return asyncio.run(res) else: return res #def apply(self, *args, **kwargs): # return self.__call__(*args, **kwargs) @classmethod def _loads( cls, data_str: Optional[str]=None, data: Optional[Any]=None, ) -> Any: if not data_str is None: data = kombu.utils.json.loads(data_str) if isinstance(data, dict) and data.get('type') == 'dataclass_json': module_name = data['module'] class_names = data['_class'].split('.') m = importlib.import_module(module_name) c = m for current_name in class_names: c = getattr(c, current_name) return c(**data['data']) else: if isinstance(data, list): return [ cls._loads(data=o) for o in data ] elif isinstance(data, dict): return { k : cls._loads(data=v) for k, v in data.items() } else: return data @classmethod def _dumps( cls, data: Any, need_native: Optional[bool]=None, ) -> Any: if need_native is None: need_native = False native = None if hasattr(data, 'to_dict'): native = dict( type='dataclass_json', module=data.__class__.__module__, _class=data.__class__.__qualname__, data=data.to_dict(), ) else: if isinstance(data, (list, tuple)): native = [ cls._dumps(o, need_native=True,) for o in data ] elif isinstance(data, dict): native = { k : cls._dumps(v, need_native=True,) for k, v in data.items() } else: native = data if not need_native: return kombu.utils.json.dumps(native) else: return native def kombu_register_json_dataclass(): import kombu.serialization kombu.serialization.register( 'json-dataclass', Task._dumps, Task._loads, content_type='application/json', content_encoding='utf-8', )