[+] improve worker.py
This commit is contained in:
parent
97e0270550
commit
64cfccd353
43
deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py
vendored
43
deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py
vendored
@ -1,10 +1,19 @@
|
||||
import transformers
|
||||
import transformers.pipelines
|
||||
|
||||
from typing import (Any, cast, Callable, Protocol, Literal,)
|
||||
from typing import (
|
||||
Any, cast, Callable, Protocol, Literal, TypedDict,
|
||||
TypeAlias,
|
||||
)
|
||||
|
||||
class SummarizerPipeline(Protocol):
|
||||
def predict(self, data: str) -> str: ...
|
||||
class predict_t:
|
||||
class output_t(TypedDict):
|
||||
summary_text: str
|
||||
|
||||
res_t : TypeAlias = list[output_t]
|
||||
|
||||
def predict(self, data: str) -> predict_t.res_t: ...
|
||||
|
||||
class Pipeline(Protocol):
|
||||
def __call__(
|
||||
@ -14,34 +23,40 @@ class Pipeline(Protocol):
|
||||
tokenizer: Any,
|
||||
) -> 'SummarizerPipeline': ...
|
||||
|
||||
|
||||
class Summarizer:
|
||||
def __init__(self) -> None:
|
||||
self.model = cast(
|
||||
Callable[[str], Any],
|
||||
getattr(transformers.AutoTokenizer, 'from_pretrained')(
|
||||
'sshleifer/distilbart-cnn-12-6',
|
||||
)
|
||||
)
|
||||
self.tokenizer = cast(
|
||||
Callable[[str], Any],
|
||||
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')(
|
||||
'sshleifer/distilbart-cnn-12-6',
|
||||
)
|
||||
getattr(transformers.AutoTokenizer, 'from_pretrained')
|
||||
)(
|
||||
'sshleifer/distilbart-cnn-12-6',
|
||||
)
|
||||
self.model = cast(
|
||||
Callable[[str], Any],
|
||||
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')
|
||||
)(
|
||||
'sshleifer/distilbart-cnn-12-6',
|
||||
)
|
||||
|
||||
self.summarizer = cast(
|
||||
Pipeline,
|
||||
getattr(transformers.pipelines, 'pipeline')
|
||||
# getattr(transformers.pipelines, 'pipeline')
|
||||
getattr(transformers, 'pipeline')
|
||||
)(
|
||||
'summarization',
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
# framework='pt',
|
||||
)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
data: list[str]
|
||||
) -> list[str]:
|
||||
return self.summarizer.predict(
|
||||
res = self.summarizer.predict(
|
||||
' '.join(data)
|
||||
).split()
|
||||
)
|
||||
assert len(res) == 1
|
||||
|
||||
return res[0]['summary_text'].split()
|
||||
|
Loading…
Reference in New Issue
Block a user