[~] Refactor

This commit is contained in:
Siarhei Siniak 2024-07-07 11:51:43 +03:00
parent e27e5af370
commit b240d598cb
2 changed files with 72 additions and 28 deletions

@ -1,4 +1,5 @@
import logging
import collections
import enum
import dataclasses
import dataclasses_json
@ -385,8 +386,8 @@ def audio_get() -> audio_get_t.res_t:
@shared_task()
def process_graph(
nodes: dict[str, Any],
data_deps: dict(str, Iterable[str]),
execution_deps: dict(str, Iterable[str]),
data_deps: dict[str, Iterable[str]],
execution_deps: dict[str, Iterable[str]],
) -> Any:
import networkx
@ -412,29 +413,47 @@ def process_graph(
pending_ids = set()
active_queue = collections.deque()
ordered_nodes = networkx.topological_sort(g_execution)
ordered_nodes = list(networkx.topological_sort(g_execution))
node_id = 0
while node_id < len(ordered_tasks):
node = ordered_tasks[node_id]
if any([
v in task_ids and task_ids[v] in pending_ids
for v in g_execution.predecessors(node)
]):
task_id = active_queue.popleft()
try:
result = task.backend.wait_for(task_id, interval=0.1,)
if result.state == celery.states.SUCCESS:
done_ids.add(result.id)
continue
except:
error_ids.add(task.id)
logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
def wait_task() -> bool:
import ipdb
ipdb.set_trace()
task_id = active_queue.popleft()
task = celery.result.AsyncResult(task_id)
try:
task.wait()
if task.status == celery.states.SUCCESS:
done_ids.add(task_id)
return True
except:
error_ids.add(task_id)
logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
return False
finally:
pending_ids.remove(task_id)
while node_id < len(ordered_nodes) or len(pending_ids) > 0:
if node_id < len(ordered_nodes):
node = ordered_nodes[node_id]
else:
node = None
if (
(len(pending_ids) > 0 and node_id == len(ordered_nodes)) or
any([
v in task_ids and task_ids[v] in pending_ids
for v in g_execution.predecessors(node)
])
):
if wait_task():
continue
else:
break
finally:
del pending_ids(task.id)
else:
args = [
celery.result.AsyncResult(
@ -445,7 +464,7 @@ def process_graph(
task = nodes[node].apply_async(*args)
task_ids[node] = task.id
pending_ids.add(task.id)
active
active_queue.append(task.id)
del args
del task
node_id += 1

@ -1,5 +1,6 @@
import celery.app.task
import celery.backends.redis
import json
import datetime
import os
import asyncio
@ -12,6 +13,9 @@ from typing import (
Callable,
Iterable,
)
from .config import logger_setup
logger = logger_setup(__name__)
def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
#@celery.shared_task(
@ -38,11 +42,32 @@ def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
# return decorator
#else:
# return decorator2
return celery.shared_task(
base=Task,
track_started=True,
**kwargs
)
def decorator(func2):
nonlocal func
if func is None:
func = func2
for a in celery._state._get_active_apps():
name = a.gen_task_name(func.__name__, func.__module__)
if name in a.tasks:
logger.info(json.dumps(dict(
name=name,
a=str(a),
action='derigester_task',
)))
a.tasks.pop(name)
return celery.shared_task(
base=Task,
track_started=True,
**kwargs
)(func)
if func is None:
return decorator
else:
return decorator(func, *args, **kwargs)
def is_async() -> bool:
try: