File size: 3,483 Bytes
8dbb8ee
716037e
af53a88
30e18b9
 
 
 
 
716037e
8dbb8ee
 
716037e
30e18b9
716037e
 
8dbb8ee
 
30e18b9
8dbb8ee
30e18b9
 
 
 
8dbb8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71754ec
af53a88
30e18b9
af53a88
30e18b9
 
af53a88
30e18b9
 
 
8dbb8ee
af53a88
30e18b9
 
 
8dbb8ee
 
 
 
 
fce3660
8dbb8ee
af53a88
30e18b9
 
af53a88
30e18b9
 
 
 
 
 
 
 
 
 
8dbb8ee
30e18b9
 
8dbb8ee
 
30e18b9
 
71754ec
 
30e18b9
71754ec
30e18b9
 
 
71754ec
30e18b9
 
 
71754ec
716037e
30e18b9
 
 
8dbb8ee
 
30e18b9
 
8dbb8ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sentencepiece  
import logging
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gradio as gr
from typing import Optional

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

model_name = "google/flan-t5-large"

# Load model and tokenizer
logger.info(f"Loading model: {model_name}")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model loaded on device: {device}")


class QuestionAnswerRequest(BaseModel):
    question: str
    context: str

@app.post("/question-answer")
def answer_question(request: QuestionAnswerRequest):
    try:
        input_text = f"question: {request.question} context: {request.context}"
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
        outputs = model.generate(
            inputs.input_ids, 
            max_length=64,
            num_beams=4,
            early_stopping=True
        )
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"answer": answer}
    except Exception as e:
        logger.error(f"QA error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


class SummarizationRequest(BaseModel):
    text: str
    max_length: Optional[int] = 150
    min_length: Optional[int] = 30

def summarize_text(text, max_length=150, min_length=30):
    logger.info(f"Summarizing text of length {len(text)}")
    inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)

    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        min_length=min_length,
        num_beams=6,
        repetition_penalty=2.0,
        length_penalty=1.0,
        early_stopping=True,
        no_repeat_ngram_size=3
    )

    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    logger.info(f"Generated summary of length {len(summary)}")
    return summary

@app.post("/summarize")
async def summarize(request: SummarizationRequest):
    try:
        summary = summarize_text(
            request.text,
            max_length=request.max_length,
            min_length=request.min_length
        )
        return {"summary": summary}
    except Exception as e:
        logger.error(f"Summarization error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# ---------- Gradio Interface ----------

def gradio_summarize(text, max_length=150, min_length=30):
    return summarize_text(text, max_length, min_length)

demo = gr.Interface(
    fn=gradio_summarize,
    inputs=[
        gr.Textbox(lines=10, placeholder="Enter text to summarize..."),
        gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"),
        gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length")
    ],
    outputs="text",
    title="Text Summarization with FLAN-T5",
    description="This app summarizes text using Google's FLAN-T5 model."
)

# Mount the Gradio app at the root path
app = gr.mount_gradio_app(app, demo, path="/")

# ---------- Entry Point ----------

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)