Studai / app.py
WolfeLeo2's picture
removed gradio and route mismatch fix
11a16ac
import sentencepiece
import logging
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from contextlib import asynccontextmanager
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model_name = "google/flan-t5-large"
# Load model and tokenizer
logger.info(f"Loading model: {model_name}")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model loaded on device: {device}")
class QuestionAnswerRequest(BaseModel):
question: str
context: str
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global model, tokenizer
try:
logger.info(f"Loading model: {model_name}")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to(device)
logger.info(f"Model loaded on device: {device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown
if torch.cuda.is_available():
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
@app.post("/question-answer")
async def answer_question(request: QuestionAnswerRequest):
try:
input_text = f"question: {request.question} context: {request.context}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
outputs = model.generate(
inputs.input_ids,
max_length=64,
num_beams=4,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"answer": answer}
except Exception as e:
logger.error(f"QA error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
class SummarizationRequest(BaseModel):
text: str
max_length: Optional[int] = 150
min_length: Optional[int] = 30
def summarize_text(text, max_length=150, min_length=30):
logger.info(f"Summarizing text of length {len(text)}")
inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
min_length=min_length,
num_beams=6,
repetition_penalty=2.0,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=3
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated summary of length {len(summary)}")
return summary
@app.post("/summarize")
async def summarize(request: SummarizationRequest):
try:
summary = summarize_text(
request.text,
max_length=request.max_length,
min_length=request.min_length
)
return {"summary": summary}
except Exception as e:
logger.error(f"Summarization error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# ---------- Entry Point ----------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)