Curative commited on
Commit
222eba8
·
verified ·
1 Parent(s): 5e41748

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -15
main.py CHANGED
@@ -6,41 +6,58 @@ import torch
6
  import threading
7
  import uvicorn
8
 
9
- # Load model & tokenizer
10
  model_path = "./t5-summarizer"
11
  tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
12
  model = T5ForConditionalGeneration.from_pretrained(model_path)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
- # FastAPI setup
17
  app = FastAPI()
18
  class TextInput(BaseModel):
19
  text: str
20
 
21
  @app.post("/summarize/")
22
  def summarize_text(input: TextInput):
23
- inputs = tokenizer("summarize: " + input.text.replace("\n"," "),
24
- return_tensors="pt", max_length=512, truncation=True)
25
- outputs = model.generate(inputs.input_ids.to(device),
26
- max_length=150, min_length=30,
27
- length_penalty=2.0, num_beams=4, early_stopping=True)
28
- return {"summary": tokenizer.decode(outputs[0], skip_special_tokens=True)}
 
 
 
 
 
 
 
 
 
 
29
 
30
  def run_fastapi():
31
  uvicorn.run(app, host="0.0.0.0", port=8000)
32
 
33
- # Gradio UI
 
 
 
34
  iface = gr.Interface(
35
- fn=lambda text: summarize_text(TextInput(text=text))["summary"],
36
- inputs=gr.Textbox(lines=10, placeholder="Paste text here..."),
37
  outputs=gr.Textbox(label="Summary"),
38
  title="Text Summarizer",
39
- description="Fine-tuned T5 summarizer",
40
- flagging_mode="never", # Disable flagging
41
- examples=[["Your example text here..."]] # Pre-load examples
 
 
 
 
42
  )
43
 
44
- # Start FastAPI in background, then launch Gradio
45
  threading.Thread(target=run_fastapi, daemon=True).start()
46
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
6
  import threading
7
  import uvicorn
8
 
9
+ # 1. Load model & tokenizer
10
  model_path = "./t5-summarizer"
11
  tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
12
  model = T5ForConditionalGeneration.from_pretrained(model_path)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
+ # 2. FastAPI setup
17
  app = FastAPI()
18
  class TextInput(BaseModel):
19
  text: str
20
 
21
  @app.post("/summarize/")
22
  def summarize_text(input: TextInput):
23
+ inputs = tokenizer(
24
+ "summarize: " + input.text.replace("\n", " "),
25
+ return_tensors="pt",
26
+ max_length=512,
27
+ truncation=True
28
+ ).to(device)
29
+ summary_ids = model.generate(
30
+ inputs.input_ids,
31
+ max_length=150,
32
+ min_length=30,
33
+ length_penalty=2.0,
34
+ num_beams=4,
35
+ early_stopping=True
36
+ )
37
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
38
+ return {"summary": summary}
39
 
40
  def run_fastapi():
41
  uvicorn.run(app, host="0.0.0.0", port=8000)
42
 
43
+ # 3. Gradio UI
44
+ def summarize_ui(text):
45
+ return summarize_text(TextInput(text=text))["summary"]
46
+
47
  iface = gr.Interface(
48
+ fn=summarize_ui,
49
+ inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
50
  outputs=gr.Textbox(label="Summary"),
51
  title="Text Summarizer",
52
+ description="Fine-tuned T5 summarizer on CNN/DailyMail v3.0.0",
53
+ examples=[
54
+ ["Scientists have recently discovered a new species of frog in the Amazon rainforest..."],
55
+ ["The global economy is expected to grow at a slower pace this year..."],
56
+ ["In a thrilling final match, the underdog team scored a last-minute goal..."]
57
+ ],
58
+ allow_flagging="never" # Disable flagging properly :contentReference[oaicite:3]{index=3}
59
  )
60
 
61
+ # 4. Run both servers
62
  threading.Thread(target=run_fastapi, daemon=True).start()
63
  iface.launch(server_name="0.0.0.0", server_port=7860)