[+] improve worker.py

This commit is contained in:
Siarhei Siniak 2025-07-24 11:00:16 +03:00
parent 97e0270550
commit 64cfccd353

@ -1,10 +1,19 @@
import transformers import transformers
import transformers.pipelines import transformers.pipelines
from typing import (Any, cast, Callable, Protocol, Literal,) from typing import (
Any, cast, Callable, Protocol, Literal, TypedDict,
TypeAlias,
)
class SummarizerPipeline(Protocol): class SummarizerPipeline(Protocol):
def predict(self, data: str) -> str: ... class predict_t:
class output_t(TypedDict):
summary_text: str
res_t : TypeAlias = list[output_t]
def predict(self, data: str) -> predict_t.res_t: ...
class Pipeline(Protocol): class Pipeline(Protocol):
def __call__( def __call__(
@ -14,34 +23,40 @@ class Pipeline(Protocol):
tokenizer: Any, tokenizer: Any,
) -> 'SummarizerPipeline': ... ) -> 'SummarizerPipeline': ...
class Summarizer: class Summarizer:
def __init__(self) -> None: def __init__(self) -> None:
self.model = cast(
Callable[[str], Any],
getattr(transformers.AutoTokenizer, 'from_pretrained')(
'sshleifer/distilbart-cnn-12-6',
)
)
self.tokenizer = cast( self.tokenizer = cast(
Callable[[str], Any], Callable[[str], Any],
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')( getattr(transformers.AutoTokenizer, 'from_pretrained')
'sshleifer/distilbart-cnn-12-6', )(
) 'sshleifer/distilbart-cnn-12-6',
)
self.model = cast(
Callable[[str], Any],
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')
)(
'sshleifer/distilbart-cnn-12-6',
) )
self.summarizer = cast( self.summarizer = cast(
Pipeline, Pipeline,
getattr(transformers.pipelines, 'pipeline') # getattr(transformers.pipelines, 'pipeline')
getattr(transformers, 'pipeline')
)( )(
'summarization', 'summarization',
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
# framework='pt',
) )
def summarize( def summarize(
self, self,
data: list[str] data: list[str]
) -> list[str]: ) -> list[str]:
return self.summarizer.predict( res = self.summarizer.predict(
' '.join(data) ' '.join(data)
).split() )
assert len(res) == 1
return res[0]['summary_text'].split()