nguyenlam0306
commited on
Commit
·
7742b37
1
Parent(s):
b37c6e8
Load ung dung
Browse files- app.py +91 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
+
from diffusers import StableDiffusionPipeline
|
4 |
+
import torch
|
5 |
+
import io
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
# Load the summarization model (lacos03/bart-base-finetuned-xsum)
|
9 |
+
summarizer = pipeline("summarization", model="lacos03/bart-base-finetuned-xsum")
|
10 |
+
|
11 |
+
# Load the prompt generation model (microsoft/Promptist)
|
12 |
+
promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
|
13 |
+
promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist")
|
14 |
+
|
15 |
+
# Load the base Stable Diffusion model
|
16 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
# Initialize the Stable Diffusion pipeline
|
20 |
+
image_generator = StableDiffusionPipeline.from_pretrained(
|
21 |
+
model_id,
|
22 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
23 |
+
use_auth_token=False
|
24 |
+
).to(device)
|
25 |
+
|
26 |
+
# Load LoRA weights
|
27 |
+
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
|
28 |
+
image_generator.load_lora_weights(lora_weights)
|
29 |
+
|
30 |
+
def summarize_and_generate_image(article_text):
|
31 |
+
# Step 1: Summarize the article and extract a title
|
32 |
+
summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
|
33 |
+
title = summary.split(".")[0] # Use the first sentence as the title
|
34 |
+
|
35 |
+
# Step 2: Generate a prompt from the title using Promptist
|
36 |
+
inputs = promptist_tokenizer(title, return_tensors="pt").to(device)
|
37 |
+
prompt_output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1)
|
38 |
+
generated_prompt = promptist_tokenizer.decode(prompt_output[0], skip_special_tokens=True)
|
39 |
+
|
40 |
+
# Step 3: Generate an image from the prompt
|
41 |
+
image = image_generator(
|
42 |
+
generated_prompt,
|
43 |
+
num_inference_steps=50,
|
44 |
+
guidance_scale=7.5
|
45 |
+
).images[0]
|
46 |
+
|
47 |
+
# Step 4: Save the image to a BytesIO object for download
|
48 |
+
img_byte_arr = io.BytesIO()
|
49 |
+
image.save(img_byte_arr, format="PNG")
|
50 |
+
img_byte_arr.seek(0)
|
51 |
+
|
52 |
+
return title, generated_prompt, image, img_byte_arr
|
53 |
+
|
54 |
+
# Define the Gradio interface
|
55 |
+
def create_app():
|
56 |
+
with gr.Blocks() as demo:
|
57 |
+
gr.Markdown("# Article Summarizer and Image Generator")
|
58 |
+
gr.Markdown("Enter an article, get a summarized title, and generate an illustrative image.")
|
59 |
+
|
60 |
+
# Input for article text
|
61 |
+
article_input = gr.Textbox(label="Input Article", lines=10, placeholder="Paste your article here...")
|
62 |
+
|
63 |
+
# Output components
|
64 |
+
title_output = gr.Textbox(label="Generated Title")
|
65 |
+
prompt_output = gr.Textbox(label="Generated Prompt")
|
66 |
+
image_output = gr.Image(label="Generated Image")
|
67 |
+
|
68 |
+
# Buttons
|
69 |
+
submit_button = gr.Button("Generate")
|
70 |
+
copy_button = gr.Button("Copy Title")
|
71 |
+
download_button = gr.File(label="Download Image")
|
72 |
+
|
73 |
+
# Define actions
|
74 |
+
submit_button.click(
|
75 |
+
fn=summarize_and_generate_image,
|
76 |
+
inputs=article_input,
|
77 |
+
outputs=[title_output, prompt_output, image_output, download_button]
|
78 |
+
)
|
79 |
+
copy_button.click(
|
80 |
+
fn=lambda x: x,
|
81 |
+
inputs=title_output,
|
82 |
+
outputs=gr.State(value=title_output), # For copying to clipboard
|
83 |
+
_js="function(x) { navigator.clipboard.writeText(x); return x; }"
|
84 |
+
)
|
85 |
+
|
86 |
+
return demo
|
87 |
+
|
88 |
+
# Launch the app
|
89 |
+
if __name__ == "__main__":
|
90 |
+
app = create_app()
|
91 |
+
app.launch()
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
gradio==3.50.2 transformers==4.38.2 torch==2.0.1 diffusers==0.21.4 pillow==9.5.0 safetensors==0.4.0
|