File size: 3,408 Bytes
8dbb8ee
716037e
af53a88
30e18b9
 
 
 
11a16ac
716037e
8dbb8ee
 
716037e
30e18b9
716037e
 
8dbb8ee
 
30e18b9
8dbb8ee
30e18b9
 
 
 
8dbb8ee
 
 
 
 
 
 
11a16ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dbb8ee
11a16ac
8dbb8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
71754ec
af53a88
30e18b9
af53a88
30e18b9
 
af53a88
30e18b9
 
 
8dbb8ee
af53a88
30e18b9
 
 
8dbb8ee
 
 
 
 
fce3660
8dbb8ee
af53a88
30e18b9
 
af53a88
30e18b9
 
 
 
 
 
 
 
 
 
8dbb8ee
30e18b9
11a16ac
8dbb8ee
 
30e18b9
 
8dbb8ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sentencepiece  
import logging
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from contextlib import asynccontextmanager

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

model_name = "google/flan-t5-large"

# Load model and tokenizer
logger.info(f"Loading model: {model_name}")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model loaded on device: {device}")


class QuestionAnswerRequest(BaseModel):
    question: str
    context: str

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    global model, tokenizer
    try:
        logger.info(f"Loading model: {model_name}")
        tokenizer = T5Tokenizer.from_pretrained(model_name)
        model = T5ForConditionalGeneration.from_pretrained(model_name)
        model.to(device)
        logger.info(f"Model loaded on device: {device}")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise
    yield
    # Shutdown
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

app = FastAPI(lifespan=lifespan)

@app.post("/question-answer")
async def answer_question(request: QuestionAnswerRequest):
    try:
        input_text = f"question: {request.question} context: {request.context}"
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
        outputs = model.generate(
            inputs.input_ids, 
            max_length=64,
            num_beams=4,
            early_stopping=True
        )
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"answer": answer}
    except Exception as e:
        logger.error(f"QA error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


class SummarizationRequest(BaseModel):
    text: str
    max_length: Optional[int] = 150
    min_length: Optional[int] = 30

def summarize_text(text, max_length=150, min_length=30):
    logger.info(f"Summarizing text of length {len(text)}")
    inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)

    outputs = model.generate(
        inputs.input_ids,
        max_length=max_length,
        min_length=min_length,
        num_beams=6,
        repetition_penalty=2.0,
        length_penalty=1.0,
        early_stopping=True,
        no_repeat_ngram_size=3
    )

    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    logger.info(f"Generated summary of length {len(summary)}")
    return summary

@app.post("/summarize")
async def summarize(request: SummarizationRequest):
    try:
        summary = summarize_text(
            request.text,
            max_length=request.max_length,
            min_length=request.min_length
        )
        return {"summary": summary}
    except Exception as e:
        logger.error(f"Summarization error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
    
# ---------- Entry Point ----------

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)