|
import os |
|
import sentencepiece as spm |
|
import logging |
|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import gradio as gr |
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_name = "google/flan-t5-base" |
|
logger.info(f"Loading {model_name}...") |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
logger.info(f"Model loaded, using device: {device}") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class SummarizationRequest(BaseModel): |
|
text: str |
|
max_length: Optional[int] = 150 |
|
min_length: Optional[int] = 30 |
|
|
|
|
|
def summarize_text(text, max_length=150, min_length=30): |
|
logger.info(f"Summarizing text of length {len(text)}") |
|
inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device) |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_length=max_length, |
|
min_length=min_length, |
|
length_penalty=2.0, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
logger.info(f"Generated summary of length {len(summary)}") |
|
return summary |
|
|
|
|
|
@app.post("/summarize") |
|
async def summarize(request: SummarizationRequest): |
|
try: |
|
summary = summarize_text( |
|
request.text, |
|
max_length=request.max_length, |
|
min_length=request.min_length |
|
) |
|
return {"summary": summary} |
|
except Exception as e: |
|
logger.error(f"Error in summarization: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
def gradio_summarize(text, max_length=150, min_length=30): |
|
return summarize_text(text, max_length, min_length) |
|
|
|
demo = gr.Interface( |
|
fn=gradio_summarize, |
|
inputs=[ |
|
gr.Textbox(lines=10, placeholder="Enter text to summarize..."), |
|
gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"), |
|
gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length") |
|
], |
|
outputs="text", |
|
title="Text Summarization with FLAN-T5", |
|
description="This app summarizes text using Google's FLAN-T5 model." |
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |