Curative commited on
Commit
296695c
·
verified ·
1 Parent(s): 2b58f82

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -2
main.py CHANGED
@@ -2,6 +2,9 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import T5ForConditionalGeneration, T5Tokenizer
4
  import torch
 
 
 
5
 
6
  # Load your fine-tuned model
7
  model_path = "./t5-summarizer" # Path inside Docker container
@@ -11,6 +14,7 @@ tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model = model.to(device)
13
 
 
14
  app = FastAPI()
15
 
16
  class TextInput(BaseModel):
@@ -19,9 +23,41 @@ class TextInput(BaseModel):
19
  @app.post("/summarize/")
20
  def summarize_text(input: TextInput):
21
  input_text = "summarize: " + input.text.strip().replace("\n", " ")
22
-
23
  inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
24
  summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
25
-
26
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
27
  return {"summary": summary}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from pydantic import BaseModel
3
  from transformers import T5ForConditionalGeneration, T5Tokenizer
4
  import torch
5
+ import gradio as gr
6
+ import threading
7
+ import uvicorn
8
 
9
  # Load your fine-tuned model
10
  model_path = "./t5-summarizer" # Path inside Docker container
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model = model.to(device)
16
 
17
+ # FastAPI app
18
  app = FastAPI()
19
 
20
  class TextInput(BaseModel):
 
23
  @app.post("/summarize/")
24
  def summarize_text(input: TextInput):
25
  input_text = "summarize: " + input.text.strip().replace("\n", " ")
 
26
  inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
27
  summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
 
28
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
  return {"summary": summary}
30
+
31
+ # Summarization function for Gradio
32
+ def summarize_ui(text):
33
+ input_text = "summarize: " + text.strip().replace("\n", " ")
34
+ inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
35
+ summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
36
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
37
+ return summary
38
+
39
+ # Gradio interface with example texts
40
+ gradio_app = gr.Interface(
41
+ fn=summarize_ui,
42
+ inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
43
+ outputs=gr.Textbox(label="Summary"),
44
+ title="Text Summarizer",
45
+ description="Paste your long text and get a concise summary using a fine-tuned T5 model.",
46
+ examples=[
47
+ ["Scientists have recently discovered a new species of frog in the Amazon rainforest. This frog is notable for its bright blue legs and unique mating call, which sounds like a series of short whistles. Researchers believe that the discovery of this species could shed new light on the ecological diversity of the region."],
48
+ ["The global economy is expected to grow at a slower pace this year, according to new forecasts released today. Economists point to ongoing geopolitical tensions, supply chain disruptions, and inflationary pressures as key factors contributing to the reduced growth outlook."],
49
+ ["In a thrilling final match, the underdog team scored a last-minute goal to secure their first championship title. Fans erupted into celebration as the team lifted the trophy, marking a historic moment in the club's history."]
50
+ ]
51
+ )
52
+
53
+
54
+ # Function to run Gradio in a thread
55
+ def run_gradio():
56
+ gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=False)
57
+
58
+ # Run Gradio in a separate thread
59
+ threading.Thread(target=run_gradio).start()
60
+
61
+ # Run FastAPI with Uvicorn if needed (for local dev)
62
+ if __name__ == "__main__":
63
+ uvicorn.run(app, host="0.0.0.0", port=8000)