Spaces:
Sleeping
Sleeping
File size: 8,814 Bytes
58bde27 c9c338a 58bde27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import logging
import logging.config
from typing import Dict, List
import feedparser
import torch
from bs4 import BeautifulSoup
from functools import wraps
from time import time
from pydantic import HttpUrl
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
)
from config import LANGUAGES, LANG_LEX_2_CODE
from logging_conf import LOGGING_CONFIG
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger("src.task_management")
def proc_timer(f):
@wraps(f)
def wrapper(*args, **kw):
ts = time()
result = f(*args, **kw)
te = time()
logger.info(f"func:{f.__name__} args:[{args}, {kw}] took: {te - ts}:%2.4f sec")
return result
return wrapper
class TaskManager:
"""TaskManager class managing the summarization, translation,
feed-parsing and other necessary processing tasks
"""
def __init__(self):
# The supported, by our application, translation languages
self.supported_langs = LANGUAGES.values()
# Load the bart-large-cnn model and tokenizer
summarization_model_name = "facebook/bart-large-cnn"
# Move model for summarization to GPU if available
# self.summarization_device = (
# 0 if torch.cuda.is_available() else -1
# ) # 0 for GPU, -1 for CPU
self.summarization_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.summarization_config = AutoConfig.from_pretrained(summarization_model_name)
self.summarizer = AutoModelForSeq2SeqLM.from_pretrained(
summarization_model_name
).to(self.summarization_device)
self.summarization_tokenizer = AutoTokenizer.from_pretrained(
summarization_model_name
)
# Check if CUDA is available and set the device
self.translation_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Load translation pipeline for model facebook/nllb-200-distilled-1.3B
self.translator = pipeline(
"translation",
model="facebook/nllb-200-distilled-1.3B",
device=self.translation_device,
)
# @proc_timer
def summarize(
self, txt_to_summarize: str, max_length: int = 30, min_length: int = 10
) -> str:
"""Summarization task, used for summarizing the provided text
Args:
txt_to_summarize (str): the text that need to be summarized
max_length (int, optional): the max_length downlimit of the summarized text. Defaults to 30.
min_length (int, optional): the min_length downlimit of the summarized text. Defaults to 10.
Returns:
str: the summarized text
"""
full_text_length = len(txt_to_summarize)
# Adapt max and min lengths for summary, if larger than they should be
max_perc_init_length = round(full_text_length * 0.3)
max_length = (
max_perc_init_length
if self.summarization_config.max_length > 0.5 * full_text_length
else max(max_length, self.summarization_config.max_length)
)
# Min length is the minimum of the following two:
# the min to max default config values factor, multiplied by real max
# the default config minimum value
min_to_max_perc = (
self.summarization_config.min_length / self.summarization_config.max_length
)
min_length = min(
round(min_to_max_perc * max_length), self.summarization_config.min_length
)
# Tokenize input
inputs = self.summarization_tokenizer(
txt_to_summarize, return_tensors="pt", max_length=1024, truncation=True
).to(self.summarization_device)
# Generate summary with custom max_length
summary_ids = self.summarizer.generate(
inputs["input_ids"],
max_length=max_length, # Set max_length here
min_length=min_length, # Set min_length here
num_beams=4, # Optional: Use beam search
early_stopping=True, # Optional: Stop early if EOS is reached
)
# Decode the summary
summary_txt = self.summarization_tokenizer.decode(
summary_ids[0], skip_special_tokens=True
)
return summary_txt
# @proc_timer
def translate(self, txt_to_translate: str, src_lang: str, tgt_lang: str) -> str:
"""Translate the provided text from a source language to a target language
Args:
txt_to_translate (str): the text to translate
src_lang (str): the source language of the initial text
tgt_lang (str): the target language the initial text should be translated to
Raises:
RuntimeError: error in case of unsupported source language
RuntimeError: error in case of unsupported target language
RuntimeError: error in case of translation failure
Returns:
str: the translated text
"""
# Raise error in case of unsupported languages
if src_lang not in self.supported_langs:
raise RuntimeError("Unsupported source language.")
if tgt_lang not in self.supported_langs:
raise RuntimeError("Unsupported target language.")
# Translate the text using the NLLB model
src_lang = LANG_LEX_2_CODE.get(src_lang, src_lang)
tgt_lang = LANG_LEX_2_CODE.get(tgt_lang, tgt_lang)
translated_text = self.translator(
txt_to_translate, src_lang=src_lang, tgt_lang=tgt_lang, batch_size=10
)[0]["translation_text"]
# If something goes wrong with the translation raise error
if len(translated_text) <= 0:
raise RuntimeError("Failed to generate translation.")
return translated_text
def parse_and_process_feed(
self,
rss_url: HttpUrl,
src_lang: str,
tgt_lang: str,
entries_limit: int = None,
) -> List[Dict]:
"""Parse the input feed, and process the feed entries keeping the important information,
summarizing and translating it
Args:
rss_url (HttpUrl): the feed url to parse
src_lang (str): the feed's initial language
tgt_lang (str): the target language to which the content will be translated
entries_limit (int, optional): the number of feed-entries to be processed. Defaults to None (process all).
Returns:
List[Dict]: a list of dictionaries, each one containing the processed info regarding
title, author, content and link for the respective feed entry
"""
src_lang = LANGUAGES.get(src_lang, src_lang)
tgt_lang = LANGUAGES.get(tgt_lang, tgt_lang)
default_lang = LANGUAGES.get("en", "en")
feed = feedparser.parse(rss_url)
# Return the maximum number of entries in case entries is None or exceeding entries length
processed_entries = feed.entries[:entries_limit]
# Iterate over each entry in the feed
for entry in processed_entries:
title = entry.get("title", "")
author = entry.get("author", "")
link = entry.get("link", "")
content = entry.get(
"summary", entry.get("content", entry.get("description", ""))
)
soup = BeautifulSoup(content, features="html.parser")
content = "".join(soup.findAll(text=True))
# If source language is not English, first translate it to English to summarize
if src_lang != default_lang:
content = self.translate(
content, src_lang=src_lang, tgt_lang=default_lang
)
# Summarize the content
summarized_content = self.summarize(content, max_length=30, min_length=10)
# Translate the title and summarized content
translated_title = self.translate(
title, src_lang=src_lang, tgt_lang=tgt_lang
)
# Unless the target language is already the default, translate it
translated_content = (
self.translate(
summarized_content, src_lang=default_lang, tgt_lang=tgt_lang
)
if tgt_lang != default_lang
else summarized_content
)
# Update entry
entry.update(
{
"title": translated_title,
"content": translated_content,
"author": author,
"link": link,
}
)
return processed_entries
|