From e27e5af370665ece17d1737ec14834b3e1d40362 Mon Sep 17 00:00:00 2001 From: Siarhei Siniak Date: Sun, 7 Jul 2024 11:25:29 +0300 Subject: [PATCH] [~] Refactor --- python/tasks/tiktok/tasks.py | 76 ++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/python/tasks/tiktok/tasks.py b/python/tasks/tiktok/tasks.py index fedac5c..18c5802 100644 --- a/python/tasks/tiktok/tasks.py +++ b/python/tasks/tiktok/tasks.py @@ -381,3 +381,79 @@ def audio_get() -> audio_get_t.res_t: path_mp3=path_mp3, url=url, ) + +@shared_task() +def process_graph( + nodes: dict[str, Any], + data_deps: dict(str, Iterable[str]), + execution_deps: dict(str, Iterable[str]), +) -> Any: + import networkx + + g_data = networkx.DiGraph() + g_execution = networkx.DiGraph() + + for v in nodes: + g_data.add_node(v) + g_execution.add_node(v) + + for b, deps in data_deps.items(): + for a in deps: + g_data.add_edge(a, b) + g_execution.add_edge(a, b) + + for b, deps in execution_deps.items(): + for a in deps: + g_execution.add_edge(a, b) + + task_ids = dict() + done_ids = set() + error_ids = set() + pending_ids = set() + active_queue = collections.deque() + + ordered_nodes = networkx.topological_sort(g_execution) + node_id = 0 + while node_id < len(ordered_tasks): + node = ordered_tasks[node_id] + + if any([ + v in task_ids and task_ids[v] in pending_ids + for v in g_execution.predecessors(node) + ]): + task_id = active_queue.popleft() + try: + result = task.backend.wait_for(task_id, interval=0.1,) + if result.state == celery.states.SUCCESS: + done_ids.add(result.id) + continue + except: + error_ids.add(task.id) + logger.error(json.dumps(dict( + msg=traceback.format_exc(), + ))) + break + finally: + del pending_ids(task.id) + else: + args = [ + celery.result.AsyncResult( + task_ids[v] + ).result + for v in g_data.predecessors(node) + ] + task = nodes[node].apply_async(*args) + task_ids[node] = task.id + pending_ids.add(task.id) + active + del args + del task + node_id += 1 + + return dict( + ordered_nodes=ordered_nodes, + done_ids=done_ids, + task_ids=task_ids, + error_ids=error_ids, + pending_ids=pending_ids, + )