WolfeLeo2 commited on
Commit
30e18b9
·
1 Parent(s): ed4fa64

fixed app=app

Browse files
Files changed (1) hide show
  1. app.py +64 -61
app.py CHANGED
@@ -1,82 +1,85 @@
1
- import gradio as gr
2
  import logging
3
- from fastapi import FastAPI
4
- from pydantic import BaseModel
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
- from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
8
 
9
  # Configure logging
10
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
  logger = logging.getLogger(__name__)
12
 
13
- # Load FLAN-T5 model
14
  model_name = "google/flan-t5-base"
15
- logger.info(f"Loading {model_name} model...")
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
- logger.info("Model loaded successfully!")
 
 
19
 
20
- # -----------------------------
21
- # REST API SECTION
22
- # -----------------------------
23
- api = FastAPI()
24
-
25
- api.add_middleware(
26
- CORSMiddleware,
27
- allow_origins=["*"], # Change to your domain later
28
- allow_credentials=True,
29
- allow_methods=["*"],
30
- allow_headers=["*"],
31
- )
32
 
33
- class SummarizeRequest(BaseModel):
 
34
  text: str
35
- max_length: int = 150
36
- min_length: int = 30
37
-
38
- @api.post("/summarize")
39
- def summarize_endpoint(request: SummarizeRequest):
40
- text = request.text.strip()
41
- if not text or len(text) < 50:
42
- return {"summary": text}
43
-
44
- logger.info(f"Summarizing via API. Length: {len(text)}")
45
-
46
- input_text = f"summarize: {text}"
47
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
48
-
49
- # Safe dynamic length handling
50
- max_tokens = min(request.max_length, 512)
51
- min_tokens = min(request.min_length, max_tokens - 1)
52
 
 
 
 
 
53
  outputs = model.generate(
54
- **inputs,
55
- max_new_tokens=max_tokens,
56
- min_length=min_tokens
 
 
 
57
  )
58
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
- return {"summary": summary}
 
60
 
61
- # -----------------------------
62
- # GRADIO UI SECTION
63
- # -----------------------------
64
- def summarize_text(text, max_length=150, min_length=30):
65
- return summarize_endpoint(SummarizeRequest(text=text, max_length=max_length, min_length=min_length))["summary"]
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  demo = gr.Interface(
68
- fn=summarize_text,
69
  inputs=[
70
- gr.Textbox(lines=10, label="Text to Summarize"),
71
- gr.Slider(50, 512, value=150, label="Max Length"),
72
- gr.Slider(10, 300, value=30, label="Min Length")
73
  ],
74
- outputs=gr.Textbox(label="Summary"),
75
- title="StudAI Text Summarization",
76
- description="Powered by google/flan-t5-base model"
77
  )
78
 
79
- # Mount Gradio + API
80
- app = FastAPI()
81
- app.mount("/", api)
82
- demo.launch(server_name="0.0.0.0", server_port=7860, root_path="/", app=app)
 
 
 
 
 
1
+ import os
2
  import logging
 
 
 
3
  import torch
4
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+ import gradio as gr
8
+ from typing import Optional
9
 
10
  # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Load model and tokenizer
15
  model_name = "google/flan-t5-base"
16
+ logger.info(f"Loading {model_name}...")
17
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
18
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model.to(device)
21
+ logger.info(f"Model loaded, using device: {device}")
22
 
23
+ # FastAPI app
24
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Pydantic model for request validation
27
+ class SummarizationRequest(BaseModel):
28
  text: str
29
+ max_length: Optional[int] = 150
30
+ min_length: Optional[int] = 30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Summarization function
33
+ def summarize_text(text, max_length=150, min_length=30):
34
+ logger.info(f"Summarizing text of length {len(text)}")
35
+ inputs = tokenizer("summarize: " + text, return_tensors="pt", truncation=True, max_length=512).to(device)
36
  outputs = model.generate(
37
+ inputs.input_ids,
38
+ max_length=max_length,
39
+ min_length=min_length,
40
+ length_penalty=2.0,
41
+ num_beams=4,
42
+ early_stopping=True
43
  )
44
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ logger.info(f"Generated summary of length {len(summary)}")
46
+ return summary
47
 
48
+ # REST API endpoint
49
+ @app.post("/summarize")
50
+ async def summarize(request: SummarizationRequest):
51
+ try:
52
+ summary = summarize_text(
53
+ request.text,
54
+ max_length=request.max_length,
55
+ min_length=request.min_length
56
+ )
57
+ return {"summary": summary}
58
+ except Exception as e:
59
+ logger.error(f"Error in summarization: {str(e)}")
60
+ raise HTTPException(status_code=500, detail=str(e))
61
+
62
+ # Gradio interface
63
+ def gradio_summarize(text, max_length=150, min_length=30):
64
+ return summarize_text(text, max_length, min_length)
65
 
66
  demo = gr.Interface(
67
+ fn=gradio_summarize,
68
  inputs=[
69
+ gr.Textbox(lines=10, placeholder="Enter text to summarize..."),
70
+ gr.Slider(minimum=50, maximum=200, value=150, step=10, label="Maximum Length"),
71
+ gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Minimum Length")
72
  ],
73
+ outputs="text",
74
+ title="Text Summarization with FLAN-T5",
75
+ description="This app summarizes text using Google's FLAN-T5 model."
76
  )
77
 
78
+ # Mount the Gradio app at the root path
79
+ app = gr.mount_gradio_app(app, demo, path="/")
80
+
81
+ # Start the server
82
+ if __name__ == "__main__":
83
+ import uvicorn
84
+ # Start server with both FastAPI and Gradio
85
+ uvicorn.run(app, host="0.0.0.0", port=7860)