t5-small_cnn / main.py
Curative's picture
Update main.py
222eba8 verified
raw
history blame
2.02 kB
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import threading
import uvicorn
# 1. Load model & tokenizer
model_path = "./t5-summarizer"
tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 2. FastAPI setup
app = FastAPI()
class TextInput(BaseModel):
text: str
@app.post("/summarize/")
def summarize_text(input: TextInput):
inputs = tokenizer(
"summarize: " + input.text.replace("\n", " "),
return_tensors="pt",
max_length=512,
truncation=True
).to(device)
summary_ids = model.generate(
inputs.input_ids,
max_length=150,
min_length=30,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=8000)
# 3. Gradio UI
def summarize_ui(text):
return summarize_text(TextInput(text=text))["summary"]
iface = gr.Interface(
fn=summarize_ui,
inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
outputs=gr.Textbox(label="Summary"),
title="Text Summarizer",
description="Fine-tuned T5 summarizer on CNN/DailyMail v3.0.0",
examples=[
["Scientists have recently discovered a new species of frog in the Amazon rainforest..."],
["The global economy is expected to grow at a slower pace this year..."],
["In a thrilling final match, the underdog team scored a last-minute goal..."]
],
allow_flagging="never" # Disable flagging properly :contentReference[oaicite:3]{index=3}
)
# 4. Run both servers
threading.Thread(target=run_fastapi, daemon=True).start()
iface.launch(server_name="0.0.0.0", server_port=7860)