Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -1,10 +1,8 @@
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from pydantic import BaseModel
|
3 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
4 |
import torch
|
5 |
-
import gradio as gr
|
6 |
-
import threading
|
7 |
-
import uvicorn
|
8 |
|
9 |
# Load your fine-tuned model
|
10 |
model_path = "./t5-summarizer" # Path inside Docker container
|
@@ -14,7 +12,6 @@ tokenizer = T5Tokenizer.from_pretrained(model_path, legacy=False)
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
model = model.to(device)
|
16 |
|
17 |
-
# FastAPI app
|
18 |
app = FastAPI()
|
19 |
|
20 |
class TextInput(BaseModel):
|
@@ -23,42 +20,18 @@ class TextInput(BaseModel):
|
|
23 |
@app.post("/summarize/")
|
24 |
def summarize_text(input: TextInput):
|
25 |
input_text = "summarize: " + input.text.strip().replace("\n", " ")
|
26 |
-
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
27 |
-
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
|
28 |
-
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
29 |
-
return {"summary": summary}
|
30 |
|
31 |
-
# Summarization function for Gradio
|
32 |
-
def summarize_ui(text):
|
33 |
-
input_text = "summarize: " + text.strip().replace("\n", " ")
|
34 |
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
35 |
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
|
36 |
-
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
37 |
-
return summary
|
38 |
-
|
39 |
-
# Gradio interface with example texts
|
40 |
-
gradio_app = gr.Interface(
|
41 |
-
fn=summarize_ui,
|
42 |
-
inputs=gr.Textbox(lines=10, placeholder="Paste your text here..."),
|
43 |
-
outputs=gr.Textbox(label="Summary"),
|
44 |
-
title="Text Summarizer",
|
45 |
-
description="Paste your long text and get a concise summary using a fine-tuned T5 model.",
|
46 |
-
examples=[
|
47 |
-
["Scientists have recently discovered a new species of frog in the Amazon rainforest. This frog is notable for its bright blue legs and unique mating call, which sounds like a series of short whistles. Researchers believe that the discovery of this species could shed new light on the ecological diversity of the region."],
|
48 |
-
["The global economy is expected to grow at a slower pace this year, according to new forecasts released today. Economists point to ongoing geopolitical tensions, supply chain disruptions, and inflationary pressures as key factors contributing to the reduced growth outlook."],
|
49 |
-
["In a thrilling final match, the underdog team scored a last-minute goal to secure their first championship title. Fans erupted into celebration as the team lifted the trophy, marking a historic moment in the club's history."]
|
50 |
-
],
|
51 |
-
flagging=False # Disable flagging to prevent permission issue
|
52 |
-
)
|
53 |
-
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
58 |
|
59 |
-
#
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
# Run FastAPI with Uvicorn if needed (for local dev)
|
63 |
-
if __name__ == "__main__":
|
64 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
+
import gradio as gr
|
2 |
from fastapi import FastAPI
|
3 |
from pydantic import BaseModel
|
4 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
5 |
import torch
|
|
|
|
|
|
|
6 |
|
7 |
# Load your fine-tuned model
|
8 |
model_path = "./t5-summarizer" # Path inside Docker container
|
|
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
model = model.to(device)
|
14 |
|
|
|
15 |
app = FastAPI()
|
16 |
|
17 |
class TextInput(BaseModel):
|
|
|
20 |
@app.post("/summarize/")
|
21 |
def summarize_text(input: TextInput):
|
22 |
input_text = "summarize: " + input.text.strip().replace("\n", " ")
|
|
|
|
|
|
|
|
|
23 |
|
|
|
|
|
|
|
24 |
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
25 |
summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
28 |
+
return {"summary": summary}
|
|
|
29 |
|
30 |
+
# Gradio UI setup
|
31 |
+
gr.Interface(
|
32 |
+
fn=lambda text: summarize_text(TextInput(text=text))["summary"], # Ensure it returns summary
|
33 |
+
inputs=gr.Textbox(label="Input Text"),
|
34 |
+
outputs=gr.Textbox(label="Summarized Text"),
|
35 |
+
flagging=False # Disable flagging to prevent permission issues
|
36 |
+
).launch()
|
37 |
|
|
|
|
|
|