Curative's picture
Update app.py
0e8551f verified
raw
history blame
4.14 kB
import gradio as gr
from transformers import pipeline, AutoTokenizer
import torch
# —— Lazy‑loaded pipelines & tokenizers —— #
summarizer = sentiment = ner = classifier = None
ner_tokenizer = None
def get_summarizer():
global summarizer
if summarizer is None:
summarizer = pipeline(
"summarization",
model="Curative/t5-summarizer-cnn",
framework="pt"
)
return summarizer
def get_sentiment():
global sentiment
if sentiment is None:
sentiment = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
framework="pt"
)
return sentiment
def get_classifier():
global classifier
if classifier is None:
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
framework="pt"
)
return classifier
def get_ner():
global ner, ner_tokenizer
if ner is None:
# Load Fast tokenizer explicitly for proper aggregation
ner_tokenizer = AutoTokenizer.from_pretrained(
"elastic/distilbert-base-uncased-finetuned-conll03-english",
use_fast=True
)
ner = pipeline(
"ner",
model="elastic/distilbert-base-uncased-finetuned-conll03-english",
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
framework="pt"
)
return ner
# —— Helper functions —— #
def chunk_and_summarize(text: str) -> str:
"""Split on sentences into ≤1,000 char chunks, summarize each, then join."""
summarizer = get_summarizer()
max_chunk = 1000
sentences = text.split(". ")
chunks, current = [], ""
for sent in sentences:
# +2 accounts for the period and space
if len(current) + len(sent) + 2 <= max_chunk:
current += sent + ". "
else:
chunks.append(current.strip())
current = sent + ". "
if current:
chunks.append(current.strip())
summaries = []
for chunk in chunks:
part = summarizer(
chunk,
max_length=150,
min_length=40,
do_sample=False
)[0]["summary_text"]
summaries.append(part)
return " ".join(summaries)
def merge_entities(ents):
"""Merge sub‑word tokens (##…) into full words."""
merged = []
for e in ents:
w, t = e["word"], e["entity_group"]
if w.startswith("##") and merged:
merged[-1]["word"] += w.replace("##", "")
else:
merged.append({"word": w, "type": t})
return merged
def process(text, features):
out = {}
if "Summarization" in features:
out["summary"] = chunk_and_summarize(text) # :contentReference[oaicite:7]{index=7}
if "Sentiment" in features:
s = get_sentiment()(text)[0]
out["sentiment"] = {"label": s["label"], "score": s["score"]}
if "Classification" in features:
labels = ["technology","sports","business","politics",
"health","science","travel","entertainment"]
cls = get_classifier()(text, candidate_labels=labels)
# Zip & sort
pairs = sorted(
zip(cls["labels"], cls["scores"]),
key=lambda x: x[1],
reverse=True
)
out["classification"] = [
{"label": lbl, "score": scr} for lbl, scr in pairs
]
if "Entities" in features:
ents = get_ner()(text)
out["entities"] = merge_entities(ents) # :contentReference[oaicite:8]{index=8}
return out
# —— Gradio UI —— #
with gr.Blocks() as demo:
gr.Markdown("## 🛠️ Multi‑Feature NLP Service")
inp = gr.Textbox(lines=8, placeholder="Enter your text here…")
feats = gr.CheckboxGroup(
["Summarization","Sentiment","Classification","Entities"],
label="Select features to run"
)
btn = gr.Button("Run")
out = gr.JSON(label="Results")
btn.click(process, [inp, feats], out)
demo.queue(api_open=True).launch()