WolfeLeo2 commited on
Commit
8dbb8ee
·
1 Parent(s): e943634

changed to flan t5 large

Browse files
Files changed (1) hide show
  1. app.py +42 -17
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import sentencepiece
2
  import logging
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
@@ -7,45 +7,69 @@ from pydantic import BaseModel
7
  import gradio as gr
8
  from typing import Optional
9
 
 
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
 
14
  # Load model and tokenizer
15
- model_name = "google/flan-t5-base"
16
- logger.info(f"Loading {model_name}...")
17
  tokenizer = T5Tokenizer.from_pretrained(model_name)
18
  model = T5ForConditionalGeneration.from_pretrained(model_name)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
- logger.info(f"Model loaded, using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # FastAPI app
24
- app = FastAPI()
25
 
26
- # Pydantic model for request validation
27
  class SummarizationRequest(BaseModel):
28
  text: str
29
  max_length: Optional[int] = 150
30
  min_length: Optional[int] = 30
31
 
32
- # Summarization function
33
  def summarize_text(text, max_length=150, min_length=30):
34
  logger.info(f"Summarizing text of length {len(text)}")
35
  inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
 
36
  outputs = model.generate(
37
  inputs.input_ids,
38
  max_length=max_length,
39
  min_length=min_length,
40
- length_penalty=2.0,
41
- num_beams=4,
42
- early_stopping=True
 
 
43
  )
 
44
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
  logger.info(f"Generated summary of length {len(summary)}")
46
  return summary
47
 
48
- # REST API endpoint
49
  @app.post("/summarize")
50
  async def summarize(request: SummarizationRequest):
51
  try:
@@ -56,10 +80,11 @@ async def summarize(request: SummarizationRequest):
56
  )
57
  return {"summary": summary}
58
  except Exception as e:
59
- logger.error(f"Error in summarization: {str(e)}")
60
  raise HTTPException(status_code=500, detail=str(e))
61
 
62
- # Gradio interface
 
63
  def gradio_summarize(text, max_length=150, min_length=30):
64
  return summarize_text(text, max_length, min_length)
65
 
@@ -78,8 +103,8 @@ demo = gr.Interface(
78
  # Mount the Gradio app at the root path
79
  app = gr.mount_gradio_app(app, demo, path="/")
80
 
81
- # Start the server
 
82
  if __name__ == "__main__":
83
  import uvicorn
84
- # Start server with both FastAPI and Gradio
85
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import sentencepiece
2
  import logging
3
  import torch
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
7
  import gradio as gr
8
  from typing import Optional
9
 
10
+ app = FastAPI()
11
+
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ model_name = "google/flan-t5-large"
17
+
18
  # Load model and tokenizer
19
+ logger.info(f"Loading model: {model_name}")
 
20
  tokenizer = T5Tokenizer.from_pretrained(model_name)
21
  model = T5ForConditionalGeneration.from_pretrained(model_name)
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  model.to(device)
24
+ logger.info(f"Model loaded on device: {device}")
25
+
26
+
27
+ class QuestionAnswerRequest(BaseModel):
28
+ question: str
29
+ context: str
30
+
31
+ @app.post("/question-answer")
32
+ def answer_question(request: QuestionAnswerRequest):
33
+ try:
34
+ input_text = f"question: {request.question} context: {request.context}"
35
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
36
+ outputs = model.generate(
37
+ inputs.input_ids,
38
+ max_length=64,
39
+ num_beams=4,
40
+ early_stopping=True
41
+ )
42
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ return {"answer": answer}
44
+ except Exception as e:
45
+ logger.error(f"QA error: {str(e)}")
46
+ raise HTTPException(status_code=500, detail=str(e))
47
 
 
 
48
 
 
49
  class SummarizationRequest(BaseModel):
50
  text: str
51
  max_length: Optional[int] = 150
52
  min_length: Optional[int] = 30
53
 
 
54
  def summarize_text(text, max_length=150, min_length=30):
55
  logger.info(f"Summarizing text of length {len(text)}")
56
  inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
57
+
58
  outputs = model.generate(
59
  inputs.input_ids,
60
  max_length=max_length,
61
  min_length=min_length,
62
+ num_beams=6,
63
+ repetition_penalty=2.0,
64
+ length_penalty=1.0,
65
+ early_stopping=True,
66
+ no_repeat_ngram_size=3
67
  )
68
+
69
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
  logger.info(f"Generated summary of length {len(summary)}")
71
  return summary
72
 
 
73
  @app.post("/summarize")
74
  async def summarize(request: SummarizationRequest):
75
  try:
 
80
  )
81
  return {"summary": summary}
82
  except Exception as e:
83
+ logger.error(f"Summarization error: {str(e)}")
84
  raise HTTPException(status_code=500, detail=str(e))
85
 
86
+ # ---------- Gradio Interface ----------
87
+
88
  def gradio_summarize(text, max_length=150, min_length=30):
89
  return summarize_text(text, max_length, min_length)
90
 
 
103
  # Mount the Gradio app at the root path
104
  app = gr.mount_gradio_app(app, demo, path="/")
105
 
106
+ # ---------- Entry Point ----------
107
+
108
  if __name__ == "__main__":
109
  import uvicorn
110
+ uvicorn.run(app, host="0.0.0.0", port=7860)