From 64cfccd353ceeba8ef7c153ef4dd67747a35ec4d Mon Sep 17 00:00:00 2001 From: Siarhei Siniak Date: Thu, 24 Jul 2025 11:00:16 +0300 Subject: [PATCH] [+] improve worker.py --- .../transform/worker.py | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py b/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py index 5e2f69e..12bd191 100644 --- a/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py +++ b/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py @@ -1,10 +1,19 @@ import transformers import transformers.pipelines -from typing import (Any, cast, Callable, Protocol, Literal,) +from typing import ( + Any, cast, Callable, Protocol, Literal, TypedDict, + TypeAlias, +) class SummarizerPipeline(Protocol): - def predict(self, data: str) -> str: ... + class predict_t: + class output_t(TypedDict): + summary_text: str + + res_t : TypeAlias = list[output_t] + + def predict(self, data: str) -> predict_t.res_t: ... class Pipeline(Protocol): def __call__( @@ -14,34 +23,40 @@ class Pipeline(Protocol): tokenizer: Any, ) -> 'SummarizerPipeline': ... + class Summarizer: def __init__(self) -> None: - self.model = cast( - Callable[[str], Any], - getattr(transformers.AutoTokenizer, 'from_pretrained')( - 'sshleifer/distilbart-cnn-12-6', - ) - ) self.tokenizer = cast( Callable[[str], Any], - getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')( - 'sshleifer/distilbart-cnn-12-6', - ) + getattr(transformers.AutoTokenizer, 'from_pretrained') + )( + 'sshleifer/distilbart-cnn-12-6', + ) + self.model = cast( + Callable[[str], Any], + getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained') + )( + 'sshleifer/distilbart-cnn-12-6', ) self.summarizer = cast( Pipeline, - getattr(transformers.pipelines, 'pipeline') + # getattr(transformers.pipelines, 'pipeline') + getattr(transformers, 'pipeline') )( 'summarization', model=self.model, tokenizer=self.tokenizer, + # framework='pt', ) def summarize( self, data: list[str] ) -> list[str]: - return self.summarizer.predict( + res = self.summarizer.predict( ' '.join(data) - ).split() + ) + assert len(res) == 1 + + return res[0]['summary_text'].split()