File size: 2,759 Bytes
30e18b9 716037e af53a88 30e18b9 716037e 30e18b9 716037e 30e18b9 af53a88 30e18b9 71754ec 30e18b9 af53a88 30e18b9 af53a88 30e18b9 af53a88 30e18b9 af53a88 30e18b9 fce3660 af53a88 30e18b9 af53a88 30e18b9 71754ec 30e18b9 71754ec 30e18b9 71754ec 30e18b9 71754ec 716037e 30e18b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import logging
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gradio as gr
from typing import Optional
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load model and tokenizer
model_name = "google/flan-t5-base"
logger.info(f"Loading {model_name}...")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model loaded, using device: {device}")
# FastAPI app
app = FastAPI()
# Pydantic model for request validation
class SummarizationRequest(BaseModel):
text: str
max_length: Optional[int] = 150
min_length: Optional[int] = 30
# Summarization function
def summarize_text(text, max_length=150, min_length=30):
logger.info(f"Summarizing text of length {len(text)}")
inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
min_length=min_length,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated summary of length {len(summary)}")
return summary
# REST API endpoint
@app.post("/summarize")
async def summarize(request: SummarizationRequest):
try:
summary = summarize_text(
request.text,
max_length=request.max_length,
min_length=request.min_length
)
return {"summary": summary}
except Exception as e:
logger.error(f"Error in summarization: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
def gradio_summarize(text, max_length=150, min_length=30):
return summarize_text(text, max_length, min_length)
demo = gr.Interface(
fn=gradio_summarize,
inputs=[
gr.Textbox(lines=10, placeholder="Enter text to summarize..."),
gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"),
gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length")
],
outputs="text",
title="Text Summarization with FLAN-T5",
description="This app summarizes text using Google's FLAN-T5 model."
)
# Mount the Gradio app at the root path
app = gr.mount_gradio_app(app, demo, path="/")
# Start the server
if __name__ == "__main__":
import uvicorn
# Start server with both FastAPI and Gradio
uvicorn.run(app, host="0.0.0.0", port=7860) |