Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import fitz | |
from langdetect import detect | |
import matplotlib.pyplot as plt | |
from collections import Counter | |
import re | |
import os | |
MODEL_PATH = "hideosnes/Bart-T2T-Distill_GildaBot" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
# ### st.title("PDF Summarizer") | |
# ### st.markdown("CPU-optimized model for text-to-text transformation (T2T), facilitating efficient and accurate language processing. Multi-lingual but target language is English. Please be gentle, it runs on CPU!") | |
def summarize(file, text, style, length): | |
text_input = "" | |
if file is not None: | |
if file.name.endswith(".pdf") and hasattr(file, 'read'): | |
with fitz.open(stream=file.read(), filetype="pdf") as doc: | |
text_input = " ".join([page.get_text() for page in doc]) | |
else: | |
# if file doesn't end with .pdf AND hasn't attribute 'read', | |
# then handle all other cases (TXT files, PDFs without .read(), etc.) | |
text_input = str(file) | |
elif text: | |
text_input = text | |
# If the input text is empty or contains only whitespace, | |
# return early with a user message and placeholder values. | |
if not text_input.strip(): | |
# Gradio expects the summarize() function to always return the same number of outputs, | |
# so we return a message for the first output (the summary box) and None for the rest. | |
# This ensures the UI remains consistent and doesn't break if the input is empty. | |
return "Maybe try uploading a file or typing some text?", None, None, None, None, None | |
# Language detection | |
try: | |
lang_code = detect(text_input) | |
except: | |
lang_code = "en" | |
# Length | |
max_token, min_token = ( | |
(100, 85) if length == "Short" else | |
(200, 185) if length == "Medium" else | |
(300, 285) | |
) | |
# System prompt based on language and style | |
prompt_map = { | |
"en": { | |
"Precise": "In English, distill the following text into a concise summary, utilizing formal and academic language to convey the essential information:", | |
"Sloppy": "In English, provide a brief and informal summary of the following text, using straightforward language to facilitate easy comprehension:", | |
"Keywords": "In English, condense the following text into a list of keywords, highlighting key points and main ideas in a clear and objective manner:", | |
}#, <-- don't forget the comma!!!!! | |
#"foo": { "precise": "another language or prompt map could go here"} | |
} | |
prompt = prompt_map.get(lang_code, prompt_map["en"])+[style] + " " + text_input | |
# Summarization | |
# Custom tokenizer: create a class with encode/decode methods following the HuggingFace | |
# tokenizer interface, or use the PreTrainedTokenizerFast class with your own | |
# vocab and pre-tokenization rules. | |
# Note: 1024 tokens typically correspond to about 750–800 English words, | |
# depending on the tokenizer and language. ---------------------------------------------- (!) | |
# Make sure to display this token/word information to the user in the app UI for clarity. | |
inputs = tokenizer.encode(prompt, return_tensors="pyTorchTensor", truncation=True, max_length=1024) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
inputs = inputs.to(device) | |
# the generated summary is not text yet but a tensor-array of token IDs | |
summary_ids = model.generate( | |
inputs, | |
max_length=max_token, | |
min_length=min_token, | |
length_penalty=2.0, | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=3 | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# These lines calculate and store the word count of the original text, | |
# the word count of the summary, and the percentage reduction in length after summarization. | |
# Note: len() is a built-in Python function that returns the number of items in an object. | |
original_len = len(text_input.split()) | |
summary_len = len(summary.split()) | |
reduction = 100 - (summary_len / original_len * 100) | |
# Extracting the 5 most frequent words (longer than 3 characters) | |
# from the summary, treating them as keywords. | |
words = re.findall(r'\w+', summary.lower()) | |
keyword_counts = Counter(words).most_common(5) | |
keywords = [kw for kw, _ in keyword_counts if len(kw) > 3] | |
# Plot | |
fig, ax = plt.subplots() | |
ax.bar( | |
["Original", "Summary"], | |
[original_len, summary_len], | |
color=["coral", "purple"] | |
) | |
ax.set_ylabel("Word Count") | |
return summary, ", ".join(keywords), original_len, summary_len, f"{reduction:.2f}%", fig | |
with gr.Blocks() as app: | |
gr.Markdown("Summarizer (T2T)") | |
file = gr.File(label="Upload a PDF or a TXT file") | |
text = gr.Textbox(label="Paste text from clipboard.", lines=10) | |
style = gr.Dropdown(["Precise", "Sloppy", "Keywords"], label="Style") | |
length = gr.Radio(["Short", "Middle", "Long"], label="Length") | |
btn = gr.Button("Transform") | |
summary = gr.Textbox(label="Summary") | |
keywords = gr.Textbox(label="Important Keywords") | |
original_len = gr.Number(label="Original Text Length") | |
summary_len = gr.Number(label="Summary Length") | |
reduction = gr.Textbox(label="Summary Efficiency") | |
plot = gr.Plot(label="Summary Statistics") | |
btn.click( | |
summarize, | |
inputs=[file, text, style, length], | |
outputs=[summary, keywords, original_len, summary_len, reduction, plot] | |
) | |
app.launch() |