Spaces:
Running
Running
File size: 6,587 Bytes
3a361ef e2e5815 3a361ef 2224978 8ef2c76 2224978 e2e5815 9192e24 3a361ef e2e5815 2224978 3a361ef d3959f2 3a361ef 3d68164 3a361ef 2224978 cf271ff 3a361ef d3959f2 3a361ef 8f28339 3a361ef 258987a 3a361ef e2e5815 3a361ef 3d68164 f976b0a 3d68164 f976b0a d3959f2 2ab6e28 6587e7e 3d68164 258987a 3d68164 3a361ef e2e5815 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import fitz
from langdetect import detect
import matplotlib.pyplot as plt
from collections import Counter
import re
import os
theme = gr.themes.Ocean(
primary_hue="sky",
secondary_hue="indigo",
neutral_hue="neutral",
)
MODEL_PATH = "hideosnes/Bart-T2T-Distill_GildaBot"
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
return tokenizer, model
tokenizer, model = load_model()
def summarize(file, text, style, length):
text_input = ""
if file is not None:
if file.name.endswith(".pdf"): # and hasattr(file, 'read'):
try:
with fitz.open(file.name) as doc: #with fitz.open(stream=file.read(), filetype="pdf") as doc:
text_input = " ".join([page.get_text() for page in doc])
except:
# Fallback if fitz can't read the file
text_input = "Error: Could not read PDF file"
elif file.name.endswith(".txt"):
try:
with open(file.name, 'r', encoding='utf-8') as f:
text_input = f.read()
except:
text_input = str(file)
elif text:
text_input = text
# If the input text is empty or contains only whitespace,
# return early with a user message and placeholder values.
if not text_input.strip():
# Gradio expects the summarize() function to always return the same number of outputs,
# so we return a message for the first output (the summary box) and None for the rest.
# This ensures the UI remains consistent and doesn't break if the input is empty.
return "Maybe try uploading a file or typing some proper text?", None, None, None, None, None
# Language detection
try:
lang_code = detect(text_input)
except:
lang_code = "en"
# Length
max_token, min_token = (
(100, 85) if length == "Short" else
(200, 185) if length == "Medium" else
(300, 285)
)
# System prompt based on language and style
prompt_map = {
"en": {
"Precise": "Summarize concisely in a dry, scientific and academic style:",
"Sloppy": "Summarize comprehensively for a child:",
"Keywords": "Write only a list of important keywords of main ideas separated by a comma:",
}#, <-- don't forget the comma!!!!!
#"foo": { "precise": "another language or prompt map could go here"}
}
prompt = prompt_map.get(lang_code, prompt_map["en"])[style] + " " + text_input
# Summarization
# Custom tokenizer: create a class with encode/decode methods following the HuggingFace
# tokenizer interface, or use the PreTrainedTokenizerFast class with your own
# vocab and pre-tokenization rules.
# Note: 1024 tokens typically correspond to about 750–800 English words,
# depending on the tokenizer and language. ---------------------------------------------- (!)
# Make sure to display this token/word information to the user in the app UI for clarity.
# Note: "pyTorchTensor" is not a valid TensorType, use one of ['pt', 'tf', 'np', 'jax', 'mlx']
inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=1024)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = inputs.to(device)
# the generated summary is not text yet but a tensor-array of token IDs
summary_ids = model.generate(
inputs,
max_length=max_token,
min_length=min_token,
length_penalty=2.0,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# removing the prompt text from the summary
prompt_text = prompt_map.get(lang_code, prompt_map["en"])[style]
if summary.startswith(prompt_text):
summary = summary[len(prompt_text):].strip()
if text_input in summary:
summary = summary.replace(text_input, "").strip()
# These lines calculate and store the word count of the original text,
# the word count of the summary, and the percentage reduction in length after summarization.
# Note: len() is a built-in Python function that returns the number of items in an object.
original_len = len(text_input.split())
summary_len = len(summary.split())
reduction = 100 - (summary_len / original_len * 100)
# Extracting the 5 most frequent words (longer than 3 characters)
# from the summary, treating them as keywords.
words = re.findall(r'\w+', summary.lower())
keyword_counts = Counter(words).most_common(5)
keywords = [kw for kw, _ in keyword_counts if len(kw) > 5] # Changed from > 3 to > 5
# Plot
fig, ax = plt.subplots()
ax.bar(
["Original", "Summary"],
[original_len, summary_len],
color=["indigo", "sky"]
)
ax.set_ylabel("Word Count")
return summary, ", ".join(keywords), original_len, summary_len, f"{reduction:.2f}%", fig
with gr.Blocks(theme=theme) as demo:
gr.Markdown("Summarizer (T2T)")
with gr.Row(): #left column
with gr.Column():
file = gr.File(label="Upload a PDF file")
text = gr.Textbox(label="Paste text from clipboard.", lines=10)
with gr.Row(): # for inline horizontal layout
style = gr.Dropdown(["Precise", "Sloppy", "Keywords"], label="Style (experimental)")
length = gr.Dropdown(["Short", "Middle", "Long"], label="Length")
token_info = gr.Text(label="Speed", value="750-800 words (1024 tokens) will take 4 minutes (~4t/s).")
btn = gr.Button("Transform", variant="primary")
with gr.Column(): #right column
summary = gr.Textbox(label="Summary")
keywords = gr.Textbox(label="Keywords (experimental)")
original_len = gr.Number(label="Original Text Length")
summary_len = gr.Number(label="Summary Length")
reduction = gr.Textbox(label="Summary Efficiency")
plot = gr.Plot(label="Summary Statistics")
btn.click(
summarize,
inputs=[file, text, style, length],
outputs=[summary, keywords, original_len, summary_len, reduction, plot]
)
demo.launch(share=True)
|