Studai / app.py
WolfeLeo2's picture
changed to flan t5 large
8dbb8ee
raw
history blame
3.48 kB
import sentencepiece
import logging
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gradio as gr
from typing import Optional
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model_name = "google/flan-t5-large"
# Load model and tokenizer
logger.info(f"Loading model: {model_name}")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model loaded on device: {device}")
class QuestionAnswerRequest(BaseModel):
question: str
context: str
@app.post("/question-answer")
def answer_question(request: QuestionAnswerRequest):
try:
input_text = f"question: {request.question} context: {request.context}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
outputs = model.generate(
inputs.input_ids,
max_length=64,
num_beams=4,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"answer": answer}
except Exception as e:
logger.error(f"QA error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
class SummarizationRequest(BaseModel):
text: str
max_length: Optional[int] = 150
min_length: Optional[int] = 30
def summarize_text(text, max_length=150, min_length=30):
logger.info(f"Summarizing text of length {len(text)}")
inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
min_length=min_length,
num_beams=6,
repetition_penalty=2.0,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=3
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated summary of length {len(summary)}")
return summary
@app.post("/summarize")
async def summarize(request: SummarizationRequest):
try:
summary = summarize_text(
request.text,
max_length=request.max_length,
min_length=request.min_length
)
return {"summary": summary}
except Exception as e:
logger.error(f"Summarization error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# ---------- Gradio Interface ----------
def gradio_summarize(text, max_length=150, min_length=30):
return summarize_text(text, max_length, min_length)
demo = gr.Interface(
fn=gradio_summarize,
inputs=[
gr.Textbox(lines=10, placeholder="Enter text to summarize..."),
gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"),
gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length")
],
outputs="text",
title="Text Summarization with FLAN-T5",
description="This app summarizes text using Google's FLAN-T5 model."
)
# Mount the Gradio app at the root path
app = gr.mount_gradio_app(app, demo, path="/")
# ---------- Entry Point ----------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)