import logging
import collections
import enum
import dataclasses
import dataclasses_json
import multiprocessing
import traceback
import subprocess
import os
import sys
import json
from typing import (
    Any,
    Literal,
    Optional,
    Iterable,
)
import celery
from .config import tiktok_config, logger_setup
from .utils import Task, shared_task


logger = logger_setup(__name__)

#logging.getLogger().setLevel(logging.INFO)

@shared_task()
async def tiktok_videos_links_get(
    query: Optional[str]=None,
    screenshot_path: Optional[str]=None,
    max_time: Optional[int | float]=None,
    max_links: Optional[int]=None,
) -> Iterable[str]:
    import datetime
    import TikTokApi
    import pyktok
    import asyncio
    import re

    if max_links is None:
        max_links = 100

    if max_time is None:
        max_time = 10

    async with TikTokApi.TikTokApi() as client:
        await client.create_sessions()

        session = client.sessions[0]

        if not query is None:
            await session.page.goto(
                'https://www.tiktok.com/search?q=%s' % query
            )

        if not screenshot_path is None:
            await session.page.screenshot(
                path=screenshot_path,
            )

        links = list()
        links_set = set()

        started_at = datetime.datetime.now()

        while True:
            content = await session.page.content()
            new_links = re.compile(
                r'https://www.tiktok.com/@\w+/video/\d+'
            ).findall(content)

            old_size = len(links)

            for o in new_links:
                if not o in links_set:
                    links_set.add(o)
                    links.append(o)

            await session.page.mouse.wheel(0, 100)

            elapsed = (
                datetime.datetime.now() - started_at
            ).total_seconds()

            if elapsed > max_time:
                break;

            if len(links_set) > max_links:
                break

            if old_size < len(links):
                logger.info(json.dumps(dict(
                    total=len(links),
                    elapsed=elapsed,
                    scroll_y=await session.page.evaluate('window.scrollY'),
                )))

    return list(links)[:max_links]

@shared_task()
def tiktok_videos_meta(links: Iterable[str]) -> Iterable[dict[str, Any]]:
    res = []
    for o in links:
        parts = o.split('/')

        res.append(dict(
            url=o,
            id=int(parts[-1]),
            fname='_'.join(parts[-3:]) +'.mp4',
            result_dir=tiktok_config().videos,
        ))

    return res

class tiktok_video_fetch_t:
    class method_t(enum.Enum):
        pyktok = 'pyktok'
        tikcdn_io_curl = 'tikcdn.io-curl'
        tikcdn_io_wget = 'tikcdn.io-wget'

@shared_task()
def tiktok_video_fetch(
    id: int,
    url: str,
    fname: str,
    result_dir: str,
    method: Optional[tiktok_video_fetch_t.method_t]=None,
    method_str: Optional[str]=None,
) -> None:
    os.chdir(result_dir)

    if not method_str is None:
        method = tiktok_video_fetch_t.method_t(method_str)

    if method is None:
        method = tiktok_video_fetch_t.method_t.pyktok

    if method == tiktok_video_fetch_t.method_t.pyktok:
        import pyktok
        pyktok.save_tiktok(url)
    elif method == tiktok_video_fetch_t.method_t.tikcdn_io_curl:
        subprocess.check_call([
            'curl',
            '-v',
            'https://tikcdn.io/ssstik/%d' % id,
            '-o', fname,
        ])
    elif method == tiktok_video_fetch_t.method_t.tikcdn_io_wget:
        subprocess.check_call([
            'wget',
            'https://tikcdn.io/ssstik/%d' % id,
            '-O',
            fname,
        ])
    else:
        raise NotImplementedError

    mime_type = file_mime_type(fname)

    if mime_type in ['empty']:
        raise RuntimeError('notdownloaded')

def file_mime_type(path: str) -> Optional[str]:
    if os.path.exists(path):
        mime_type = subprocess.check_output([
            'file',
            '-b', path,
        ]).strip().decode('utf-8')
        return mime_type
    else:
        return None

