[+] improve typing for worker.py

This commit is contained in:
Siarhei Siniak 2025-07-24 10:26:47 +03:00
parent c4eb8b5568
commit 97e0270550
2 changed files with 9 additions and 8 deletions

@ -35,6 +35,7 @@ venv:
pyright: pyright:
$(ENV_PATH)/bin/python3 -m pyright \ $(ENV_PATH)/bin/python3 -m pyright \
-p pyproject.toml \ -p pyproject.toml \
--threads 3 \
--pythonpath $(PYTHON_PATH) --pythonpath $(PYTHON_PATH)
ruff_check: ruff_check:

@ -4,15 +4,15 @@ import transformers.pipelines
from typing import (Any, cast, Callable, Protocol, Literal,) from typing import (Any, cast, Callable, Protocol, Literal,)
class SummarizerPipeline(Protocol): class SummarizerPipeline(Protocol):
def predict(data: str) -> str: ... def predict(self, data: str) -> str: ...
class Pipeline(Protocol): class Pipeline(Protocol):
def __call__( def __call__(
self, self,
task: Literal['summarizer'], task: Literal['summarization'],
model: Any, model: Any,
tokenizer: Any, tokenizer: Any,
) -> Summarizer: ... ) -> 'SummarizerPipeline': ...
class Summarizer: class Summarizer:
def __init__(self) -> None: def __init__(self) -> None:
@ -31,11 +31,11 @@ class Summarizer:
self.summarizer = cast( self.summarizer = cast(
Pipeline, Pipeline,
getattr(transformers.pipelines, 'pipeline')( getattr(transformers.pipelines, 'pipeline')
)(
'summarization', 'summarization',
model=model, model=self.model,
tokenizer=tokenizer, tokenizer=self.tokenizer,
)
) )
def summarize( def summarize(