import time
import glob
import io
import os
import numpy
import numpy.typing
import functools
import pathlib
import threading
import cython
import datetime

from typing import (Any, Optional, TypeVar, Type, cast)
# from scoping import scoping as s

def test(
    _id: int,
    T: float,
    a: numpy.ndarray[Any, numpy.dtype[numpy.int32]],
) -> None:
    with cython.nogil:
    #if True:
        started_at = datetime.datetime.now()
        print('started')
        def elapsed() -> float:
            return (datetime.datetime.now() - started_at).total_seconds()
        #a = 0
        while elapsed() < T:
            #a += 1
            for k in range(1024 * 1024):
                a[_id] += 1

    print(['done', started_at, elapsed(), a[_id]])

M = TypeVar('M', bound=Type[Any])

def build(content: str, module: M) -> M:
    import pathlib
    import tempfile
    import hashlib
    import Cython.Build.Inline

    sha256sum = hashlib.sha256(content.encode('utf-8')).digest().hex()

    output_dir = (pathlib.Path('.') / 'tmp' / 'cython' / sha256sum).absolute()


    if not output_dir.exists() or True:
        os.makedirs(str(output_dir), exist_ok=True)

        source_path = output_dir / ('_%s.pyx' % sha256sum)
        if not source_path.exists():
            with io.open(str(source_path), 'w') as f:
                f.write(content)

        t1 = Cython.Build.Inline._get_build_extension()
        t1.extensions = Cython.Build.cythonize(str(source_path))
        t1.build_temp = str(pathlib.Path('/'))
        t1.build_lib = str(output_dir)
        #t2 = Cython.Build.Inline.Extension(
        #    name=sha256sum,
        #)
        t1.run()

        return cast(
            M,
            Cython.Build.Inline.load_dynamic(
                '_%s' % sha256sum,
                glob.glob(
                    str(output_dir / ('_%s*.so' % sha256sum))
                )[0]
            )
        )

    raise NotImplementedError

def mypyc_build(file_path: pathlib.Path) -> Any:
    import pathlib
    import tempfile
    import hashlib
    import mypyc.build
    import Cython.Build.Inline

    assert isinstance(file_path, pathlib.Path)

    #sha256sum = hashlib.sha256(content.encode('utf-8')).digest().hex()

    #output_dir = (pathlib.Path('.') / 'tmp' / 'cython' / sha256sum).absolute()
    output_dir = pathlib.Path('.') / 'tmp' / 'mypyc'
    sha256sum = file_path.stem
    lib_pattern = file_path.parent / ('%s.cpython*.so' % sha256sum)
    lib_dir = pathlib.Path('.')


    def lib_path_glob(path: str | pathlib.Path) -> Optional[pathlib.Path]:
        res : list[str] = glob.glob(str(path))

        if len(res) == 0:
            return None
        else:
            return pathlib.Path(res[0])

    need_build : bool = False

    lib_path : Optional[pathlib.Path] = None

    lib_path = lib_path_glob(lib_pattern)

    if not lib_path is None:
        t2 = file_path.stat()
        t3 = lib_path.stat()
        if t3.st_mtime < t2.st_mtime:
            need_build = True

        del t2
        del t3
    else:
        need_build = True


    if need_build:
        for o in [
            output_dir,
            output_dir / 'build' / file_path.parent,
        ]:
            os.makedirs(
                str(o),
                exist_ok=True
            )
        #source_path = output_dir / ('_%s.py' % sha256sum)
        source_path = file_path
        #with io.open(str(source_path), 'w') as f:
        #    f.write(content)

        t1 = Cython.Build.Inline._get_build_extension()
        t1.extensions = mypyc.build.mypycify(
            [str(source_path)],
            target_dir=str(output_dir / 'build')
        )
        t1.build_temp = str(output_dir)
        t1.build_lib = str(lib_dir)
        #t2 = Cython.Build.Inline.Extension(
        #    name=sha256sum,
        #)
        t1.run()

        lib_path = lib_path_glob(lib_pattern)

    return Cython.Build.Inline.load_dynamic(
        #'_%s' % sha256sum,
        #t1.extensions[0].name,
        file_path.stem,
        str(lib_path),
    )

    raise NotImplementedError

class Source:
    @staticmethod
    def test2(
        _a : numpy.ndarray[Any, numpy.dtype[numpy.int64]],
        _id : numpy.dtype[numpy.int32] | int,
        T : float=16
    ) -> int:
        raise NotImplementedError


source = build(r'''
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def test4(int[:] a, int[:] b):
    cdef int N = a.shape[0]
    assert N == b.shape[0]

    with cython.nogil:
        for i in range(N):
            a[i] += b[i]
    return N

import datetime

def elapsed(started_at: datetime.datetime):
    res = (datetime.datetime.now() - started_at).total_seconds()

    return res

@cython.boundscheck(False)  # Deactivate bounds checking
@cython.wraparound(False)   # Deactivate negative indexing.
def has_time(started_at: datetime.datetime, T: float):
    t1 = elapsed(started_at)

    res = t1 < T

    return res

@cython.boundscheck(False)
@cython.wraparound(False)
def test2(long long [:] _a,  int _id, double T=16) -> int:
    started_at = datetime.datetime.now()

    print('started')

    cdef int C = 1;

    cdef int cond;

    with cython.nogil:
    #if True:
        #a = 0
        while True:

            with cython.gil:
                cond = has_time(started_at, T)
                #cond = 0

                if cond != 1:
                    break

            #a += 1
            for k in range(1024 * 1024 * 1024):
                _a[_id] += C

    print(['done', started_at, elapsed(started_at), _a[_id]])

    return _a[_id]

''', Source)

def test_cython(N: int=4, T:int=16) -> None:
    #a = [0] * N
    a = numpy.zeros((N,), dtype=numpy.int64)

    t = [
        threading.Thread(
            target=functools.partial(
                source.test2,
                a,
                k,
                T,
            )
        )
        for k in range(N)
    ]

    for o in t:
        o.start()
    for o in t:
        o.join()

    #cython_module['test2'](a, 0)

def test_mypyc(N: int=4, W:int=35) -> None:
    cython2 = mypyc_build(
        (pathlib.Path(__file__).parent / 'cython2.py').relative_to(
            pathlib.Path.cwd()
        )
    )

    # from .cython2 import fib

    #a = [0] * N
    t = [
        threading.Thread(
            target=functools.partial(
                cython2.fib,
                W,
            )
        )
        for k in range(N)
    ]

    for o in t:
        o.start()
    for o in t:
        o.join()