async def playwright_save(url: str):
    import TikTokApi

    async with TikTokApi.TikTokApi() as client:
        await client.create_sessions()
        session = client.sessions[0]
        page = session.page

        async with page.expect_download() as download_info:
            await page.goto(url)
        download = download_info.value
        path = download.path()
        download.save_as(path)
        print(path)

@shared_task()
def tiktok_videos_fetch(
    meta: Iterable[dict[str, Any]],
    method: Optional[tiktok_video_fetch_t.method_t]=None,
    method_str: Optional[str]=None,
    force: Optional[bool]=None,
) -> Iterable[dict[str, Any]]:
    import tqdm

    if force is None:
        force = False

    stats = dict(
        saved=0,
        total=0,
        skipped=0,
        error=0,
    )

    with multiprocessing.Pool(processes=1) as pool:
        for o in tqdm.tqdm(meta):
            stats['total'] += 1
            path = os.path.join(
                o['result_dir'],
                o['fname'],
            )

            if (
                not os.path.exists(path) or
                file_mime_type(path) in ['empty'] or
                force
            ):
                try:
                    pool.apply(
                        tiktok_video_fetch,
                        kwds=dict(
                            id=o['id'],
                            url=o['url'],
                            fname=o['fname'],
                            method=method,
                            method_str=method_str,
                            result_dir=o['result_dir'],
                        ),
                    )
                    stats['saved'] += 1
                except KeyboardInterrupt:
                    break
                except:
                    logger.error(json.dumps(dict(
                        msg=traceback.format_exc(),
                    )))
                    stats['error'] += 1
            else:
                stats['skipped'] += 1

    return stats

class tiktok_videos_process_t:
    @dataclasses_json.dataclass_json
    @dataclasses.dataclass
    class res_t:
        @dataclasses_json.dataclass_json
        @dataclasses.dataclass
        class stats_t:
            saved: int=0
            total: int=0
            skipped: int=0
            error: int=0

        @dataclasses_json.dataclass_json
        @dataclasses.dataclass
        class video_t:
            meta: Optional[dict[str, Any]]=None
            processed_path: Optional[str]=None


        stats: stats_t=dataclasses.field(default_factory=stats_t)
        videos: Iterable[video_t]=dataclasses.field(default_factory=list)

@shared_task()
def tiktok_videos_process(meta: Iterable[dict[str, Any]]) -> dict[str, Any]:
    import tqdm

    res = tiktok_videos_process_t.res_t(
        videos=[],
    )

    song = audio_get()

    for o in tqdm.tqdm(meta):
        res.stats.total += 1
        res.videos.append(tiktok_videos_process_t.res_t.video_t())

        res.videos[-1].meta = o

        path = os.path.join(
            o['result_dir'],
            o['fname'],
        )

        try:
            path_parts = os.path.splitext(path)

            processed_path = path_parts[0] + '-proc' + path_parts[1]
            processed_path_tmp = path_parts[0] + '-proc.tmp' + path_parts[1]

            if os.path.exists(processed_path):
                res.videos[-1].processed_path = processed_path

            if not os.path.exists(path) or os.path.exists(processed_path):
                res.stats.skipped += 1
                continue

            if os.path.exists(processed_path_tmp):
                os.unlink(processed_path_tmp)

            ffmpeg = [
                'ffmpeg',
                '-i', path,
                '-i', song.path_mp3,
                '-shortest',
                '-vf',
                ','.join([
                    'setpts=1.1*PTS',
                    'scale=trunc(iw/0.9):trunc(ow/a/2)*2',
                ]),
                '-sws_flags', 'bilinear',
                '-map', '0:v:0',
                '-map', '1:a:0',
                processed_path_tmp,
            ]

            subprocess.check_call(
                ffmpeg,
                stdin=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                stdout=subprocess.DEVNULL
            )

            os.rename(processed_path_tmp, processed_path)

            if os.path.exists(processed_path):
                res.videos[-1].processed_path = processed_path

            res.stats.saved += 1
        except KeyboardInterrupt:
            break
        except:
            logger.error(json.dumps(dict(
                msg=traceback.format_exc(),
            )))
            res.stats.error += 1

    return res

