Compare commits

...

48 Commits

Author SHA1 Message Date
017482857d [~] Refactor 2024-07-07 14:00:10 +03:00
dc83fd5772 [~] Refactor 2024-07-07 13:45:47 +03:00
aa9a270980 [~] Refactor 2024-07-07 13:33:02 +03:00
816572d6de [~] Refactor 2024-07-07 13:08:51 +03:00
5b6d2fab34 [~] Refactor 2024-07-07 12:57:46 +03:00
955b7bde41 [~] Refactor 2024-07-07 12:56:22 +03:00
8083545b90 [~] Refactor 2024-07-07 12:53:53 +03:00
3a7ed25c08 [~] Refactor 2024-07-07 12:39:23 +03:00
5db573c025 [~] Refactor 2024-07-07 12:19:59 +03:00
4d2e9a3fd7 [~] Refactor 2024-07-07 12:14:24 +03:00
a1a0a52e4d [~] Refactor 2024-07-07 12:11:38 +03:00
ea54ef5828 [~] Refactor 2024-07-07 11:51:53 +03:00
b240d598cb [~] Refactor 2024-07-07 11:51:43 +03:00
e27e5af370 [~] Refactor 2024-07-07 11:25:29 +03:00
bd883d810c [~] Refactor 2024-07-07 08:59:39 +03:00
98f99cc470 [~] Refactor 2024-07-07 00:57:00 +03:00
e3557ff8fd [~] Refactor 2024-07-07 00:56:49 +03:00
94e3d90df2 [~] Refactor 2024-07-07 00:45:50 +03:00
c350db8ee4 [~] Refactor 2024-07-07 00:41:01 +03:00
70cf5e6bad [~] Refactor 2024-07-06 23:12:26 +03:00
3a54f34f6c [~] Refactor 2024-07-06 23:09:43 +03:00
2f8274297e [~] Refactor 2024-07-06 23:07:57 +03:00
8bbb43f1ca [~] Refactor 2024-07-06 22:45:32 +03:00
86adcdf4da [~] Refactor 2024-07-06 22:25:39 +03:00
9ee8c962ca [~] Refactor 2024-07-06 22:10:09 +03:00
c99948eb95 [~] Refactor 2024-07-06 22:05:33 +03:00
081e49f3a5 [~] Refactor 2024-07-06 21:06:24 +03:00
d7832d7574 [~] Refactor 2024-07-06 21:06:09 +03:00
b2afd124a7 [~] Refactor 2024-07-06 20:47:12 +03:00
639e755dbd [~] Refactor 2024-07-06 20:24:19 +03:00
8b4d78ac53 [~] Refactor 2024-07-06 20:14:22 +03:00
ff16d48f86 [~] Refactor 2024-07-06 19:54:47 +03:00
6adf6106f1 [~] Refactor 2024-07-06 19:42:12 +03:00
cffbe712c6 [~] Refactor 2024-07-06 19:17:56 +03:00
b234330471 [~] Refactor 2024-07-06 18:19:27 +03:00
a9e0f80fb3 [~] Refactor 2024-07-06 18:10:43 +03:00
45e289cea7 [~] Refactor 2024-07-06 17:55:04 +03:00
7dff2a98f2 [~] Refactor 2024-07-06 17:49:44 +03:00
5cb6394e27 [~] Refactor 2024-07-06 17:43:37 +03:00
ec7e2712eb [~] Refactor 2024-07-06 16:12:35 +03:00
450cab746e [~] Refactor 2024-07-06 15:51:34 +03:00
4679b3b861 [~] Refactor 2024-07-06 15:50:09 +03:00
7ade65b678 [~] Refactor 2024-07-06 15:13:45 +03:00
c28d9a19cc [~] Refactor 2024-07-06 14:40:59 +03:00
858332d657 [~] Refactor 2024-07-06 14:30:27 +03:00
75f41b03db [~] Refactor 2024-07-06 14:27:14 +03:00
2f5a5d0e78 [~] Refactor 2024-07-06 12:02:05 +03:00
a317081fd3 [~] Refactor 2024-07-06 11:27:50 +03:00
11 changed files with 1109 additions and 0 deletions

