diff --git a/python/tasks/tiktok/tasks.py b/python/tasks/tiktok/tasks.py index 58cfbf7..deed5c2 100644 --- a/python/tasks/tiktok/tasks.py +++ b/python/tasks/tiktok/tasks.py @@ -383,12 +383,23 @@ def audio_get() -> audio_get_t.res_t: url=url, ) +class process_graph_t: + @dataclasses_json.dataclass_json + @dataclasses.dataclass + class res_t: + ordered_nodes: Iterable[str]=dataclasses.field(default_factory=list) + done_ids: Iterable[str]=dataclasses.field(default_factory=set) + error_ids: Iterable[str]=dataclasses.field(default_factory=set) + task_ids: dict[str, str]=dataclasses.field(default_factory=dict) + pending_ids: Iterable[str]=dataclasses.field(default_factory=set) + done_tasks: Iterable[celery.result.AsyncResult]=dataclasses.field(default_factory=dict) + @shared_task() def process_graph( nodes: dict[str, Any], data_deps: dict[str, Iterable[str]], execution_deps: dict[str, Iterable[str]], -) -> Any: +) -> process_graph_t.res_t: import networkx g_data = networkx.DiGraph() @@ -466,7 +477,7 @@ def process_graph( del task node_id += 1 - return dict( + return process_graph_t.res_t( ordered_nodes=ordered_nodes, done_ids=done_ids, done_tasks={