282 lines
7.8 KiB
Python
282 lines
7.8 KiB
Python
import celery.app.task
|
|
import celery.backends.redis
|
|
import json
|
|
import datetime
|
|
import os
|
|
import asyncio
|
|
import inspect
|
|
import importlib
|
|
import kombu.utils.json
|
|
from typing import (
|
|
Any,
|
|
Optional,
|
|
Callable,
|
|
Iterable,
|
|
)
|
|
from .config import logger_setup
|
|
|
|
logger = logger_setup(__name__)
|
|
|
|
def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
|
|
#@celery.shared_task(
|
|
# base=Task,
|
|
# **kwargs,
|
|
#)
|
|
#def decorator2(*args, **kwargs):
|
|
# res = func(*args, **kwargs)
|
|
#
|
|
# if inspect.isawaitable(res):
|
|
# return asyncio.run(res)
|
|
# else:
|
|
# return res
|
|
#
|
|
#def decorator(func2: Callable[Any, Any]):
|
|
# nonlocal func
|
|
#
|
|
# if func is None:
|
|
# func = func2
|
|
#
|
|
# return decorator2
|
|
#
|
|
#if func is None:
|
|
# return decorator
|
|
#else:
|
|
# return decorator2
|
|
def decorator(func2):
|
|
nonlocal func
|
|
|
|
if func is None:
|
|
func = func2
|
|
|
|
for a in celery._state._get_active_apps():
|
|
name = a.gen_task_name(func.__name__, func.__module__)
|
|
if name in a.tasks:
|
|
logger.info(json.dumps(dict(
|
|
name=name,
|
|
a=str(a),
|
|
action='derigester_task',
|
|
)))
|
|
a.tasks.pop(name)
|
|
|
|
return celery.shared_task(
|
|
base=Task,
|
|
track_started=True,
|
|
**kwargs
|
|
)(func)
|
|
|
|
if func is None:
|
|
return decorator
|
|
else:
|
|
return decorator(func, *args, **kwargs)
|
|
|
|
def is_async() -> bool:
|
|
try:
|
|
asyncio.get_running_loop()
|
|
return True
|
|
except RuntimeError:
|
|
return False
|
|
|
|
class Backend(celery.backends.redis.RedisBackend):
|
|
def __init__(self, *args, **kwargs):
|
|
return super().__init__(*args, **kwargs)
|
|
|
|
def _store_result(self, task_id, result, state,
|
|
traceback=None, request=None, **kwargs):
|
|
meta = self._get_result_meta(result=result, state=state,
|
|
traceback=traceback, request=request, task_id=task_id,)
|
|
meta['task_id'] = celery.backends.base.bytes_to_str(task_id)
|
|
|
|
# Retrieve metadata from the backend, if the status
|
|
# is a success then we ignore any following update to the state.
|
|
# This solves a task deduplication issue because of network
|
|
# partitioning or lost workers. This issue involved a race condition
|
|
# making a lost task overwrite the last successful result in the
|
|
# result backend.
|
|
current_meta = self._get_task_meta_for(task_id)
|
|
|
|
if current_meta['status'] == celery.states.SUCCESS:
|
|
return result
|
|
|
|
try:
|
|
self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)
|
|
except celery.backends.base.BackendStoreError as ex:
|
|
raise celery.backends.base.BackendStoreError(str(ex), state=state, task_id=task_id) from ex
|
|
|
|
return result
|
|
|
|
|
|
def _get_result_meta(
|
|
self,
|
|
*args,
|
|
task_id: Optional[str]=None,
|
|
state: Optional[str]=None,
|
|
request: Optional[Any]=None,
|
|
**kwargs
|
|
):
|
|
current_meta = None
|
|
|
|
if not task_id is None:
|
|
current_meta = self._get_task_meta_for(task_id)
|
|
|
|
res = super()._get_result_meta(*args, state=state, request=request, **kwargs)
|
|
|
|
if not request is None:
|
|
#import pprint
|
|
#print(request)
|
|
if not '_task_name' in res:
|
|
res['_task_name'] = request.task
|
|
|
|
if state == celery.states.STARTED:
|
|
if not 'date_started' in res:
|
|
res['date_started'] = datetime.datetime.now()
|
|
|
|
if state in celery.states.READY_STATES:
|
|
if not '_date_done' in res:
|
|
res['_date_done'] = datetime.datetime.now()
|
|
|
|
for k in ['date_started', '_date_done', '_task_name']:
|
|
if k in current_meta:
|
|
res[k] = current_meta[k]
|
|
|
|
return res
|
|
|
|
class Task(celery.app.task.Task):
|
|
def __call__(self, *args, **kwargs) -> Any:
|
|
res = super().__call__(*args, **kwargs)
|
|
if inspect.isawaitable(res) and not is_async():
|
|
return asyncio.run(res)
|
|
else:
|
|
return res
|
|
|
|
#def apply(self, *args, **kwargs):
|
|
# return self.__call__(*args, **kwargs)
|
|
|
|
#def before_start(self, task_id: str, *args, **kwargs):
|
|
# self.update_state(None, celery.states.STARTED)
|
|
#
|
|
# meta = self.backend._get_task_meta_for(task_id)
|
|
#
|
|
# assert isinstance(meta, dict)
|
|
#
|
|
# if not 'date_started' in meta:
|
|
# meta['date_started']
|
|
#
|
|
# self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)
|
|
|
|
#def update_state(
|
|
# self,
|
|
# *args,
|
|
# state: Optional[str]=None,
|
|
# meta: Optional[dict[str,Any]]=None,
|
|
# **kwargs
|
|
#):
|
|
# print(['blah', meta, state])
|
|
#
|
|
# if not meta is None:
|
|
# logger.info(json.dumps(dict(state=state)))
|
|
#
|
|
# if not 'date_started' in meta and state == celery.states.STARTED:
|
|
# meta['date_started'] = datetime.datetime.now()
|
|
#
|
|
# return super().update_stae(*args, state=state, meta=meta, **kwargs)
|
|
|
|
@classmethod
|
|
def _loads(
|
|
cls,
|
|
data_str: Optional[str]=None,
|
|
data: Optional[Any]=None,
|
|
) -> Any:
|
|
if not data_str is None:
|
|
data = kombu.utils.json.loads(data_str)
|
|
|
|
if isinstance(data, dict) and data.get('type') == 'dataclass_json':
|
|
module_name = data['module']
|
|
class_names = data['_class'].split('.')
|
|
|
|
m = importlib.import_module(module_name)
|
|
|
|
c = m
|
|
for current_name in class_names:
|
|
c = getattr(c, current_name)
|
|
|
|
return c(**data['data'])
|
|
else:
|
|
if isinstance(data, list):
|
|
return [
|
|
cls._loads(data=o)
|
|
for o in data
|
|
]
|
|
elif isinstance(data, dict):
|
|
return {
|
|
k : cls._loads(data=v)
|
|
for k, v in data.items()
|
|
}
|
|
else:
|
|
return data
|
|
|
|
@classmethod
|
|
def _dumps(
|
|
cls,
|
|
data: Any,
|
|
need_native: Optional[bool]=None,
|
|
) -> Any:
|
|
if need_native is None:
|
|
need_native = False
|
|
|
|
native = None
|
|
if hasattr(data, 'to_dict'):
|
|
native = dict(
|
|
type='dataclass_json',
|
|
module=data.__class__.__module__,
|
|
_class=data.__class__.__qualname__,
|
|
data=data.to_dict(),
|
|
)
|
|
else:
|
|
if isinstance(data, (list, tuple)):
|
|
native = [
|
|
cls._dumps(o, need_native=True,)
|
|
for o in data
|
|
]
|
|
elif isinstance(data, dict):
|
|
native = {
|
|
k : cls._dumps(v, need_native=True,)
|
|
for k, v in data.items()
|
|
}
|
|
else:
|
|
native = data
|
|
|
|
if not need_native:
|
|
return kombu.utils.json.dumps(native)
|
|
else:
|
|
return native
|
|
|
|
def kombu_register_json_dataclass():
|
|
import kombu.serialization
|
|
kombu.serialization.register(
|
|
'json-dataclass',
|
|
Task._dumps,
|
|
Task._loads,
|
|
content_type='application/json',
|
|
content_encoding='utf-8',
|
|
)
|
|
|
|
def grid_of_videos(paths: Iterable[str]) -> Any:
|
|
from ipywidgets import Output, GridspecLayout
|
|
from IPython import display
|
|
|
|
grid = GridspecLayout(1, len(paths))
|
|
|
|
for i, path in enumerate(paths):
|
|
assert os.path.exists(path)
|
|
|
|
out = Output()
|
|
with out:
|
|
display.display(display.Video(
|
|
url='/files/%s' % path,
|
|
#embed=True
|
|
))
|
|
|
|
grid[0, i] = out
|
|
|
|
return grid
|