Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -2,6 +2,9 @@ from fastapi import FastAPI
|
|
2 |
from pydantic import BaseModel
|
3 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
4 |
import torch
|
|
|
|
|
|
|
5 |
|
6 |
# Load your fine-tuned model
|
7 |
model_path = "./t5-summarizer" # Path inside Docker container
|
@@ -11,6 +14,7 @@ tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
|
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
model = model.to(device)
|
13 |
|
|
|
14 |
app = FastAPI()
|
15 |
|
16 |
class TextInput(BaseModel):
|
@@ -19,9 +23,41 @@ class TextInput(BaseModel):
|
|
19 |
@app.post("/summarize/")
|
20 |
def summarize_text(input: TextInput):
|
21 |
input_text = "summarize: " + input.text.strip().replace("\n", " ")
|
22 |
-
|
23 |
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
24 |
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
|
25 |
-
|
26 |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
27 |
return {"summary": summary}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
4 |
import torch
|
5 |
+
import gradio as gr
|
6 |
+
import threading
|
7 |
+
import uvicorn
|
8 |
|
9 |
# Load your fine-tuned model
|
10 |
model_path = "./t5-summarizer" # Path inside Docker container
|
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
model = model.to(device)
|
16 |
|
17 |
+
# FastAPI app
|
18 |
app = FastAPI()
|
19 |
|
20 |
class TextInput(BaseModel):
|
|
|
23 |
@app.post("/summarize/")
|
24 |
def summarize_text(input: TextInput):
|
25 |
input_text = "summarize: " + input.text.strip().replace("\n", " ")
|
|
|
26 |
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
27 |
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
|
|
|
28 |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
29 |
return {"summary": summary}
|
30 |
+
|
31 |
+
# Summarization function for Gradio
|
32 |
+
def summarize_ui(text):
|
33 |
+
input_text = "summarize: " + text.strip().replace("\n", " ")
|
34 |
+
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
35 |
+
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
|
36 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
37 |
+
return summary
|
38 |
+
|
39 |
+
# Gradio interface with example texts
|
40 |
+
gradio_app = gr.Interface(
|
41 |
+
fn=summarize_ui,
|
42 |
+
inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
|
43 |
+
outputs=gr.Textbox(label="Summary"),
|
44 |
+
title="Text Summarizer",
|
45 |
+
description="Paste your long text and get a concise summary using a fine-tuned T5 model.",
|
46 |
+
examples=[
|
47 |
+
["Scientists have recently discovered a new species of frog in the Amazon rainforest. This frog is notable for its bright blue legs and unique mating call, which sounds like a series of short whistles. Researchers believe that the discovery of this species could shed new light on the ecological diversity of the region."],
|
48 |
+
["The global economy is expected to grow at a slower pace this year, according to new forecasts released today. Economists point to ongoing geopolitical tensions, supply chain disruptions, and inflationary pressures as key factors contributing to the reduced growth outlook."],
|
49 |
+
["In a thrilling final match, the underdog team scored a last-minute goal to secure their first championship title. Fans erupted into celebration as the team lifted the trophy, marking a historic moment in the club's history."]
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
# Function to run Gradio in a thread
|
55 |
+
def run_gradio():
|
56 |
+
gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
57 |
+
|
58 |
+
# Run Gradio in a separate thread
|
59 |
+
threading.Thread(target=run_gradio).start()
|
60 |
+
|
61 |
+
# Run FastAPI with Uvicorn if needed (for local dev)
|
62 |
+
if __name__ == "__main__":
|
63 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|