WolfeLeo2 commited on
Commit
11a16ac
·
1 Parent(s): 8dbb8ee

removed gradio and route mismatch fix

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
- import gradio as gr
8
  from typing import Optional
 
9
 
10
  app = FastAPI()
11
 
@@ -28,8 +28,28 @@ class QuestionAnswerRequest(BaseModel):
28
  question: str
29
  context: str
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @app.post("/question-answer")
32
- def answer_question(request: QuestionAnswerRequest):
33
  try:
34
  input_text = f"question: {request.question} context: {request.context}"
35
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
@@ -82,27 +102,7 @@ async def summarize(request: SummarizationRequest):
82
  except Exception as e:
83
  logger.error(f"Summarization error: {str(e)}")
84
  raise HTTPException(status_code=500, detail=str(e))
85
-
86
- # ---------- Gradio Interface ----------
87
-
88
- def gradio_summarize(text, max_length=150, min_length=30):
89
- return summarize_text(text, max_length, min_length)
90
-
91
- demo = gr.Interface(
92
- fn=gradio_summarize,
93
- inputs=[
94
- gr.Textbox(lines=10, placeholder="Enter text to summarize..."),
95
- gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"),
96
- gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length")
97
- ],
98
- outputs="text",
99
- title="Text Summarization with FLAN-T5",
100
- description="This app summarizes text using Google's FLAN-T5 model."
101
- )
102
-
103
- # Mount the Gradio app at the root path
104
- app = gr.mount_gradio_app(app, demo, path="/")
105
-
106
  # ---------- Entry Point ----------
107
 
108
  if __name__ == "__main__":
 
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
 
7
  from typing import Optional
8
+ from contextlib import asynccontextmanager
9
 
10
  app = FastAPI()
11
 
 
28
  question: str
29
  context: str
30
 
31
+ @asynccontextmanager
32
+ async def lifespan(app: FastAPI):
33
+ # Startup
34
+ global model, tokenizer
35
+ try:
36
+ logger.info(f"Loading model: {model_name}")
37
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
38
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
39
+ model.to(device)
40
+ logger.info(f"Model loaded on device: {device}")
41
+ except Exception as e:
42
+ logger.error(f"Failed to load model: {e}")
43
+ raise
44
+ yield
45
+ # Shutdown
46
+ if torch.cuda.is_available():
47
+ torch.cuda.empty_cache()
48
+
49
+ app = FastAPI(lifespan=lifespan)
50
+
51
  @app.post("/question-answer")
52
+ async def answer_question(request: QuestionAnswerRequest):
53
  try:
54
  input_text = f"question: {request.question} context: {request.context}"
55
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
 
102
  except Exception as e:
103
  logger.error(f"Summarization error: {str(e)}")
104
  raise HTTPException(status_code=500, detail=str(e))
105
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # ---------- Entry Point ----------
107
 
108
  if __name__ == "__main__":