From 3a7ed25c08261a132547c326d5c594b56104bdb0 Mon Sep 17 00:00:00 2001 From: Siarhei Siniak Date: Sun, 7 Jul 2024 12:39:23 +0300 Subject: [PATCH] [~] Refactor --- python/tasks/tiktok/tasks.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tasks/tiktok/tasks.py b/python/tasks/tiktok/tasks.py index e21b84c..9d98b3d 100644 --- a/python/tasks/tiktok/tasks.py +++ b/python/tasks/tiktok/tasks.py @@ -13,6 +13,7 @@ from typing import ( Any, Literal, Optional, + Callable, Iterable, ) import celery @@ -394,6 +395,7 @@ class process_graph_t: def process_graph( nodes: dict[str, Any], data_deps: dict[str, Iterable[str]], + data_preproc: dict[str, Callable[Any, Any]], execution_deps: dict[str, Iterable[str]], ) -> process_graph_t.res_t: import networkx @@ -463,9 +465,17 @@ def process_graph( celery.result.AsyncResult( task_ids[v] ).result - for v in g_data.predecessors(node) + for v in data_deps.get(node, tuple()) ] - task = nodes[node].apply_async(*args) + kwargs = dict() + + if node in data_preproc: + args, kwargs = data_preproc[node]( + nodes[node], + *args + ) + + task = nodes[node].clone(args=args, kwargs=kwargs).apply_async() task_ids[node] = task.id pending_ids.add(task.id) active_queue.append(task.id)