[~] Refactor

This commit is contained in:
Siarhei Siniak 2024-07-07 11:51:43 +03:00
parent e27e5af370
commit b240d598cb
2 changed files with 72 additions and 28 deletions

@ -1,4 +1,5 @@
import logging import logging
import collections
import enum import enum
import dataclasses import dataclasses
import dataclasses_json import dataclasses_json
@ -385,8 +386,8 @@ def audio_get() -> audio_get_t.res_t:
@shared_task() @shared_task()
def process_graph( def process_graph(
nodes: dict[str, Any], nodes: dict[str, Any],
data_deps: dict(str, Iterable[str]), data_deps: dict[str, Iterable[str]],
execution_deps: dict(str, Iterable[str]), execution_deps: dict[str, Iterable[str]],
) -> Any: ) -> Any:
import networkx import networkx
@ -412,29 +413,47 @@ def process_graph(
pending_ids = set() pending_ids = set()
active_queue = collections.deque() active_queue = collections.deque()
ordered_nodes = networkx.topological_sort(g_execution) ordered_nodes = list(networkx.topological_sort(g_execution))
node_id = 0 node_id = 0
while node_id < len(ordered_tasks):
node = ordered_tasks[node_id]
if any([ def wait_task() -> bool:
v in task_ids and task_ids[v] in pending_ids import ipdb
for v in g_execution.predecessors(node) ipdb.set_trace()
]):
task_id = active_queue.popleft() task_id = active_queue.popleft()
try: task = celery.result.AsyncResult(task_id)
result = task.backend.wait_for(task_id, interval=0.1,)
if result.state == celery.states.SUCCESS: try:
done_ids.add(result.id) task.wait()
continue if task.status == celery.states.SUCCESS:
except: done_ids.add(task_id)
error_ids.add(task.id) return True
logger.error(json.dumps(dict( except:
msg=traceback.format_exc(), error_ids.add(task_id)
))) logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
return False
finally:
pending_ids.remove(task_id)
while node_id < len(ordered_nodes) or len(pending_ids) > 0:
if node_id < len(ordered_nodes):
node = ordered_nodes[node_id]
else:
node = None
if (
(len(pending_ids) > 0 and node_id == len(ordered_nodes)) or
any([
v in task_ids and task_ids[v] in pending_ids
for v in g_execution.predecessors(node)
])
):
if wait_task():
continue
else:
break break
finally:
del pending_ids(task.id)
else: else:
args = [ args = [
celery.result.AsyncResult( celery.result.AsyncResult(
@ -445,7 +464,7 @@ def process_graph(
task = nodes[node].apply_async(*args) task = nodes[node].apply_async(*args)
task_ids[node] = task.id task_ids[node] = task.id
pending_ids.add(task.id) pending_ids.add(task.id)
active active_queue.append(task.id)
del args del args
del task del task
node_id += 1 node_id += 1

@ -1,5 +1,6 @@
import celery.app.task import celery.app.task
import celery.backends.redis import celery.backends.redis
import json
import datetime import datetime
import os import os
import asyncio import asyncio
@ -12,6 +13,9 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
) )
from .config import logger_setup
logger = logger_setup(__name__)
def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any: def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
#@celery.shared_task( #@celery.shared_task(
@ -38,11 +42,32 @@ def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
# return decorator # return decorator
#else: #else:
# return decorator2 # return decorator2
return celery.shared_task( def decorator(func2):
base=Task, nonlocal func
track_started=True,
**kwargs if func is None:
) func = func2
for a in celery._state._get_active_apps():
name = a.gen_task_name(func.__name__, func.__module__)
if name in a.tasks:
logger.info(json.dumps(dict(
name=name,
a=str(a),
action='derigester_task',
)))
a.tasks.pop(name)
return celery.shared_task(
base=Task,
track_started=True,
**kwargs
)(func)
if func is None:
return decorator
else:
return decorator(func, *args, **kwargs)
def is_async() -> bool: def is_async() -> bool:
try: try: