Curative commited on
Commit
5e41748
·
verified ·
1 Parent(s): 0caec9c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -22
main.py CHANGED
@@ -3,35 +3,44 @@ from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import T5ForConditionalGeneration, T5Tokenizer
5
  import torch
 
 
6
 
7
- # Load your fine-tuned model
8
- model_path = "./t5-summarizer" # Path inside Docker container
9
- model = T5ForConditionalGeneration.from_pretrained(model_path)
10
  tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
11
-
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model = model.to(device)
14
 
 
15
  app = FastAPI()
16
-
17
  class TextInput(BaseModel):
18
  text: str
19
 
20
  @app.post("/summarize/")
21
  def summarize_text(input: TextInput):
22
- input_text = "summarize: " + input.text.strip().replace("\n", " ")
23
-
24
- inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
25
- summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
26
-
27
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
28
- return {"summary": summary}
29
-
30
- # Gradio UI setup
31
- gr.Interface(
32
- fn=lambda text: summarize_text(TextInput(text=text))["summary"], # Ensure it returns summary
33
- inputs=gr.Textbox(label="Input Text"),
34
- outputs=gr.Textbox(label="Summarized Text"),
35
- flagging=False # Disable flagging to prevent permission issues
36
- ).launch()
37
-
 
 
 
 
 
 
 
 
 
3
  from pydantic import BaseModel
4
  from transformers import T5ForConditionalGeneration, T5Tokenizer
5
  import torch
6
+ import threading
7
+ import uvicorn
8
 
9
+ # Load model & tokenizer
10
+ model_path = "./t5-summarizer"
 
11
  tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
12
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
 
16
+ # FastAPI setup
17
  app = FastAPI()
 
18
  class TextInput(BaseModel):
19
  text: str
20
 
21
  @app.post("/summarize/")
22
  def summarize_text(input: TextInput):
23
+ inputs = tokenizer("summarize: " + input.text.replace("\n"," "),
24
+ return_tensors="pt", max_length=512, truncation=True)
25
+ outputs = model.generate(inputs.input_ids.to(device),
26
+ max_length=150, min_length=30,
27
+ length_penalty=2.0, num_beams=4, early_stopping=True)
28
+ return {"summary": tokenizer.decode(outputs[0], skip_special_tokens=True)}
29
+
30
+ def run_fastapi():
31
+ uvicorn.run(app, host="0.0.0.0", port=8000)
32
+
33
+ # Gradio UI
34
+ iface = gr.Interface(
35
+ fn=lambda text: summarize_text(TextInput(text=text))["summary"],
36
+ inputs=gr.Textbox(lines=10, placeholder="Paste text here..."),
37
+ outputs=gr.Textbox(label="Summary"),
38
+ title="Text Summarizer",
39
+ description="Fine-tuned T5 summarizer",
40
+ flagging_mode="never", # Disable flagging
41
+ examples=[["Your example text here..."]] # Pre-load examples
42
+ )
43
+
44
+ # Start FastAPI in background, then launch Gradio
45
+ threading.Thread(target=run_fastapi, daemon=True).start()
46
+ iface.launch(server_name="0.0.0.0", server_port=7860)