19
docker/tiktok/.zshrc Normal file

@ -0,0 +1,19 @@
# The following lines were added by compinstall
zstyle ':completion:*' completer _expand _complete _ignored _correct _approximate
zstyle :compinstall filename '~/.zshrc'
setopt INC_APPEND_HISTORY SHARE_HISTORY AUTO_PUSHD PUSHD_IGNORE_DUPS
setopt PROMPTSUBST
autoload -Uz compinit
compinit
# End of lines added by compinstall
# Lines configured by zsh-newuser-install
HISTFILE=~/.histfile
HISTSIZE=1000000
SAVEHIST=1000000
# End of lines configured by zsh-newuser-install
bindkey -d
bindkey -v

29
docker/tiktok/Dockerfile Normal file

@ -0,0 +1,29 @@
FROM python:latest
RUN pip3 install ipython jupyter
RUN apt-get update -yy && apt-get install -yy zsh htop mc git
RUN pip3 install jupyterlab-vim
RUN pip3 install pyktok yt-dlp playwright==1.44.0 TikTokApi
RUN pip3 install numpy pandas browser_cookie3 ipdb asgiref
RUN python3 -m playwright install-deps
RUN python3 -m playwright install
RUN pip3 install tqdm
RUN apt-get install -yy ffmpeg
RUN pip3 install celery redis
RUN pip3 install dataclasses-json
RUN pip3 install rpdb
RUN apt-get install -yy netcat-traditional
RUN apt-get install -yy vim
RUN apt-get install -yy tini
RUN apt-get install -yy wkhtmltopdf graphviz
RUN pip3 install pandoc
RUN apt-get install -yy pandoc
RUN apt-get install -yy texlive-xetex texlive-fonts-recommended texlive-plain-generic
RUN pip3 install 'nbconvert[webpdf]'
RUN pip3 install pickleshare
RUN pip3 install networkx
WORKDIR /app
ENTRYPOINT ["tini", "--", "bash", "docker/tiktok/entry.sh"]
CMD ["zsh", "-l"]

59
docker/tiktok/Makefile Normal file

