From 8b4d78ac53e89661fbadddc9d2b4d9734fcc950c Mon Sep 17 00:00:00 2001 From: Siarhei Siniak Date: Sat, 6 Jul 2024 20:14:22 +0300 Subject: [PATCH] [~] Refactor --- python/tasks/tiktok/celery.py | 1 + python/tasks/tiktok/utils.py | 45 ++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/python/tasks/tiktok/celery.py b/python/tasks/tiktok/celery.py index 99a145e..50746e4 100644 --- a/python/tasks/tiktok/celery.py +++ b/python/tasks/tiktok/celery.py @@ -9,6 +9,7 @@ app = celery.Celery( __name__, broker=c.celery_broker, result_backend=c.celery_result_backend, + accept_content=['pickle'], ) app.autodiscover_tasks(c.celery_imports) diff --git a/python/tasks/tiktok/utils.py b/python/tasks/tiktok/utils.py index 9c224fe..191633a 100644 --- a/python/tasks/tiktok/utils.py +++ b/python/tasks/tiktok/utils.py @@ -1,12 +1,51 @@ import celery.app.task +import importlib +import kombu.serialization +import kombu.utils.json +from typing import ( + Any, +) + class Task(celery.app.task.Task): def __call__(self, *args, **kwargs): res = super().__call__(*args, **kwargs) return self._to_native(res) - def _to_native(self, data): - if hasattr(data, 'to_dict'): - return data.to_dict() + @classmethod + def _loads(self, data_str: str) -> Any: + data = kombu.utils.json.loads(data_str) + + if isinstance(data, dict) and data.get('type') == 'dataclass_json': + module_name = data['module'] + class_names = data['_class'].split('.') + + m = importlib.import_module(module_name) + + c = m + for current_name in class_names: + c = getattr(c, current_name) + + return c(**data['data']) else: return data + + @classmethod + def _dumps(self, data: Any) -> str: + if hasattr(data, 'to_dict'): + return kombu.utils.json.dumps(dict( + type='dataclass_json', + module=data.__class__.__module__, + _class=data.__class__.__qualname__, + data=data.to_dict(), + )) + else: + return kombu.utils.json.dumps(data) + +kombu.serialization.register( + 'json-dataclass', + Task._dumps, + Task._loads, + content_type='application/json', + content_encoding='text', +)