Curative commited on
Commit
0caec9c
·
verified ·
1 Parent(s): 3e6721b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -37
main.py CHANGED
@@ -1,10 +1,8 @@
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import T5ForConditionalGeneration, T5Tokenizer
4
  import torch
5
- import gradio as gr
6
- import threading
7
- import uvicorn
8
 
9
  # Load your fine-tuned model
10
  model_path = "./t5-summarizer" # Path inside Docker container
@@ -14,7 +12,6 @@ tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model = model.to(device)
16
 
17
- # FastAPI app
18
  app = FastAPI()
19
 
20
  class TextInput(BaseModel):
@@ -23,42 +20,18 @@ class TextInput(BaseModel):
23
  @app.post("/summarize/")
24
  def summarize_text(input: TextInput):
25
  input_text = "summarize: " + input.text.strip().replace("\n", " ")
26
- inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
27
- summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
28
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
- return {"summary": summary}
30
 
31
- # Summarization function for Gradio
32
- def summarize_ui(text):
33
- input_text = "summarize: " + text.strip().replace("\n", " ")
34
  inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
35
  summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
36
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
37
- return summary
38
-
39
- # Gradio interface with example texts
40
- gradio_app = gr.Interface(
41
- fn=summarize_ui,
42
- inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
43
- outputs=gr.Textbox(label="Summary"),
44
- title="Text Summarizer",
45
- description="Paste your long text and get a concise summary using a fine-tuned T5 model.",
46
- examples=[
47
- ["Scientists have recently discovered a new species of frog in the Amazon rainforest. This frog is notable for its bright blue legs and unique mating call, which sounds like a series of short whistles. Researchers believe that the discovery of this species could shed new light on the ecological diversity of the region."],
48
- ["The global economy is expected to grow at a slower pace this year, according to new forecasts released today. Economists point to ongoing geopolitical tensions, supply chain disruptions, and inflationary pressures as key factors contributing to the reduced growth outlook."],
49
- ["In a thrilling final match, the underdog team scored a last-minute goal to secure their first championship title. Fans erupted into celebration as the team lifted the trophy, marking a historic moment in the club's history."]
50
- ],
51
- flagging=False # Disable flagging to prevent permission issue
52
- )
53
-
54
 
55
- # Function to run Gradio in a thread
56
- def run_gradio():
57
- gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=False)
58
 
59
- # Run Gradio in a separate thread
60
- threading.Thread(target=run_gradio).start()
 
 
 
 
 
61
 
62
- # Run FastAPI with Uvicorn if needed (for local dev)
63
- if __name__ == "__main__":
64
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import gradio as gr
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import T5ForConditionalGeneration, T5Tokenizer
5
  import torch
 
 
 
6
 
7
  # Load your fine-tuned model
8
  model_path = "./t5-summarizer" # Path inside Docker container
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model = model.to(device)
14
 
 
15
  app = FastAPI()
16
 
17
  class TextInput(BaseModel):
 
20
  @app.post("/summarize/")
21
  def summarize_text(input: TextInput):
22
  input_text = "summarize: " + input.text.strip().replace("\n", " ")
 
 
 
 
23
 
 
 
 
24
  inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
25
  summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
28
+ return {"summary": summary}
 
29
 
30
+ # Gradio UI setup
31
+ gr.Interface(
32
+ fn=lambda text: summarize_text(TextInput(text=text))["summary"], # Ensure it returns summary
33
+ inputs=gr.Textbox(label="Input Text"),
34
+ outputs=gr.Textbox(label="Summarized Text"),
35
+ flagging=False # Disable flagging to prevent permission issues
36
+ ).launch()
37