[+] improve worker.py

This commit is contained in:
Siarhei Siniak 2025-07-24 11:00:16 +03:00
parent 97e0270550
commit 64cfccd353

@ -1,10 +1,19 @@
import transformers
import transformers.pipelines
from typing import (Any, cast, Callable, Protocol, Literal,)
from typing import (
Any, cast, Callable, Protocol, Literal, TypedDict,
TypeAlias,
)
class SummarizerPipeline(Protocol):
def predict(self, data: str) -> str: ...
class predict_t:
class output_t(TypedDict):
summary_text: str
res_t : TypeAlias = list[output_t]
def predict(self, data: str) -> predict_t.res_t: ...
class Pipeline(Protocol):
def __call__(
@ -14,34 +23,40 @@ class Pipeline(Protocol):
tokenizer: Any,
) -> 'SummarizerPipeline': ...
class Summarizer:
def __init__(self) -> None:
self.model = cast(
Callable[[str], Any],
getattr(transformers.AutoTokenizer, 'from_pretrained')(
'sshleifer/distilbart-cnn-12-6',
)
)
self.tokenizer = cast(
Callable[[str], Any],
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')(
getattr(transformers.AutoTokenizer, 'from_pretrained')
)(
'sshleifer/distilbart-cnn-12-6',
)
self.model = cast(
Callable[[str], Any],
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')
)(
'sshleifer/distilbart-cnn-12-6',
)
self.summarizer = cast(
Pipeline,
getattr(transformers.pipelines, 'pipeline')
# getattr(transformers.pipelines, 'pipeline')
getattr(transformers, 'pipeline')
)(
'summarization',
model=self.model,
tokenizer=self.tokenizer,
# framework='pt',
)
def summarize(
self,
data: list[str]
) -> list[str]:
return self.summarizer.predict(
res = self.summarizer.predict(
' '.join(data)
).split()
)
assert len(res) == 1
return res[0]['summary_text'].split()