@ -0,0 +1,59 @@
PROJECT_ROOT ?= ${PWD}
export PROJECT_ROOT
PORT ?= 8888
TOKEN ?= $(shell pwgen -n 20 1)
c:
cd ${PROJECT_ROOT} && \
sudo docker-compose \
-f docker/tiktok/docker-compose.yml $(ARGS)
build:
$(MAKE) c ARGS="pull"
$(MAKE) c ARGS="build --pull"
celery-up:
$(MAKE) c ARGS="up -d redis celery"
celery-stop:
$(MAKE) c ARGS="stop redis celery"
celery-cmd:
$(MAKE) c ARGS="exec celery celery -A python.tasks.tiktok.celery ${ARGS}"
deploy:
cd ${PROJECT_ROOT} && tar -cvf ${PROJECT_ROOT}/tmp/cache/tiktok/repo.tar \
docker/tiktok \
python/tasks/tiktok \
tmp/cache/tiktok/notebooks/tiktok.ipynb \
tmp/cache/tiktok/notebooks/*.pdf \
.dockerignore \
.gitignore
logs:
$(MAKE) c ARGS="logs --tail=100 -f"
celery-restart:
$(MAKE) c ARGS="restart celery"
run:
cd ${PROJECT_ROOT} && \
sudo docker-compose \
-f docker/tiktok/docker-compose.yml \
run \
--use-aliases \
--rm tiktok
jupyter:
cd ${PROJECT_ROOT} && \
sudo docker-compose \
-f docker/tiktok/docker-compose.yml \
run \
-p 127.0.0.1:${PORT}:8888 \
--rm tiktok \
jupyter-lab \
--allow-root \
--ip=0.0.0.0 \
--NotebookApp.token=${TOKEN}

@ -0,0 +1,28 @@
version: '3.7'
services:
redis:
image: redis:latest
volumes:
- ../../tmp/cache/tiktok/redis/data:/data:rw
tiktok: &tiktok
links:
- redis
build:
context: ../../
dockerfile: ./docker/tiktok/Dockerfile
volumes:
- ./../../docker/tiktok:/app/docker/tiktok:ro
- ./../../tmp/cache/tiktok:/app/tmp/cache/tiktok:rw
- ./../../python/tasks/tiktok:/app/python/tasks/tiktok:ro
celery:
build:
context: ../../
dockerfile: ./docker/tiktok/Dockerfile
depends_on:
- redis
volumes:
- ./../../docker/tiktok:/app/docker/tiktok:ro
- ./../../tmp/cache/tiktok:/app/tmp/cache/tiktok:rw
- ./../../python/tasks/tiktok:/app/python/tasks/tiktok:ro
command:
- celery -A python.tasks.tiktok.celery worker -c 2

10
docker/tiktok/entry.sh Normal file

@ -0,0 +1,10 @@
ln -sf $PWD/docker/tiktok/.zshrc ~
mkdir -p tmp/cache/tiktok/zsh
mkdir -p tmp/cache/tiktok/ipython
mkdir -p tmp/cache/tiktok/jupyter
ln -sf $PWD/tmp/cache/tiktok/zsh/histfile ~/.histfile
ln -sf $PWD/tmp/cache/tiktok/jupyter ~/.jupyter
ln -sf $PWD/tmp/cache/tiktok/ipython ~/.ipython
ipython3 profile create
ln -sf $PWD/docker/tiktok/ipython_config.py ~/.ipython/profile_default/
exec $@

@ -0,0 +1,72 @@
c.InteractiveShellApp.exec_lines = [
'%autoreload 2',
r'''
def ipython_update_shortcuts():
import IPython
import prompt_toolkit.filters
import prompt_toolkit.document
import functools
import tempfile
import io
import subprocess
def ipython_edit_in_vim(*args, pt_app):
content = pt_app.app.current_buffer.document.text
lines_count = lambda text: len(text.splitlines())
with tempfile.NamedTemporaryFile(
suffix='.py',
mode='w',
) as f:
with io.open(f.name, 'w') as f2:
f2.write(content)
f2.flush()
result = subprocess.call([
'vim',
'+%d' % lines_count(content),
f.name,
])
if result != 0:
return
f.seek(0, io.SEEK_SET)
with io.open(f.name, 'r') as f2:
new_content = f2.read()
pt_app.app.current_buffer.document = \
prompt_toolkit.document.Document(
new_content,
cursor_position=len(new_content.rstrip()),
)
t1 = IPython.get_ipython()
t2 = t1.pt_app
t3 = [o for o in t2.key_bindings.bindings if 'f2' in repr(o.keys).lower()]
assert len(t3) == 1
t4 = t3[0]
t2.key_bindings.remove(t4.handler)
t2.key_bindings.add(
'\\', 'e', filter=~prompt_toolkit.filters.vi_insert_mode,
)(
functools.partial(
ipython_edit_in_vim,
pt_app=t2,
)
#t4.handler
)
''',
'ipython_update_shortcuts()',
]
c.IPCompleter.use_jedi = False
c.InteractiveShellApp.extensions = ['autoreload']
c.InteractiveShell.history_length = 100 * 1000 * 1000
c.InteractiveShell.history_load_length = 100 * 1000 * 1000
#c.InteractiveShell.enable_history_search = False
#c.InteractiveShell.autosuggestions_provider = None
c.InteractiveShell.pdb = True
c.TerminalInteractiveShell.editing_mode = 'vi'
c.TerminalInteractiveShell.modal_cursor = False
c.TerminalInteractiveShell.emacs_bindings_in_vi_insert_mode = False

@ -0,0 +1,32 @@
from .config import tiktok_config
from .utils import kombu_register_json_dataclass, Backend
import logging
import celery
import redis
c = tiktok_config()
app = celery.Celery(
__name__,
broker=c.celery_broker,
#result_backend=c.celery_result_backend,
#backend=Backend,
#result_backend=c.celery_result_backend,
accept_content=['json-dataclass'],
task_serializer='json-dataclass',
result_serializer='json-dataclass',
task_track_started=True,
)
kombu_register_json_dataclass()
app._backend = Backend(app=app, url=c.celery_result_backend)
app.log.setup(loglevel=c.celery_log_level)
app.autodiscover_tasks(c.celery_imports)
redis = dict(
broker=redis.Redis(host='redis', db=int(c.celery_broker.split('/')[-1])),
result_backend=redis.Redis(host='redis', db=int(c.celery_result_backend.split('/')[-1])),
)

@ -0,0 +1,75 @@
import logging
import enum
import dataclasses
import dataclasses_json
import multiprocessing
import traceback
import subprocess
import os
import sys
import json
from typing import (
Any,
Literal,
Optional,
Iterable,
)
logger = logging.getLogger(__name__)
#logging.getLogger().setLevel(logging.INFO)
class tiktok_config_t:
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class res_t:
project_root: str=''
cache: str=''
videos: str=''
audios: str=''
celery_broker: str=''
celery_result_backend: str=''
celery_imports: Iterable[str]=tuple()
celery_log_level: int=logging.INFO
def tiktok_config() -> tiktok_config_t.res_t:
res = tiktok_config_t.res_t(
project_root=os.path.abspath(
os.path.join(
os.path.dirname(__file__),
'..', '..', '..',
),
),
)
res.celery_broker = 'redis://redis:6379/1'
res.celery_result_backend = 'redis://redis:6379/2'
res.celery_imports = ['python.tasks.tiktok.tasks']
res.cache = os.path.join(
res.project_root,
'tmp/cache/tiktok',
)
res.videos = os.path.join(
res.cache,
'videos',
)
res.audios = os.path.join(
res.cache,
'audios',
)
os.makedirs(res.videos, exist_ok=True)
os.makedirs(res.audios, exist_ok=True)
return res
def logger_setup(name: str) -> logging.Logger:
logger = logging.getLogger(name)
if len(logger.handlers) == 0:
handler = logging.StreamHandler(sys.stderr)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
logger = logger_setup(__name__)

@ -0,0 +1,497 @@
import logging
import collections
import enum
import dataclasses
import dataclasses_json
import multiprocessing
import traceback
import subprocess
import os
import sys
import json
from typing import (
Any,
Literal,
Optional,
Callable,
Iterable,
)
import celery
from .config import tiktok_config, logger_setup
from .utils import Task, shared_task
logger = logger_setup(__name__)
#logging.getLogger().setLevel(logging.INFO)
@shared_task()
async def tiktok_videos_links_get(
query: Optional[str]=None,
screenshot_path: Optional[str]=None,
max_time: Optional[int | float]=None,
max_links: Optional[int]=None,
) -> Iterable[str]:
import datetime
import TikTokApi
import pyktok
import asyncio
import re
if max_links is None:
max_links = 100
if max_time is None:
max_time = 10
async with TikTokApi.TikTokApi() as client:
await client.create_sessions()
session = client.sessions[0]
if not query is None:
await session.page.goto(
'https://www.tiktok.com/search?q=%s' % query
)
if not screenshot_path is None:
await session.page.screenshot(
path=screenshot_path,
)
links = list()
links_set = set()
started_at = datetime.datetime.now()
while True:
content = await session.page.content()
new_links = re.compile(
r'https://www.tiktok.com/@\w+/video/\d+'
).findall(content)
old_size = len(links)
for o in new_links:
if not o in links_set:
links_set.add(o)
links.append(o)
await session.page.mouse.wheel(0, 100)
elapsed = (
datetime.datetime.now() - started_at
).total_seconds()
if elapsed > max_time:
break;
if len(links_set) > max_links:
break
if old_size < len(links):
logger.info(json.dumps(dict(
total=len(links),
elapsed=elapsed,
scroll_y=await session.page.evaluate('window.scrollY'),
)))
return list(links)[:max_links]
@shared_task()
def tiktok_videos_meta(links: Iterable[str]) -> Iterable[dict[str, Any]]:
res = []
for o in links:
parts = o.split('/')
res.append(dict(
url=o,
id=int(parts[-1]),
fname='_'.join(parts[-3:]) +'.mp4',
result_dir=tiktok_config().videos,
))
return res
class tiktok_video_fetch_t:
class method_t(enum.Enum):
pyktok = 'pyktok'
tikcdn_io_curl = 'tikcdn.io-curl'
tikcdn_io_wget = 'tikcdn.io-wget'
@shared_task()
def tiktok_video_fetch(
id: int,
url: str,
fname: str,
result_dir: str,
method: Optional[tiktok_video_fetch_t.method_t]=None,
method_str: Optional[str]=None,
) -> None:
os.chdir(result_dir)
if not method_str is None:
method = tiktok_video_fetch_t.method_t(method_str)
if method is None:
method = tiktok_video_fetch_t.method_t.pyktok
if method == tiktok_video_fetch_t.method_t.pyktok:
import pyktok
pyktok.save_tiktok(url)
elif method == tiktok_video_fetch_t.method_t.tikcdn_io_curl:
subprocess.check_call([
'curl',
'-v',
'https://tikcdn.io/ssstik/%d' % id,
'-o', fname,
])
elif method == tiktok_video_fetch_t.method_t.tikcdn_io_wget:
subprocess.check_call([
'wget',
'https://tikcdn.io/ssstik/%d' % id,
'-O',
fname,
])
else:
raise NotImplementedError
mime_type = file_mime_type(fname)
if mime_type in ['empty']:
raise RuntimeError('notdownloaded')
def file_mime_type(path: str) -> Optional[str]:
if os.path.exists(path):
mime_type = subprocess.check_output([
'file',
'-b', path,
]).strip().decode('utf-8')
return mime_type
else:
return None
async def playwright_save(url: str):
import TikTokApi
async with TikTokApi.TikTokApi() as client:
await client.create_sessions()
session = client.sessions[0]
page = session.page
async with page.expect_download() as download_info:
await page.goto(url)
download = download_info.value
path = download.path()
download.save_as(path)
print(path)
@shared_task()
def tiktok_videos_fetch(
meta: Iterable[dict[str, Any]],
method: Optional[tiktok_video_fetch_t.method_t]=None,
method_str: Optional[str]=None,
force: Optional[bool]=None,
) -> Iterable[dict[str, Any]]:
import tqdm
if force is None:
force = False
stats = dict(
saved=0,
total=0,
skipped=0,
error=0,
)
for o in tqdm.tqdm(meta):
stats['total'] += 1
path = os.path.join(
o['result_dir'],
o['fname'],
)
if (
not os.path.exists(path) or
file_mime_type(path) in ['empty'] or
force
):
try:
tiktok_video_fetch.s(
id=o['id'],
url=o['url'],
fname=o['fname'],
method=method,
method_str=method_str,
result_dir=o['result_dir'],
).apply_async().get(disable_sync_subtasks=False,)
stats['saved'] += 1
except KeyboardInterrupt:
break
except:
logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
stats['error'] += 1
else:
stats['skipped'] += 1
return stats
class tiktok_videos_process_t:
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class res_t:
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class stats_t:
saved: int=0
total: int=0
skipped: int=0
error: int=0
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class video_t:
meta: Optional[dict[str, Any]]=None
processed_path: Optional[str]=None
stats: stats_t=dataclasses.field(default_factory=stats_t)
videos: Iterable[video_t]=dataclasses.field(default_factory=list)
@shared_task()
def tiktok_videos_process(meta: Iterable[dict[str, Any]]) -> dict[str, Any]:
import tqdm
res = tiktok_videos_process_t.res_t(
videos=[],
)
song = audio_get()
for o in tqdm.tqdm(meta):
res.stats.total += 1
res.videos.append(tiktok_videos_process_t.res_t.video_t())
res.videos[-1].meta = o
path = os.path.join(
o['result_dir'],
o['fname'],
)
try:
path_parts = os.path.splitext(path)
processed_path = path_parts[0] + '-proc' + path_parts[1]
processed_path_tmp = path_parts[0] + '-proc.tmp' + path_parts[1]
if os.path.exists(processed_path):
res.videos[-1].processed_path = processed_path
if not os.path.exists(path) or os.path.exists(processed_path):
res.stats.skipped += 1
continue
if os.path.exists(processed_path_tmp):
os.unlink(processed_path_tmp)
ffmpeg = [
'ffmpeg',
'-i', path,
'-i', song.path_mp3,
'-shortest',
'-vf',
','.join([
'setpts=1.1*PTS',
'scale=trunc(iw/0.9):trunc(ow/a/2)*2',
]),
'-sws_flags', 'bilinear',
'-map', '0:v:0',
'-map', '1:a:0',
processed_path_tmp,
]
subprocess.check_call(
ffmpeg,
stdin=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL
)
os.rename(processed_path_tmp, processed_path)
if os.path.exists(processed_path):
res.videos[-1].processed_path = processed_path
res.stats.saved += 1
except KeyboardInterrupt:
break
except:
logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
res.stats.error += 1
return res
class audio_get_t:
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class res_t:
file: str
file_mp3: str
path: str
path_mp3: str
url: str
@shared_task()
def audio_get() -> audio_get_t.res_t:
c = tiktok_config()
url = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ'
file = 'song.dat'
file_mp3 = 'song.mp3'
path = os.path.join(c.audios, file)
path_mp3 = os.path.join(c.audios, file_mp3)
if not os.path.exists(path):
subprocess.check_call([
'yt-dlp',
'-f', 'bestaudio',
url,
'-o', path,
])
if not os.path.exists(path_mp3):
subprocess.check_call([
'ffmpeg',
'-i', path,
path_mp3,
])
return audio_get_t.res_t(
file=file,
file_mp3=file_mp3,
path=path,
path_mp3=path_mp3,
url=url,
)
class process_graph_t:
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class res_t:
ordered_nodes: Iterable[str]=dataclasses.field(default_factory=list)
done_ids: Iterable[str]=dataclasses.field(default_factory=set)
error_ids: Iterable[str]=dataclasses.field(default_factory=set)
task_ids: dict[str, str]=dataclasses.field(default_factory=dict)
pending_ids: Iterable[str]=dataclasses.field(default_factory=set)
done_tasks: Iterable[celery.result.AsyncResult]=dataclasses.field(default_factory=dict)
@shared_task()
def process_graph(
nodes: dict[str, Any],
data_deps: dict[str, Iterable[str]],
data_preproc: dict[str, Callable[Any, Any]],
execution_deps: dict[str, Iterable[str]],
) -> process_graph_t.res_t:
import networkx
g_data = networkx.DiGraph()
g_execution = networkx.DiGraph()
for v in nodes:
g_data.add_node(v)
g_execution.add_node(v)
for b, deps in data_deps.items():
for a in deps:
g_data.add_edge(a, b)
g_execution.add_edge(a, b)
for b, deps in execution_deps.items():
for a in deps:
g_execution.add_edge(a, b)
task_ids = dict()
done_ids = set()
error_ids = set()
pending_ids = set()
active_queue = collections.deque()
ordered_nodes = list(networkx.topological_sort(g_execution))
node_id = 0
def wait_task() -> bool:
task_id = active_queue.popleft()
task = celery.result.AsyncResult(task_id)
try:
task.wait()
if task.status == celery.states.SUCCESS:
done_ids.add(task_id)
return True
except:
error_ids.add(task_id)
logger.error(json.dumps(dict(
msg=traceback.format_exc(),
)))
return False
finally:
pending_ids.remove(task_id)
while node_id < len(ordered_nodes) or len(pending_ids) > 0:
if node_id < len(ordered_nodes):
node = ordered_nodes[node_id]
else:
node = None
if (
(len(pending_ids) > 0 and node_id == len(ordered_nodes)) or
any([
v in task_ids and task_ids[v] in pending_ids
for v in g_execution.predecessors(node)
])
):
if wait_task():
continue
else:
break
else:
args = [
celery.result.AsyncResult(
task_ids[v]
).result
for v in data_deps.get(node, tuple())
]
kwargs = dict()
if node in data_preproc:
args, kwargs = data_preproc[node](
nodes[node],
*args
)
task = nodes[node].clone(args=args, kwargs=kwargs).apply_async()
task_ids[node] = task.id
pending_ids.add(task.id)
active_queue.append(task.id)
del args
del task
node_id += 1
return process_graph_t.res_t(
ordered_nodes=ordered_nodes,
done_ids=done_ids,
done_tasks={
k : celery.result.AsyncResult(task_ids[k])
for k in nodes
if task_ids.get(k) in done_ids
},
task_ids=task_ids,
error_ids=error_ids,
pending_ids=pending_ids,
)

@ -0,0 +1,288 @@
import celery.app.task
import celery.backends.redis
import json
import datetime
import os
import asyncio
import inspect
import importlib
import kombu.utils.json
from typing import (
Any,
Optional,
Callable,
Iterable,
)
from .config import logger_setup
logger = logger_setup(__name__)
def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
#@celery.shared_task(
# base=Task,
# **kwargs,
#)
#def decorator2(*args, **kwargs):
# res = func(*args, **kwargs)
#
# if inspect.isawaitable(res):
# return asyncio.run(res)
# else:
# return res
#
#def decorator(func2: Callable[Any, Any]):
# nonlocal func
#
# if func is None:
# func = func2
#
# return decorator2
#
#if func is None:
# return decorator
#else:
# return decorator2
def decorator(func2):
nonlocal func
if func is None:
func = func2
for a in celery._state._get_active_apps():
name = a.gen_task_name(func.__name__, func.__module__)
if name in a.tasks:
logger.info(json.dumps(dict(
name=name,
a=str(a),
action='derigester_task',
)))
a.tasks.pop(name)
return celery.shared_task(
base=Task,
track_started=True,
**kwargs
)(func)
if func is None:
return decorator
else:
return decorator(func, *args, **kwargs)
def is_async() -> bool:
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False
class Backend(celery.backends.redis.RedisBackend):
def __init__(self, *args, **kwargs):
return super().__init__(*args, **kwargs)
def _store_result(self, task_id, result, state,
traceback=None, request=None, **kwargs):
meta = self._get_result_meta(result=result, state=state,
traceback=traceback, request=request, task_id=task_id,)
meta['task_id'] = celery.backends.base.bytes_to_str(task_id)
# Retrieve metadata from the backend, if the status
# is a success then we ignore any following update to the state.
# This solves a task deduplication issue because of network
# partitioning or lost workers. This issue involved a race condition
# making a lost task overwrite the last successful result in the
# result backend.
current_meta = self._get_task_meta_for(task_id)
if current_meta['status'] == celery.states.SUCCESS:
return result
try:
self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)
except celery.backends.base.BackendStoreError as ex:
raise celery.backends.base.BackendStoreError(str(ex), state=state, task_id=task_id) from ex
return result
def _get_result_meta(
self,
*args,
task_id: Optional[str]=None,
state: Optional[str]=None,
request: Optional[Any]=None,
**kwargs
):
current_meta = None
if not task_id is None:
current_meta = self._get_task_meta_for(task_id)
res = super()._get_result_meta(*args, state=state, request=request, **kwargs)
if not request is None:
#import pprint
#print(request)
if not '_task_name' in res:
res['_task_name'] = request.task
if state == celery.states.STARTED:
if not 'date_started' in res:
res['date_started'] = datetime.datetime.now()
if state in celery.states.READY_STATES:
if not '_date_done' in res:
res['_date_done'] = datetime.datetime.now()
for k in ['date_started', '_date_done', '_task_name']:
if k in current_meta:
res[k] = current_meta[k]
return res
class Task(celery.app.task.Task):
def __call__(self, *args, **kwargs) -> Any:
res = super().__call__(*args, **kwargs)
if inspect.isawaitable(res) and not is_async():
return asyncio.run(res)
else:
return res
#def apply(self, *args, **kwargs):
# return self.__call__(*args, **kwargs)
#def before_start(self, task_id: str, *args, **kwargs):
# self.update_state(None, celery.states.STARTED)
#
# meta = self.backend._get_task_meta_for(task_id)
#
# assert isinstance(meta, dict)
#
# if not 'date_started' in meta:
# meta['date_started']
#
# self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)
#def update_state(
# self,
# *args,
# state: Optional[str]=None,
# meta: Optional[dict[str,Any]]=None,
# **kwargs
#):
# print(['blah', meta, state])
#
# if not meta is None:
# logger.info(json.dumps(dict(state=state)))
#
# if not 'date_started' in meta and state == celery.states.STARTED:
# meta['date_started'] = datetime.datetime.now()
#
# return super().update_stae(*args, state=state, meta=meta, **kwargs)
@classmethod
def _loads(
cls,
data_str: Optional[str]=None,
data: Optional[Any]=None,
) -> Any:
if not data_str is None:
data = kombu.utils.json.loads(data_str)
if isinstance(data, dict) and data.get('type') == 'dataclass_json':
module_name = data['module']
class_names = data['_class'].split('.')
m = importlib.import_module(module_name)
c = m
for current_name in class_names:
c = getattr(c, current_name)
return c.from_dict({
k : cls._loads(data=v)
for k, v in data['data'].items()
})
else:
if isinstance(data, list):
return [
cls._loads(data=o)
for o in data
]
elif isinstance(data, dict):
return {
k : cls._loads(data=v)
for k, v in data.items()
}
else:
return data
@classmethod
def _dumps(
cls,
data: Any,
need_native: Optional[bool]=None,
) -> Any:
if need_native is None:
need_native = False
native = None
if hasattr(data, 'to_dict'):
native = dict(
type='dataclass_json',
module=data.__class__.__module__,
_class=data.__class__.__qualname__,
data={
k : cls._dumps(v, need_native=True,)
for k, v in data.__dict__.items()
},
)
else:
if isinstance(data, (list, tuple)):
native = [
cls._dumps(o, need_native=True,)
for o in data
]
elif isinstance(data, dict):
native = {
k : cls._dumps(v, need_native=True,)
for k, v in data.items()
}
else:
native = data
if not need_native:
return kombu.utils.json.dumps(native)
else:
return native
def kombu_register_json_dataclass():
import kombu.serialization
kombu.serialization.register(
'json-dataclass',
Task._dumps,
Task._loads,
content_type='application/json',
content_encoding='utf-8',
)
def grid_of_videos(paths: Iterable[str]) -> Any:
from ipywidgets import Output, GridspecLayout
from IPython import display
grid = GridspecLayout(1, len(paths))
for i, path in enumerate(paths):
assert os.path.exists(path)
out = Output()
with out:
display.display(display.Video(
url='/files/%s' % path,
height=200,
#embed=True
))
grid[0, i] = out
return grid