[~] Refactor

This commit is contained in:
Siarhei Siniak 2024-07-07 11:25:29 +03:00
parent bd883d810c
commit e27e5af370

@ -381,3 +381,79 @@ def audio_get() -> audio_get_t.res_t:
path_mp3=path_mp3, path_mp3=path_mp3,
url=url, url=url,
) )
@shared_task()
def process_graph(
nodes: dict[str, Any],
data_deps: dict(str, Iterable[str]),
execution_deps: dict(str, Iterable[str]),
) -> Any:
import networkx
g_data = networkx.DiGraph()
g_execution = networkx.DiGraph()
for v in nodes:
g_data.add_node(v)
g_execution.add_node(v)
for b, deps in data_deps.items():
for a in deps:
g_data.add_edge(a, b)
g_execution.add_edge(a, b)
for b, deps in execution_deps.items():
for a in deps:
g_execution.add_edge(a, b)
task_ids = dict()
done_ids = set()
error_ids = set()
pending_ids = set()
active_queue = collections.deque()
ordered_nodes = networkx.topological_sort(g_execution)
node_id = 0
while node_id < len(ordered_tasks):
node = ordered_tasks[node_id]
if any([
v in task_ids and task_ids[v] in pending_ids
for v in g_execution.predecessors(node)
]):
task_id = active_queue.popleft()
try:
result = task.backend.wait_for(task_id, interval=0.1,)
if result.state == celery.states.SUCCESS:
done_ids.add(result.id)
continue
except:
error_ids.add(task.id)
logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
break
finally:
del pending_ids(task.id)
else:
args = [
celery.result.AsyncResult(
task_ids[v]
).result
for v in g_data.predecessors(node)
]
task = nodes[node].apply_async(*args)
task_ids[node] = task.id
pending_ids.add(task.id)
active
del args
del task
node_id += 1
return dict(
ordered_nodes=ordered_nodes,
done_ids=done_ids,
task_ids=task_ids,
error_ids=error_ids,
pending_ids=pending_ids,
)