diff --git a/python/tasks/tiktok/tasks.py b/python/tasks/tiktok/tasks.py index 18c5802..adde5a7 100644 --- a/python/tasks/tiktok/tasks.py +++ b/python/tasks/tiktok/tasks.py @@ -1,4 +1,5 @@ import logging +import collections import enum import dataclasses import dataclasses_json @@ -385,8 +386,8 @@ def audio_get() -> audio_get_t.res_t: @shared_task() def process_graph( nodes: dict[str, Any], - data_deps: dict(str, Iterable[str]), - execution_deps: dict(str, Iterable[str]), + data_deps: dict[str, Iterable[str]], + execution_deps: dict[str, Iterable[str]], ) -> Any: import networkx @@ -412,29 +413,47 @@ def process_graph( pending_ids = set() active_queue = collections.deque() - ordered_nodes = networkx.topological_sort(g_execution) + ordered_nodes = list(networkx.topological_sort(g_execution)) node_id = 0 - while node_id < len(ordered_tasks): - node = ordered_tasks[node_id] - if any([ - v in task_ids and task_ids[v] in pending_ids - for v in g_execution.predecessors(node) - ]): - task_id = active_queue.popleft() - try: - result = task.backend.wait_for(task_id, interval=0.1,) - if result.state == celery.states.SUCCESS: - done_ids.add(result.id) - continue - except: - error_ids.add(task.id) - logger.error(json.dumps(dict( - msg=traceback.format_exc(), - ))) + def wait_task() -> bool: + import ipdb + ipdb.set_trace() + + task_id = active_queue.popleft() + task = celery.result.AsyncResult(task_id) + + try: + task.wait() + if task.status == celery.states.SUCCESS: + done_ids.add(task_id) + return True + except: + error_ids.add(task_id) + logger.error(json.dumps(dict( + msg=traceback.format_exc(), + ))) + return False + finally: + pending_ids.remove(task_id) + + while node_id < len(ordered_nodes) or len(pending_ids) > 0: + if node_id < len(ordered_nodes): + node = ordered_nodes[node_id] + else: + node = None + + if ( + (len(pending_ids) > 0 and node_id == len(ordered_nodes)) or + any([ + v in task_ids and task_ids[v] in pending_ids + for v in g_execution.predecessors(node) + ]) + ): + if wait_task(): + continue + else: break - finally: - del pending_ids(task.id) else: args = [ celery.result.AsyncResult( @@ -445,7 +464,7 @@ def process_graph( task = nodes[node].apply_async(*args) task_ids[node] = task.id pending_ids.add(task.id) - active + active_queue.append(task.id) del args del task node_id += 1 diff --git a/python/tasks/tiktok/utils.py b/python/tasks/tiktok/utils.py index d76d46c..c27e38d 100644 --- a/python/tasks/tiktok/utils.py +++ b/python/tasks/tiktok/utils.py @@ -1,5 +1,6 @@ import celery.app.task import celery.backends.redis +import json import datetime import os import asyncio @@ -12,6 +13,9 @@ from typing import ( Callable, Iterable, ) +from .config import logger_setup + +logger = logger_setup(__name__) def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any: #@celery.shared_task( @@ -38,11 +42,32 @@ def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any: # return decorator #else: # return decorator2 - return celery.shared_task( - base=Task, - track_started=True, - **kwargs - ) + def decorator(func2): + nonlocal func + + if func is None: + func = func2 + + for a in celery._state._get_active_apps(): + name = a.gen_task_name(func.__name__, func.__module__) + if name in a.tasks: + logger.info(json.dumps(dict( + name=name, + a=str(a), + action='derigester_task', + ))) + a.tasks.pop(name) + + return celery.shared_task( + base=Task, + track_started=True, + **kwargs + )(func) + + if func is None: + return decorator + else: + return decorator(func, *args, **kwargs) def is_async() -> bool: try: