File size: 8,814 Bytes
58bde27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9c338a
 
58bde27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import logging
import logging.config

from typing import Dict, List

import feedparser
import torch

from bs4 import BeautifulSoup
from functools import wraps
from time import time
from pydantic import HttpUrl
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    pipeline,
)

from config import LANGUAGES, LANG_LEX_2_CODE
from logging_conf import LOGGING_CONFIG


logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger("src.task_management")


def proc_timer(f):
    @wraps(f)
    def wrapper(*args, **kw):
        ts = time()
        result = f(*args, **kw)
        te = time()
        logger.info(f"func:{f.__name__} args:[{args}, {kw}] took: {te - ts}:%2.4f sec")
        return result

    return wrapper


class TaskManager:
    """TaskManager class managing the summarization, translation,
    feed-parsing and other necessary processing tasks
    """

    def __init__(self):
        # The supported, by our application, translation languages
        self.supported_langs = LANGUAGES.values()

        # Load the bart-large-cnn model and tokenizer
        summarization_model_name = "facebook/bart-large-cnn"

        # Move model for summarization to GPU if available
        # self.summarization_device = (
        #     0 if torch.cuda.is_available() else -1
        # )  # 0 for GPU, -1 for CPU
        self.summarization_device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

        self.summarization_config = AutoConfig.from_pretrained(summarization_model_name)

        self.summarizer = AutoModelForSeq2SeqLM.from_pretrained(
            summarization_model_name
        ).to(self.summarization_device)

        self.summarization_tokenizer = AutoTokenizer.from_pretrained(
            summarization_model_name
        )

        # Check if CUDA is available and set the device
        self.translation_device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

        # Load translation pipeline for model facebook/nllb-200-distilled-1.3B
        self.translator = pipeline(
            "translation",
            model="facebook/nllb-200-distilled-1.3B",
            device=self.translation_device,
        )

    # @proc_timer
    def summarize(
        self, txt_to_summarize: str, max_length: int = 30, min_length: int = 10
    ) -> str:
        """Summarization task, used for summarizing the provided text

        Args:
            txt_to_summarize (str): the text that need to be summarized
            max_length (int, optional): the max_length downlimit of the summarized text. Defaults to 30.
            min_length (int, optional): the min_length downlimit of the summarized text. Defaults to 10.

        Returns:
            str: the summarized text
        """

        full_text_length = len(txt_to_summarize)

        # Adapt max and min lengths for summary, if larger than they should be
        max_perc_init_length = round(full_text_length * 0.3)
        max_length = (
            max_perc_init_length
            if self.summarization_config.max_length > 0.5 * full_text_length
            else max(max_length, self.summarization_config.max_length)
        )

        # Min length is the minimum of the following two:
        # the min to max default config values factor, multiplied by real max
        # the default config minimum value
        min_to_max_perc = (
            self.summarization_config.min_length / self.summarization_config.max_length
        )
        min_length = min(
            round(min_to_max_perc * max_length), self.summarization_config.min_length
        )

        # Tokenize input
        inputs = self.summarization_tokenizer(
            txt_to_summarize, return_tensors="pt", max_length=1024, truncation=True
        ).to(self.summarization_device)

        # Generate summary with custom max_length
        summary_ids = self.summarizer.generate(
            inputs["input_ids"],
            max_length=max_length,  # Set max_length here
            min_length=min_length,  # Set min_length here
            num_beams=4,  # Optional: Use beam search
            early_stopping=True,  # Optional: Stop early if EOS is reached
        )

        # Decode the summary
        summary_txt = self.summarization_tokenizer.decode(
            summary_ids[0], skip_special_tokens=True
        )

        return summary_txt

    # @proc_timer
    def translate(self, txt_to_translate: str, src_lang: str, tgt_lang: str) -> str:
        """Translate the provided text from a source language to a target language

        Args:
            txt_to_translate (str): the text to translate
            src_lang (str): the source language of the initial text
            tgt_lang (str): the target language the initial text should be translated to

        Raises:
            RuntimeError: error in case of unsupported source language
            RuntimeError: error in case of unsupported target language
            RuntimeError: error in case of translation failure

        Returns:
            str: the translated text
        """

        # Raise error in case of unsupported languages
        if src_lang not in self.supported_langs:
            raise RuntimeError("Unsupported source language.")
        if tgt_lang not in self.supported_langs:
            raise RuntimeError("Unsupported target language.")

        # Translate the text using the NLLB model
        src_lang = LANG_LEX_2_CODE.get(src_lang, src_lang)
        tgt_lang = LANG_LEX_2_CODE.get(tgt_lang, tgt_lang)
        translated_text = self.translator(
            txt_to_translate, src_lang=src_lang, tgt_lang=tgt_lang, batch_size=10
        )[0]["translation_text"]

        # If something goes wrong with the translation raise error
        if len(translated_text) <= 0:
            raise RuntimeError("Failed to generate translation.")

        return translated_text

    def parse_and_process_feed(
        self,
        rss_url: HttpUrl,
        src_lang: str,
        tgt_lang: str,
        entries_limit: int = None,
    ) -> List[Dict]:
        """Parse the input feed, and process the feed entries keeping the important information,
        summarizing and translating it

        Args:
            rss_url (HttpUrl): the feed url to parse
            src_lang (str): the feed's initial language
            tgt_lang (str): the target language to which the content will be translated
            entries_limit (int, optional): the number of feed-entries to be processed. Defaults to None (process all).

        Returns:
            List[Dict]: a list of dictionaries, each one containing the processed info regarding
            title, author, content and link for the respective feed entry
        """

        src_lang = LANGUAGES.get(src_lang, src_lang)
        tgt_lang = LANGUAGES.get(tgt_lang, tgt_lang)
        default_lang = LANGUAGES.get("en", "en")

        feed = feedparser.parse(rss_url)

        # Return the maximum number of entries in case entries is None or exceeding entries length
        processed_entries = feed.entries[:entries_limit]

        # Iterate over each entry in the feed
        for entry in processed_entries:
            title = entry.get("title", "")
            author = entry.get("author", "")
            link = entry.get("link", "")
            content = entry.get(
                "summary", entry.get("content", entry.get("description", ""))
            )

            soup = BeautifulSoup(content, features="html.parser")
            content = "".join(soup.findAll(text=True))

            # If source language is not English, first translate it to English to summarize
            if src_lang != default_lang:
                content = self.translate(
                    content, src_lang=src_lang, tgt_lang=default_lang
                )

            # Summarize the content
            summarized_content = self.summarize(content, max_length=30, min_length=10)

            # Translate the title and summarized content
            translated_title = self.translate(
                title, src_lang=src_lang, tgt_lang=tgt_lang
            )

            # Unless the target language is already the default, translate it
            translated_content = (
                self.translate(
                    summarized_content, src_lang=default_lang, tgt_lang=tgt_lang
                )
                if tgt_lang != default_lang
                else summarized_content
            )

            # Update entry
            entry.update(
                {
                    "title": translated_title,
                    "content": translated_content,
                    "author": author,
                    "link": link,
                }
            )

        return processed_entries