[~] Refactor
This commit is contained in:
parent
e27e5af370
commit
b240d598cb
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
import enum
|
import enum
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import dataclasses_json
|
import dataclasses_json
|
||||||
@ -385,8 +386,8 @@ def audio_get() -> audio_get_t.res_t:
|
|||||||
@shared_task()
|
@shared_task()
|
||||||
def process_graph(
|
def process_graph(
|
||||||
nodes: dict[str, Any],
|
nodes: dict[str, Any],
|
||||||
data_deps: dict(str, Iterable[str]),
|
data_deps: dict[str, Iterable[str]],
|
||||||
execution_deps: dict(str, Iterable[str]),
|
execution_deps: dict[str, Iterable[str]],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
import networkx
|
import networkx
|
||||||
|
|
||||||
@ -412,29 +413,47 @@ def process_graph(
|
|||||||
pending_ids = set()
|
pending_ids = set()
|
||||||
active_queue = collections.deque()
|
active_queue = collections.deque()
|
||||||
|
|
||||||
ordered_nodes = networkx.topological_sort(g_execution)
|
ordered_nodes = list(networkx.topological_sort(g_execution))
|
||||||
node_id = 0
|
node_id = 0
|
||||||
while node_id < len(ordered_tasks):
|
|
||||||
node = ordered_tasks[node_id]
|
|
||||||
|
|
||||||
if any([
|
def wait_task() -> bool:
|
||||||
v in task_ids and task_ids[v] in pending_ids
|
import ipdb
|
||||||
for v in g_execution.predecessors(node)
|
ipdb.set_trace()
|
||||||
]):
|
|
||||||
task_id = active_queue.popleft()
|
task_id = active_queue.popleft()
|
||||||
try:
|
task = celery.result.AsyncResult(task_id)
|
||||||
result = task.backend.wait_for(task_id, interval=0.1,)
|
|
||||||
if result.state == celery.states.SUCCESS:
|
try:
|
||||||
done_ids.add(result.id)
|
task.wait()
|
||||||
continue
|
if task.status == celery.states.SUCCESS:
|
||||||
except:
|
done_ids.add(task_id)
|
||||||
error_ids.add(task.id)
|
return True
|
||||||
logger.error(json.dumps(dict(
|
except:
|
||||||
msg=traceback.format_exc(),
|
error_ids.add(task_id)
|
||||||
)))
|
logger.error(json.dumps(dict(
|
||||||
|
msg=traceback.format_exc(),
|
||||||
|
)))
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
pending_ids.remove(task_id)
|
||||||
|
|
||||||
|
while node_id < len(ordered_nodes) or len(pending_ids) > 0:
|
||||||
|
if node_id < len(ordered_nodes):
|
||||||
|
node = ordered_nodes[node_id]
|
||||||
|
else:
|
||||||
|
node = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
(len(pending_ids) > 0 and node_id == len(ordered_nodes)) or
|
||||||
|
any([
|
||||||
|
v in task_ids and task_ids[v] in pending_ids
|
||||||
|
for v in g_execution.predecessors(node)
|
||||||
|
])
|
||||||
|
):
|
||||||
|
if wait_task():
|
||||||
|
continue
|
||||||
|
else:
|
||||||
break
|
break
|
||||||
finally:
|
|
||||||
del pending_ids(task.id)
|
|
||||||
else:
|
else:
|
||||||
args = [
|
args = [
|
||||||
celery.result.AsyncResult(
|
celery.result.AsyncResult(
|
||||||
@ -445,7 +464,7 @@ def process_graph(
|
|||||||
task = nodes[node].apply_async(*args)
|
task = nodes[node].apply_async(*args)
|
||||||
task_ids[node] = task.id
|
task_ids[node] = task.id
|
||||||
pending_ids.add(task.id)
|
pending_ids.add(task.id)
|
||||||
active
|
active_queue.append(task.id)
|
||||||
del args
|
del args
|
||||||
del task
|
del task
|
||||||
node_id += 1
|
node_id += 1
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import celery.app.task
|
import celery.app.task
|
||||||
import celery.backends.redis
|
import celery.backends.redis
|
||||||
|
import json
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -12,6 +13,9 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
)
|
)
|
||||||
|
from .config import logger_setup
|
||||||
|
|
||||||
|
logger = logger_setup(__name__)
|
||||||
|
|
||||||
def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
|
def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
|
||||||
#@celery.shared_task(
|
#@celery.shared_task(
|
||||||
@ -38,11 +42,32 @@ def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
|
|||||||
# return decorator
|
# return decorator
|
||||||
#else:
|
#else:
|
||||||
# return decorator2
|
# return decorator2
|
||||||
return celery.shared_task(
|
def decorator(func2):
|
||||||
base=Task,
|
nonlocal func
|
||||||
track_started=True,
|
|
||||||
**kwargs
|
if func is None:
|
||||||
)
|
func = func2
|
||||||
|
|
||||||
|
for a in celery._state._get_active_apps():
|
||||||
|
name = a.gen_task_name(func.__name__, func.__module__)
|
||||||
|
if name in a.tasks:
|
||||||
|
logger.info(json.dumps(dict(
|
||||||
|
name=name,
|
||||||
|
a=str(a),
|
||||||
|
action='derigester_task',
|
||||||
|
)))
|
||||||
|
a.tasks.pop(name)
|
||||||
|
|
||||||
|
return celery.shared_task(
|
||||||
|
base=Task,
|
||||||
|
track_started=True,
|
||||||
|
**kwargs
|
||||||
|
)(func)
|
||||||
|
|
||||||
|
if func is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(func, *args, **kwargs)
|
||||||
|
|
||||||
def is_async() -> bool:
|
def is_async() -> bool:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user