import celery.app.task
import celery.backends.redis
import datetime
import os
import asyncio
import inspect
import importlib
import kombu.utils.json
from typing import (
    Any,
    Optional,
    Callable,
    Iterable,
)

def shared_task(func: Optional[Callable[Any, Any]]=None, **kwargs: Any) -> Any:
    #@celery.shared_task(
    #    base=Task,
    #    **kwargs,
    #)
    #def decorator2(*args, **kwargs):
    #    res = func(*args, **kwargs)
    #
    #    if inspect.isawaitable(res):
    #        return asyncio.run(res)
    #    else:
    #        return res
    #
    #def decorator(func2: Callable[Any, Any]):
    #    nonlocal func
    #
    #    if func is None:
    #        func = func2
    #
    #    return decorator2
    #
    #if func is None:
    #    return decorator
    #else:
    #    return decorator2
    return celery.shared_task(
        base=Task,
        track_started=True,
        **kwargs
    )

def is_async() -> bool:
    try:
        asyncio.get_running_loop()
        return True
    except RuntimeError:
        return False

class Backend(celery.backends.redis.RedisBackend):
    def __init__(self, *args, **kwargs):
        return super().__init__(*args, **kwargs)

    def _store_result(self, task_id, result, state,
                      traceback=None, request=None, **kwargs):
        meta = self._get_result_meta(result=result, state=state,
                                     traceback=traceback, request=request, task_id=task_id,)
        meta['task_id'] = celery.backends.base.bytes_to_str(task_id)

        # Retrieve metadata from the backend, if the status
        # is a success then we ignore any following update to the state.
        # This solves a task deduplication issue because of network
        # partitioning or lost workers. This issue involved a race condition
        # making a lost task overwrite the last successful result in the
        # result backend.
        current_meta = self._get_task_meta_for(task_id)

        if current_meta['status'] == celery.states.SUCCESS:
            return result

        try:
            self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)
        except celery.backends.base.BackendStoreError as ex:
            raise celery.backends.base.BackendStoreError(str(ex), state=state, task_id=task_id) from ex

        return result


    def _get_result_meta(
        self,
        *args,
        task_id: Optional[str]=None,
        state: Optional[str]=None,
        **kwargs
    ):
        current_meta = None

        if not task_id is None:
            current_meta = self._get_task_meta_for(task_id)

        res = super()._get_result_meta(*args, state=state, **kwargs)

        if state == celery.states.STARTED:
            if not 'date_started' in res:
                res['date_started'] = datetime.datetime.now()

        if state in celery.states.READY_STATES:
            if not '_date_done' in res:
                res['_date_done'] = datetime.datetime.now()

        for k in ['date_started', '_date_done',]:
            if k in current_meta:
                res[k] = current_meta[k]

        return res

class Task(celery.app.task.Task):
    def __call__(self, *args, **kwargs) -> Any:
        res = super().__call__(*args, **kwargs)
        if inspect.isawaitable(res) and not is_async():
            return asyncio.run(res)
        else:
            return res

    #def apply(self, *args, **kwargs):
    #    return self.__call__(*args, **kwargs)

    #def before_start(self, task_id: str, *args, **kwargs):
    #    self.update_state(None, celery.states.STARTED)
    #
    #    meta = self.backend._get_task_meta_for(task_id)
    #
    #    assert isinstance(meta, dict)
    #
    #    if not 'date_started' in meta:
    #        meta['date_started']
    #
    #    self._set_with_state(self.get_key_for_task(task_id), self.encode(meta), state)

    #def update_state(
    #    self,
    #    *args,
    #    state: Optional[str]=None,
    #    meta: Optional[dict[str,Any]]=None,
    #    **kwargs
    #):
    #    print(['blah', meta, state])
    #
    #    if not meta is None:
    #        logger.info(json.dumps(dict(state=state)))
    #
    #        if not 'date_started' in meta and state == celery.states.STARTED:
    #            meta['date_started'] = datetime.datetime.now()
    #
    #    return super().update_stae(*args, state=state, meta=meta, **kwargs)

    @classmethod
    def _loads(
        cls,
        data_str: Optional[str]=None,
        data: Optional[Any]=None,
    ) -> Any:
        if not data_str is None:
            data = kombu.utils.json.loads(data_str)

        if isinstance(data, dict) and data.get('type') == 'dataclass_json':
            module_name = data['module']
            class_names = data['_class'].split('.')

            m = importlib.import_module(module_name)

            c = m
            for current_name in class_names:
                c = getattr(c, current_name)

            return c(**data['data'])
        else:
            if isinstance(data, list):
                return [
                    cls._loads(data=o)
                    for o in data
                ]
            elif isinstance(data, dict):
                return {
                        k : cls._loads(data=v)
                    for k, v in data.items()
                }
            else:
                return data

    @classmethod
    def _dumps(
        cls,
        data: Any,
        need_native: Optional[bool]=None,
    ) -> Any:
        if need_native is None:
            need_native = False

        native = None
        if hasattr(data, 'to_dict'):
            native = dict(
                type='dataclass_json',
                module=data.__class__.__module__,
                _class=data.__class__.__qualname__,
                data=data.to_dict(),
            )
        else:
            if isinstance(data, (list, tuple)):
                native = [
                    cls._dumps(o, need_native=True,)
                    for o in data
                ]
            elif isinstance(data, dict):
                native = {
                    k : cls._dumps(v, need_native=True,)
                    for k, v in data.items()
                }
            else:
                native = data

        if not need_native:
            return kombu.utils.json.dumps(native)
        else:
            return native

def kombu_register_json_dataclass():
    import kombu.serialization
    kombu.serialization.register(
        'json-dataclass',
        Task._dumps,
        Task._loads,
        content_type='application/json',
        content_encoding='utf-8',
    )

def grid_of_videos(paths: Iterable[str]) -> Any:
    from ipywidgets import Output, GridspecLayout
    from IPython import display

    grid = GridspecLayout(1, len(paths))

    for i, path in enumerate(paths):
        assert os.path.exists(path)

        out = Output()
        with out:
            display.display(display.Video(
                url='/files/%s' % path,
                #embed=True
            ))

        grid[0, i] = out

    return grid