File size: 2,759 Bytes
30e18b9
716037e
af53a88
30e18b9
 
 
 
 
716037e
 
30e18b9
716037e
 
30e18b9
af53a88
30e18b9
 
 
 
 
 
71754ec
30e18b9
 
af53a88
30e18b9
 
af53a88
30e18b9
 
af53a88
30e18b9
 
 
 
af53a88
30e18b9
 
 
 
 
 
fce3660
af53a88
30e18b9
 
af53a88
30e18b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71754ec
 
30e18b9
71754ec
30e18b9
 
 
71754ec
30e18b9
 
 
71754ec
716037e
30e18b9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import logging
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gradio as gr
from typing import Optional

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load model and tokenizer
model_name = "google/flan-t5-base"
logger.info(f"Loading {model_name}...")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model loaded, using device: {device}")

# FastAPI app
app = FastAPI()

# Pydantic model for request validation
class SummarizationRequest(BaseModel):
    text: str
    max_length: Optional[int] = 150
    min_length: Optional[int] = 30

# Summarization function
def summarize_text(text, max_length=150, min_length=30):
    logger.info(f"Summarizing text of length {len(text)}")
    inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        min_length=min_length,
        length_penalty=2.0,
        num_beams=4,
        early_stopping=True
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    logger.info(f"Generated summary of length {len(summary)}")
    return summary

# REST API endpoint
@app.post("/summarize")
async def summarize(request: SummarizationRequest):
    try:
        summary = summarize_text(
            request.text,
            max_length=request.max_length,
            min_length=request.min_length
        )
        return {"summary": summary}
    except Exception as e:
        logger.error(f"Error in summarization: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Gradio interface
def gradio_summarize(text, max_length=150, min_length=30):
    return summarize_text(text, max_length, min_length)

demo = gr.Interface(
    fn=gradio_summarize,
    inputs=[
        gr.Textbox(lines=10, placeholder="Enter text to summarize..."),
        gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"),
        gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length")
    ],
    outputs="text",
    title="Text Summarization with FLAN-T5",
    description="This app summarizes text using Google's FLAN-T5 model."
)

# Mount the Gradio app at the root path
app = gr.mount_gradio_app(app, demo, path="/")

# Start the server
if __name__ == "__main__":
    import uvicorn
    # Start server with both FastAPI and Gradio
    uvicorn.run(app, host="0.0.0.0", port=7860)