Spaces:
Sleeping
Sleeping
import logging | |
import logging.config | |
from typing import Dict, List | |
import feedparser | |
import torch | |
from bs4 import BeautifulSoup | |
from functools import wraps | |
from time import time | |
from pydantic import HttpUrl | |
from transformers import ( | |
AutoConfig, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
pipeline, | |
) | |
from config import LANGUAGES, LANG_LEX_2_CODE | |
from logging_conf import LOGGING_CONFIG | |
logging.config.dictConfig(LOGGING_CONFIG) | |
logger = logging.getLogger("src.task_management") | |
def proc_timer(f): | |
def wrapper(*args, **kw): | |
ts = time() | |
result = f(*args, **kw) | |
te = time() | |
logger.info(f"func:{f.__name__} args:[{args}, {kw}] took: {te - ts}:%2.4f sec") | |
return result | |
return wrapper | |
class TaskManager: | |
"""TaskManager class managing the summarization, translation, | |
feed-parsing and other necessary processing tasks | |
""" | |
def __init__(self): | |
# The supported, by our application, translation languages | |
self.supported_langs = LANGUAGES.values() | |
# Load the bart-large-cnn model and tokenizer | |
summarization_model_name = "facebook/bart-large-cnn" | |
# Move model for summarization to GPU if available | |
# self.summarization_device = ( | |
# 0 if torch.cuda.is_available() else -1 | |
# ) # 0 for GPU, -1 for CPU | |
self.summarization_device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
self.summarization_config = AutoConfig.from_pretrained(summarization_model_name) | |
self.summarizer = AutoModelForSeq2SeqLM.from_pretrained( | |
summarization_model_name | |
).to(self.summarization_device) | |
self.summarization_tokenizer = AutoTokenizer.from_pretrained( | |
summarization_model_name | |
) | |
# Check if CUDA is available and set the device | |
self.translation_device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
# Load translation pipeline for model facebook/nllb-200-distilled-1.3B | |
self.translator = pipeline( | |
"translation", | |
model="facebook/nllb-200-distilled-1.3B", | |
device=self.translation_device, | |
) | |
# @proc_timer | |
def summarize( | |
self, txt_to_summarize: str, max_length: int = 30, min_length: int = 10 | |
) -> str: | |
"""Summarization task, used for summarizing the provided text | |
Args: | |
txt_to_summarize (str): the text that need to be summarized | |
max_length (int, optional): the max_length downlimit of the summarized text. Defaults to 30. | |
min_length (int, optional): the min_length downlimit of the summarized text. Defaults to 10. | |
Returns: | |
str: the summarized text | |
""" | |
full_text_length = len(txt_to_summarize) | |
# Adapt max and min lengths for summary, if larger than they should be | |
max_perc_init_length = round(full_text_length * 0.3) | |
max_length = ( | |
max_perc_init_length | |
if self.summarization_config.max_length > 0.5 * full_text_length | |
else max(max_length, self.summarization_config.max_length) | |
) | |
# Min length is the minimum of the following two: | |
# the min to max default config values factor, multiplied by real max | |
# the default config minimum value | |
min_to_max_perc = ( | |
self.summarization_config.min_length / self.summarization_config.max_length | |
) | |
min_length = min( | |
round(min_to_max_perc * max_length), self.summarization_config.min_length | |
) | |
# Tokenize input | |
inputs = self.summarization_tokenizer( | |
txt_to_summarize, return_tensors="pt", max_length=1024, truncation=True | |
).to(self.summarization_device) | |
# Generate summary with custom max_length | |
summary_ids = self.summarizer.generate( | |
inputs["input_ids"], | |
max_length=max_length, # Set max_length here | |
min_length=min_length, # Set min_length here | |
num_beams=4, # Optional: Use beam search | |
early_stopping=True, # Optional: Stop early if EOS is reached | |
) | |
# Decode the summary | |
summary_txt = self.summarization_tokenizer.decode( | |
summary_ids[0], skip_special_tokens=True | |
) | |
return summary_txt | |
# @proc_timer | |
def translate(self, txt_to_translate: str, src_lang: str, tgt_lang: str) -> str: | |
"""Translate the provided text from a source language to a target language | |
Args: | |
txt_to_translate (str): the text to translate | |
src_lang (str): the source language of the initial text | |
tgt_lang (str): the target language the initial text should be translated to | |
Raises: | |
RuntimeError: error in case of unsupported source language | |
RuntimeError: error in case of unsupported target language | |
RuntimeError: error in case of translation failure | |
Returns: | |
str: the translated text | |
""" | |
# Raise error in case of unsupported languages | |
if src_lang not in self.supported_langs: | |
raise RuntimeError("Unsupported source language.") | |
if tgt_lang not in self.supported_langs: | |
raise RuntimeError("Unsupported target language.") | |
# Translate the text using the NLLB model | |
src_lang = LANG_LEX_2_CODE.get(src_lang, src_lang) | |
tgt_lang = LANG_LEX_2_CODE.get(tgt_lang, tgt_lang) | |
translated_text = self.translator( | |
txt_to_translate, src_lang=src_lang, tgt_lang=tgt_lang, batch_size=10 | |
)[0]["translation_text"] | |
# If something goes wrong with the translation raise error | |
if len(translated_text) <= 0: | |
raise RuntimeError("Failed to generate translation.") | |
return translated_text | |
def parse_and_process_feed( | |
self, | |
rss_url: HttpUrl, | |
src_lang: str, | |
tgt_lang: str, | |
entries_limit: int = None, | |
) -> List[Dict]: | |
"""Parse the input feed, and process the feed entries keeping the important information, | |
summarizing and translating it | |
Args: | |
rss_url (HttpUrl): the feed url to parse | |
src_lang (str): the feed's initial language | |
tgt_lang (str): the target language to which the content will be translated | |
entries_limit (int, optional): the number of feed-entries to be processed. Defaults to None (process all). | |
Returns: | |
List[Dict]: a list of dictionaries, each one containing the processed info regarding | |
title, author, content and link for the respective feed entry | |
""" | |
src_lang = LANGUAGES.get(src_lang, src_lang) | |
tgt_lang = LANGUAGES.get(tgt_lang, tgt_lang) | |
default_lang = LANGUAGES.get("en", "en") | |
feed = feedparser.parse(rss_url) | |
# Return the maximum number of entries in case entries is None or exceeding entries length | |
processed_entries = feed.entries[:entries_limit] | |
# Iterate over each entry in the feed | |
for entry in processed_entries: | |
title = entry.get("title", "") | |
author = entry.get("author", "") | |
link = entry.get("link", "") | |
content = entry.get( | |
"summary", entry.get("content", entry.get("description", "")) | |
) | |
soup = BeautifulSoup(content, features="html.parser") | |
content = "".join(soup.findAll(text=True)) | |
# If source language is not English, first translate it to English to summarize | |
if src_lang != default_lang: | |
content = self.translate( | |
content, src_lang=src_lang, tgt_lang=default_lang | |
) | |
# Summarize the content | |
summarized_content = self.summarize(content, max_length=30, min_length=10) | |
# Translate the title and summarized content | |
translated_title = self.translate( | |
title, src_lang=src_lang, tgt_lang=tgt_lang | |
) | |
# Unless the target language is already the default, translate it | |
translated_content = ( | |
self.translate( | |
summarized_content, src_lang=default_lang, tgt_lang=tgt_lang | |
) | |
if tgt_lang != default_lang | |
else summarized_content | |
) | |
# Update entry | |
entry.update( | |
{ | |
"title": translated_title, | |
"content": translated_content, | |
"author": author, | |
"link": link, | |
} | |
) | |
return processed_entries | |