class audio_get_t:
    @dataclasses_json.dataclass_json
    @dataclasses.dataclass
    class res_t:
        file: str
        file_mp3: str
        path: str
        path_mp3: str
        url: str

@shared_task()
def audio_get() -> audio_get_t.res_t:
    c = tiktok_config()
    url = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ'
    file = 'song.dat'
    file_mp3 = 'song.mp3'

    path = os.path.join(c.audios, file)
    path_mp3 = os.path.join(c.audios, file_mp3)

    if not os.path.exists(path):
        subprocess.check_call([
            'yt-dlp',
            '-f', 'bestaudio',
            url,
            '-o', path,
        ])

    if not os.path.exists(path_mp3):
        subprocess.check_call([
            'ffmpeg',
            '-i', path,
            path_mp3,
        ])

    return audio_get_t.res_t(
        file=file,
        file_mp3=file_mp3,
        path=path,
        path_mp3=path_mp3,
        url=url,
    )

class process_graph_t:
    @dataclasses_json.dataclass_json
    @dataclasses.dataclass
    class res_t:
        ordered_nodes: Iterable[str]=dataclasses.field(default_factory=list)
        done_ids: Iterable[str]=dataclasses.field(default_factory=set)
        error_ids: Iterable[str]=dataclasses.field(default_factory=set)
        task_ids: dict[str, str]=dataclasses.field(default_factory=dict)
        pending_ids: Iterable[str]=dataclasses.field(default_factory=set)
        done_tasks: Iterable[celery.result.AsyncResult]=dataclasses.field(default_factory=dict)

@shared_task()
def process_graph(
    nodes: dict[str, Any],
    data_deps: dict[str, Iterable[str]],
    execution_deps: dict[str, Iterable[str]],
) -> process_graph_t.res_t:
    import networkx

    g_data = networkx.DiGraph()
    g_execution = networkx.DiGraph()

    for v in nodes:
        g_data.add_node(v)
        g_execution.add_node(v)

    for b, deps in data_deps.items():
        for a in deps:
            g_data.add_edge(a, b)
            g_execution.add_edge(a, b)

    for b, deps in execution_deps.items():
        for a in deps:
            g_execution.add_edge(a, b)

    task_ids = dict()
    done_ids = set()
    error_ids = set()
    pending_ids = set()
    active_queue = collections.deque()

    ordered_nodes = list(networkx.topological_sort(g_execution))
    node_id = 0

    def wait_task() -> bool:
        task_id = active_queue.popleft()
        task = celery.result.AsyncResult(task_id)

        try:
            task.wait()
            if task.status == celery.states.SUCCESS:
                done_ids.add(task_id)
                return True
        except:
            error_ids.add(task_id)
            logger.error(json.dumps(dict(
                msg=traceback.format_exc(),
            )))
            return False
        finally:
            pending_ids.remove(task_id)

    while node_id < len(ordered_nodes) or len(pending_ids) > 0:
        if node_id < len(ordered_nodes):
            node = ordered_nodes[node_id]
        else:
            node = None

        if (
            (len(pending_ids) > 0 and node_id == len(ordered_nodes)) or
            any([
                v in task_ids and task_ids[v] in pending_ids
                for v in g_execution.predecessors(node)
            ])
        ):
            if wait_task():
                continue
            else:
                break
        else:
            args = [
                celery.result.AsyncResult(
                    task_ids[v]
                ).result
                for v in g_data.predecessors(node)
            ]
            task = nodes[node].apply_async(*args)
            task_ids[node] = task.id
            pending_ids.add(task.id)
            active_queue.append(task.id)
            del args
            del task
            node_id += 1

    return process_graph_t.res_t(
        ordered_nodes=ordered_nodes,
        done_ids=done_ids,
        done_tasks={
            k : celery.result.AsyncResult(task_ids[k])
            for k in nodes
            if task_ids.get(k) in done_ids
        },
        task_ids=task_ids,
        error_ids=error_ids,
        pending_ids=pending_ids,
    )