[+] improve worker.py
This commit is contained in:
parent
97e0270550
commit
64cfccd353
43
deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py
vendored
43
deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py
vendored
@ -1,10 +1,19 @@
|
|||||||
import transformers
|
import transformers
|
||||||
import transformers.pipelines
|
import transformers.pipelines
|
||||||
|
|
||||||
from typing import (Any, cast, Callable, Protocol, Literal,)
|
from typing import (
|
||||||
|
Any, cast, Callable, Protocol, Literal, TypedDict,
|
||||||
|
TypeAlias,
|
||||||
|
)
|
||||||
|
|
||||||
class SummarizerPipeline(Protocol):
|
class SummarizerPipeline(Protocol):
|
||||||
def predict(self, data: str) -> str: ...
|
class predict_t:
|
||||||
|
class output_t(TypedDict):
|
||||||
|
summary_text: str
|
||||||
|
|
||||||
|
res_t : TypeAlias = list[output_t]
|
||||||
|
|
||||||
|
def predict(self, data: str) -> predict_t.res_t: ...
|
||||||
|
|
||||||
class Pipeline(Protocol):
|
class Pipeline(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -14,34 +23,40 @@ class Pipeline(Protocol):
|
|||||||
tokenizer: Any,
|
tokenizer: Any,
|
||||||
) -> 'SummarizerPipeline': ...
|
) -> 'SummarizerPipeline': ...
|
||||||
|
|
||||||
|
|
||||||
class Summarizer:
|
class Summarizer:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.model = cast(
|
|
||||||
Callable[[str], Any],
|
|
||||||
getattr(transformers.AutoTokenizer, 'from_pretrained')(
|
|
||||||
'sshleifer/distilbart-cnn-12-6',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.tokenizer = cast(
|
self.tokenizer = cast(
|
||||||
Callable[[str], Any],
|
Callable[[str], Any],
|
||||||
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')(
|
getattr(transformers.AutoTokenizer, 'from_pretrained')
|
||||||
'sshleifer/distilbart-cnn-12-6',
|
)(
|
||||||
)
|
'sshleifer/distilbart-cnn-12-6',
|
||||||
|
)
|
||||||
|
self.model = cast(
|
||||||
|
Callable[[str], Any],
|
||||||
|
getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')
|
||||||
|
)(
|
||||||
|
'sshleifer/distilbart-cnn-12-6',
|
||||||
)
|
)
|
||||||
|
|
||||||
self.summarizer = cast(
|
self.summarizer = cast(
|
||||||
Pipeline,
|
Pipeline,
|
||||||
getattr(transformers.pipelines, 'pipeline')
|
# getattr(transformers.pipelines, 'pipeline')
|
||||||
|
getattr(transformers, 'pipeline')
|
||||||
)(
|
)(
|
||||||
'summarization',
|
'summarization',
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
|
# framework='pt',
|
||||||
)
|
)
|
||||||
|
|
||||||
def summarize(
|
def summarize(
|
||||||
self,
|
self,
|
||||||
data: list[str]
|
data: list[str]
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
return self.summarizer.predict(
|
res = self.summarizer.predict(
|
||||||
' '.join(data)
|
' '.join(data)
|
||||||
).split()
|
)
|
||||||
|
assert len(res) == 1
|
||||||
|
|
||||||
|
return res[0]['summary_text'].split()
|
||||||
|
Loading…
Reference in New Issue
Block a user