File size: 6,478 Bytes
7742b37
 
55b9944
 
7742b37
c1253b3
55b9944
c1253b3
 
7742b37
c1253b3
55b9944
 
 
c1253b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c75b94
 
 
 
 
 
c1253b3
 
 
ebd5453
c1253b3
ebd5453
 
 
 
 
 
 
 
 
 
 
 
 
3c75b94
 
 
ebd5453
 
 
 
c1253b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebd5453
 
 
c1253b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b9944
c1253b3
 
 
 
 
 
 
 
 
 
7742b37
 
c1253b3
d0f68a4
c1253b3
226ec5f
c1253b3
 
 
 
 
ebd5453
c1253b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebd5453
c1253b3
226ec5f
c1253b3
 
 
 
226ec5f
c1253b3
55b9944
c1253b3
 
 
 
 
226ec5f
c1253b3
 
 
 
7742b37
 
 
 
c1253b3
7742b37
 
c1253b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import torch
from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
from PIL import Image
import io
import os
import zipfile
import traceback

# === Thiết lập thiết bị ===
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# === Load models ===
# BART Summarizer
model_name = "lacos03/bart-base-finetuned-xsum"
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1)

# Promptist
promptist_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Promptist",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
promptist_tokenizer = AutoTokenizer.from_pretrained("gpt2")
promptist_tokenizer.pad_token = promptist_tokenizer.eos_token
promptist_tokenizer.padding_side = "left"

# Stable Diffusion + LoRA
sd_model_id = "runwayml/stable-diffusion-v1-5"
image_generator = StableDiffusionPipeline.from_pretrained(
    sd_model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    use_safetensors=True
).to(device)
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
image_generator.load_lora_weights(lora_weights)

# Cấu hình beam search cho BART
NUM_BEAMS = 10
NO_REPEAT_NGRAM_SIZE = 3
LENGTH_PENALTY = 1.0
MIN_NEW_TOKENS = 10
MAX_NEW_TOKENS = 62
# === Hàm xử lý ===
def summarize_article(article_text):
    """Tóm tắt bài viết và tạo prompt refinement"""
    # Kiểm tra rỗng
    if not article_text.strip():
        return gr.update(value="❌ <span style='color:red'>Bạn chưa nhập bài viết</span>"), "", ""
    
    # Kiểm tra số từ
    word_count = len(article_text.split())
    if word_count < 20 or word_count > 300:
        return gr.update(value=f"❌ <span style='color:red'>Bài viết phải từ 20–300 từ (hiện tại: {word_count} từ)</span>"), "", ""
    
    # Nếu hợp lệ thì xóa cảnh báo
    error_msg = gr.update(value="")

    summary = summarizer(
        article_text,
        num_beams=NUM_BEAMS,
        no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
        length_penalty=LENGTH_PENALTY,
        min_new_tokens=MIN_NEW_TOKENS,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False
    )[0]["summary_text"]

    title = summary.split(".")[0] + "."
    input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device)
    eos_id = promptist_tokenizer.eos_token_id
    outputs = promptist_model.generate(
        input_ids,
        do_sample=False,
        max_new_tokens=75,
        num_beams=8,
        num_return_sequences=1,
        eos_token_id=eos_id,
        pad_token_id=eos_id,
        length_penalty=-1.0
    )
    output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    prompt = output_texts[0].replace(title + " Rephrase:", "").strip()
    
    return error_msg, title, prompt


def generate_images(prompt, style, num_images=4):
    """Sinh nhiều ảnh"""
    styled_prompt = f"{prompt}, {style.lower()} style"
    results = image_generator(
        styled_prompt,
        num_inference_steps=50,
        guidance_scale=7.5,
        num_images_per_prompt=num_images
    ).images
    return results

def save_selected_images(selected_idx, all_images):
    """Lưu ảnh đã chọn và nén thành ZIP"""
    if not selected_idx:
        return None
    temp_dir = "./temp_selected"
    os.makedirs(temp_dir, exist_ok=True)
    zip_path = os.path.join(temp_dir, "selected_images.zip")
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for idx in selected_idx:
            img = all_images[int(idx)]
            img_path = os.path.join(temp_dir, f"image_{idx}.png")
            img.save(img_path, format="PNG")
            zipf.write(img_path, f"image_{idx}.png")
    return zip_path

# === UI Gradio ===
def create_app():
    with gr.Blocks() as demo:
        gr.Markdown("## 📰 Article → 🖼️ Multiple Image Generator with Selection")

        # Bước 1: Nhập bài viết và sinh tiêu đề + prompt
        with gr.Row():
            article_input = gr.Textbox(label="📄 Bài viết", lines=10)
            style_dropdown = gr.Dropdown(
                choices=["Art", "Anime", "Watercolor", "Cyberpunk"],
                label="🎨 Phong cách ảnh", value="Art"
            )
        error_box = gr.Markdown(value="", elem_id="error-msg")
        num_images_slider = gr.Slider(1, 8, value=4, step=1, label="🔢 Số lượng ảnh")
        btn_summary = gr.Button("📌 Sinh Tiêu đề & Prompt")

        title_output = gr.Textbox(label="Tiêu đề")
        prompt_output = gr.Textbox(label="Prompt sinh ảnh")

        # Bước 2: Sinh ảnh từ prompt đã refine
        btn_generate_images = gr.Button("🎨 Sinh ảnh từ Prompt")
        gallery = gr.Gallery(label="🖼️ Ảnh minh họa", columns=2, height=600)
        selected_indices = gr.CheckboxGroup(choices=[], label="Chọn ảnh để tải về")

        # Bước 3: Tải ảnh đã chọn
        btn_download = gr.Button("📥 Tải ảnh đã chọn")
        download_file = gr.File(label="File ZIP tải về")

        # Logic
        btn_summary.click(
            fn=summarize_article,
            inputs=[article_input],
            outputs=[error_box, title_output, prompt_output]
        )

        def update_gallery(prompt, style, num_images):
            images = generate_images(prompt, style, num_images)
            choices = [str(i) for i in range(len(images))]
            return images, gr.update(choices=choices, value=[]), images  # images lưu tạm trong state

        image_state = gr.State([])

        btn_generate_images.click(
            fn=update_gallery,
            inputs=[prompt_output, style_dropdown, num_images_slider],
            outputs=[gallery, selected_indices, image_state]
        )

        btn_download.click(
            fn=save_selected_images,
            inputs=[selected_indices, image_state],
            outputs=[download_file]
        )

    return demo

# === Chạy app ===
if __name__ == "__main__":
    app = create_app()
    app.launch(debug=True, share=True)