Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import fitz | |
from langdetect import detect | |
import matplotlib.pyplot as plt | |
from collections import Counter | |
import re | |
import os | |
theme = gr.themes.Ocean( | |
primary_hue="sky", | |
secondary_hue="indigo", | |
neutral_hue="neutral", | |
) | |
MODEL_PATH = "hideosnes/Bart-T2T-Distill_GildaBot" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
def summarize(file, text, style, length): | |
text_input = "" | |
if file is not None: | |
if file.name.endswith(".pdf"): # and hasattr(file, 'read'): | |
try: | |
with fitz.open(file.name) as doc: #with fitz.open(stream=file.read(), filetype="pdf") as doc: | |
text_input = " ".join([page.get_text() for page in doc]) | |
except: | |
# Fallback if fitz can't read the file | |
text_input = "Error: Could not read PDF file" | |
elif file.name.endswith(".txt"): | |
try: | |
with open(file.name, 'r', encoding='utf-8') as f: | |
text_input = f.read() | |
except: | |
text_input = str(file) | |
elif text: | |
text_input = text | |
# If the input text is empty or contains only whitespace, | |
# return early with a user message and placeholder values. | |
if not text_input.strip(): | |
# Gradio expects the summarize() function to always return the same number of outputs, | |
# so we return a message for the first output (the summary box) and None for the rest. | |
# This ensures the UI remains consistent and doesn't break if the input is empty. | |
return "Maybe try uploading a file or typing some proper text?", None, None, None, None, None | |
# Language detection | |
try: | |
lang_code = detect(text_input) | |
except: | |
lang_code = "en" | |
# Length | |
max_token, min_token = ( | |
(100, 85) if length == "Short" else | |
(200, 185) if length == "Medium" else | |
(300, 285) | |
) | |
# System prompt based on language and style | |
prompt_map = { | |
"en": { | |
"Precise": "Summarize concisely in a dry, scientific and academic style:", | |
"Sloppy": "Summarize comprehensively for a child:", | |
"Keywords": "Write only a list of important keywords of main ideas separated by a comma:", | |
}#, <-- don't forget the comma!!!!! | |
#"foo": { "precise": "another language or prompt map could go here"} | |
} | |
prompt = prompt_map.get(lang_code, prompt_map["en"])[style] + " " + text_input | |
# Summarization | |
# Custom tokenizer: create a class with encode/decode methods following the HuggingFace | |
# tokenizer interface, or use the PreTrainedTokenizerFast class with your own | |
# vocab and pre-tokenization rules. | |
# Note: 1024 tokens typically correspond to about 750–800 English words, | |
# depending on the tokenizer and language. ---------------------------------------------- (!) | |
# Make sure to display this token/word information to the user in the app UI for clarity. | |
# Note: "pyTorchTensor" is not a valid TensorType, use one of ['pt', 'tf', 'np', 'jax', 'mlx'] | |
inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
inputs = inputs.to(device) | |
# the generated summary is not text yet but a tensor-array of token IDs | |
summary_ids = model.generate( | |
inputs, | |
max_length=max_token, | |
min_length=min_token, | |
length_penalty=2.0, | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=3 | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# removing the prompt text from the summary | |
prompt_text = prompt_map.get(lang_code, prompt_map["en"])[style] | |
if summary.startswith(prompt_text): | |
summary = summary[len(prompt_text):].strip() | |
if text_input in summary: | |
summary = summary.replace(text_input, "").strip() | |
# These lines calculate and store the word count of the original text, | |
# the word count of the summary, and the percentage reduction in length after summarization. | |
# Note: len() is a built-in Python function that returns the number of items in an object. | |
original_len = len(text_input.split()) | |
summary_len = len(summary.split()) | |
reduction = 100 - (summary_len / original_len * 100) | |
# Extracting the 5 most frequent words (longer than 3 characters) | |
# from the summary, treating them as keywords. | |
words = re.findall(r'\w+', summary.lower()) | |
keyword_counts = Counter(words).most_common(5) | |
keywords = [kw for kw, _ in keyword_counts if len(kw) > 5] # Changed from > 3 to > 5 | |
# Plot | |
fig, ax = plt.subplots() | |
ax.bar( | |
["Original", "Summary"], | |
[original_len, summary_len], | |
color=["indigo", "sky"] | |
) | |
ax.set_ylabel("Word Count") | |
return summary, ", ".join(keywords), original_len, summary_len, f"{reduction:.2f}%", fig | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("Summarizer (T2T)") | |
with gr.Row(): #left column | |
with gr.Column(): | |
file = gr.File(label="Upload a PDF file") | |
text = gr.Textbox(label="Paste text from clipboard.", lines=10) | |
with gr.Row(): # for inline horizontal layout | |
style = gr.Dropdown(["Precise", "Sloppy", "Keywords"], label="Style (experimental)") | |
length = gr.Dropdown(["Short", "Middle", "Long"], label="Length") | |
token_info = gr.Text(label="Speed", value="750-800 words (1024 tokens) will take 4 minutes (~4t/s).") | |
btn = gr.Button("Transform", variant="primary") | |
with gr.Column(): #right column | |
summary = gr.Textbox(label="Summary") | |
keywords = gr.Textbox(label="Keywords (experimental)") | |
original_len = gr.Number(label="Original Text Length") | |
summary_len = gr.Number(label="Summary Length") | |
reduction = gr.Textbox(label="Summary Efficiency") | |
plot = gr.Plot(label="Summary Statistics") | |
btn.click( | |
summarize, | |
inputs=[file, text, style, length], | |
outputs=[summary, keywords, original_len, summary_len, reduction, plot] | |
) | |
demo.launch(share=True) | |