[~] Refactor

This commit is contained in:
Siarhei Siniak 2024-07-07 12:39:23 +03:00
parent 5db573c025
commit 3a7ed25c08

@ -13,6 +13,7 @@ from typing import (
Any, Any,
Literal, Literal,
Optional, Optional,
Callable,
Iterable, Iterable,
) )
import celery import celery
@ -394,6 +395,7 @@ class process_graph_t:
def process_graph( def process_graph(
nodes: dict[str, Any], nodes: dict[str, Any],
data_deps: dict[str, Iterable[str]], data_deps: dict[str, Iterable[str]],
data_preproc: dict[str, Callable[Any, Any]],
execution_deps: dict[str, Iterable[str]], execution_deps: dict[str, Iterable[str]],
) -> process_graph_t.res_t: ) -> process_graph_t.res_t:
import networkx import networkx
@ -463,9 +465,17 @@ def process_graph(
celery.result.AsyncResult( celery.result.AsyncResult(
task_ids[v] task_ids[v]
).result ).result
for v in g_data.predecessors(node) for v in data_deps.get(node, tuple())
] ]
task = nodes[node].apply_async(*args) kwargs = dict()
if node in data_preproc:
args, kwargs = data_preproc[node](
nodes[node],
*args
)
task = nodes[node].clone(args=args, kwargs=kwargs).apply_async()
task_ids[node] = task.id task_ids[node] = task.id
pending_ids.add(task.id) pending_ids.add(task.id)
active_queue.append(task.id) active_queue.append(task.id)