diff --git a/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py b/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py new file mode 100644 index 0000000..821e222 --- /dev/null +++ b/deps/test-task-2025-07-17-v2/python/online/fxreader/pr34/test_task_2025_07_17_v2/transform/worker.py @@ -0,0 +1,47 @@ +import transformers +import transformers.pipelines + +from typing import (Any, cast, Callable, Protocol, Literal,) + +class SummarizerPipeline(Protocol): + def predict(data: str) -> str: ... + +class Pipeline(Protocol): + def __call__( + self, + task: Literal['summarizer'], + model: Any, + tokenizer: Any, + ) -> Summarizer: ... + +class Summarizer: + def __init__(self) -> None: + self.model = cast( + Callable[[str], Any], + getattr(transformers.AutoTokenizer, 'from_pretrained')( + 'sshleifer/distilbart-cnn-12-6', + ) + ) + self.tokenizer = cast( + Callable[[str], Any], + getattr(transformers.AutoModelForSeq2SeqLM, 'from_pretrained')( + 'sshleifer/distilbart-cnn-12-6', + ) + ) + + self.summarizer = cast( + Pipeline, + getattr(transformers.pipelines, 'pipeline')( + 'summarization', + model=model, + tokenizer=tokenizer, + ) + ) + + def summarize( + self, + data: list[str] + ) -> list[str]: + return self.summarizer.predict( + ' '.join(data) + ).split()