WolfeLeo2 commited on
Commit
af53a88
·
1 Parent(s): fe123c2

change to fastAPI

Browse files
Files changed (1) hide show
  1. app.py +62 -22
app.py CHANGED
@@ -1,42 +1,82 @@
1
  import gradio as gr
2
  import logging
3
- import sys
4
- from transformers import pipeline
 
 
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
8
  logger = logging.getLogger(__name__)
9
 
10
- # Load the model
11
- logger.info("Loading bart-large-cnn model...")
12
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
 
13
  logger.info("Model loaded successfully!")
14
 
15
- def summarize_text(text, max_length=150, min_length=30):
16
- if not text or len(text.strip()) < 50:
17
- return text
18
-
19
- logger.info(f"Summarizing text of length {len(text)}")
20
- result = summarizer(
21
- text,
22
- max_length=max_length,
23
- min_length=min_length,
24
- truncation=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
- summary = result[0]["summary_text"]
27
- return summary
 
 
 
 
 
 
28
 
29
- # Create Gradio interface
30
  demo = gr.Interface(
31
  fn=summarize_text,
32
  inputs=[
33
  gr.Textbox(lines=10, label="Text to Summarize"),
34
- gr.Slider(50, 500, value=150, label="Max Length"),
35
- gr.Slider(10, 200, value=30, label="Min Length")
36
  ],
37
  outputs=gr.Textbox(label="Summary"),
38
  title="StudAI Text Summarization",
39
- description="Powered by facebook/bart-large-cnn model"
40
  )
41
 
42
- demo.launch(share=True)
 
 
 
 
1
  import gradio as gr
2
  import logging
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import torch
7
+ from fastapi.middleware.cors import CORSMiddleware
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Load FLAN-T5 model
14
+ model_name = "google/flan-t5-base"
15
+ logger.info(f"Loading {model_name} model...")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
  logger.info("Model loaded successfully!")
19
 
20
+ # -----------------------------
21
+ # REST API SECTION
22
+ # -----------------------------
23
+ api = FastAPI()
24
+
25
+ api.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"], # Change to your domain later
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ class SummarizeRequest(BaseModel):
34
+ text: str
35
+ max_length: int = 150
36
+ min_length: int = 30
37
+
38
+ @api.post("/summarize")
39
+ def summarize_endpoint(request: SummarizeRequest):
40
+ text = request.text.strip()
41
+ if not text or len(text) < 50:
42
+ return {"summary": text}
43
+
44
+ logger.info(f"Summarizing via API. Length: {len(text)}")
45
+
46
+ input_text = f"summarize: {text}"
47
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
48
+
49
+ # Safe dynamic length handling
50
+ max_tokens = min(request.max_length, 512)
51
+ min_tokens = min(request.min_length, max_tokens - 1)
52
+
53
+ outputs = model.generate(
54
+ **inputs,
55
+ max_new_tokens=max_tokens,
56
+ min_length=min_tokens
57
  )
58
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
+ return {"summary": summary}
60
+
61
+ # -----------------------------
62
+ # GRADIO UI SECTION
63
+ # -----------------------------
64
+ def summarize_text(text, max_length=150, min_length=30):
65
+ return summarize_endpoint(SummarizeRequest(text=text, max_length=max_length, min_length=min_length))["summary"]
66
 
 
67
  demo = gr.Interface(
68
  fn=summarize_text,
69
  inputs=[
70
  gr.Textbox(lines=10, label="Text to Summarize"),
71
+ gr.Slider(50, 512, value=150, label="Max Length"),
72
+ gr.Slider(10, 300, value=30, label="Min Length")
73
  ],
74
  outputs=gr.Textbox(label="Summary"),
75
  title="StudAI Text Summarization",
76
+ description="Powered by google/flan-t5-base model"
77
  )
78
 
79
+ # Mount Gradio + API
80
+ app = FastAPI()
81
+ app.mount("/", api)
82
+ demo.launch(server_name="0.0.0.0", server_port=7860, root_path="/", app=app)