t5-small_cnn / main.py
Curative's picture
Update main.py
5e41748 verified
raw
history blame
1.67 kB
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import threading
import uvicorn
# Load model & tokenizer
model_path = "./t5-summarizer"
tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# FastAPI setup
app = FastAPI()
class TextInput(BaseModel):
text: str
@app.post("/summarize/")
def summarize_text(input: TextInput):
inputs = tokenizer("summarize: " + input.text.replace("\n"," "),
return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(inputs.input_ids.to(device),
max_length=150, min_length=30,
length_penalty=2.0, num_beams=4, early_stopping=True)
return {"summary": tokenizer.decode(outputs[0], skip_special_tokens=True)}
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=8000)
# Gradio UI
iface = gr.Interface(
fn=lambda text: summarize_text(TextInput(text=text))["summary"],
inputs=gr.Textbox(lines=10, placeholder="Paste text here..."),
outputs=gr.Textbox(label="Summary"),
title="Text Summarizer",
description="Fine-tuned T5 summarizer",
flagging_mode="never", # Disable flagging
examples=[["Your example text here..."]] # Pre-load examples
)
# Start FastAPI in background, then launch Gradio
threading.Thread(target=run_fastapi, daemon=True).start()
iface.launch(server_name="0.0.0.0", server_port=7860)