import gradio as gr
import torch
from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
from PIL import Image
import io
import os
import zipfile
import traceback
# === Thiết lập thiết bị ===
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
# === Load models ===
# BART Summarizer
model_name = "lacos03/bart-base-finetuned-xsum"
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1)
# Promptist
promptist_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Promptist",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
promptist_tokenizer = AutoTokenizer.from_pretrained("gpt2")
promptist_tokenizer.pad_token = promptist_tokenizer.eos_token
promptist_tokenizer.padding_side = "left"
# Stable Diffusion + LoRA
sd_model_id = "runwayml/stable-diffusion-v1-5"
image_generator = StableDiffusionPipeline.from_pretrained(
sd_model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=True
).to(device)
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
image_generator.load_lora_weights(lora_weights)
# Cấu hình beam search cho BART
NUM_BEAMS = 10
NO_REPEAT_NGRAM_SIZE = 3
LENGTH_PENALTY = 1.0
MIN_NEW_TOKENS = 10
MAX_NEW_TOKENS = 62
# === Hàm xử lý ===
def summarize_article(article_text):
"""Tóm tắt bài viết và tạo prompt refinement"""
# Kiểm tra rỗng
if not article_text.strip():
return gr.update(value="❌ Bạn chưa nhập bài viết"), "", ""
# Kiểm tra số từ
word_count = len(article_text.split())
if word_count < 20 or word_count > 300:
return gr.update(value=f"❌ Bài viết phải từ 20–300 từ (hiện tại: {word_count} từ)"), "", ""
# Nếu hợp lệ thì xóa cảnh báo
error_msg = gr.update(value="")
summary = summarizer(
article_text,
num_beams=NUM_BEAMS,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
length_penalty=LENGTH_PENALTY,
min_new_tokens=MIN_NEW_TOKENS,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False
)[0]["summary_text"]
title = summary.split(".")[0] + "."
input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device)
eos_id = promptist_tokenizer.eos_token_id
outputs = promptist_model.generate(
input_ids,
do_sample=False,
max_new_tokens=75,
num_beams=8,
num_return_sequences=1,
eos_token_id=eos_id,
pad_token_id=eos_id,
length_penalty=-1.0
)
output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
prompt = output_texts[0].replace(title + " Rephrase:", "").strip()
return error_msg, title, prompt
def generate_images(prompt, style, num_images=4):
"""Sinh nhiều ảnh"""
styled_prompt = f"{prompt}, {style.lower()} style"
results = image_generator(
styled_prompt,
num_inference_steps=50,
guidance_scale=7.5,
num_images_per_prompt=num_images
).images
return results
def save_selected_images(selected_idx, all_images):
"""Lưu ảnh đã chọn và nén thành ZIP"""
if not selected_idx:
return None
temp_dir = "./temp_selected"
os.makedirs(temp_dir, exist_ok=True)
zip_path = os.path.join(temp_dir, "selected_images.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for idx in selected_idx:
img = all_images[int(idx)]
img_path = os.path.join(temp_dir, f"image_{idx}.png")
img.save(img_path, format="PNG")
zipf.write(img_path, f"image_{idx}.png")
return zip_path
# === UI Gradio ===
def create_app():
with gr.Blocks() as demo:
gr.Markdown("## 📰 Article → 🖼️ Multiple Image Generator with Selection")
# Bước 1: Nhập bài viết và sinh tiêu đề + prompt
with gr.Row():
article_input = gr.Textbox(label="📄 Bài viết", lines=10)
style_dropdown = gr.Dropdown(
choices=["Art", "Anime", "Watercolor", "Cyberpunk"],
label="🎨 Phong cách ảnh", value="Art"
)
error_box = gr.Markdown(value="", elem_id="error-msg")
num_images_slider = gr.Slider(1, 8, value=4, step=1, label="🔢 Số lượng ảnh")
btn_summary = gr.Button("📌 Sinh Tiêu đề & Prompt")
title_output = gr.Textbox(label="Tiêu đề")
prompt_output = gr.Textbox(label="Prompt sinh ảnh")
# Bước 2: Sinh ảnh từ prompt đã refine
btn_generate_images = gr.Button("🎨 Sinh ảnh từ Prompt")
gallery = gr.Gallery(label="🖼️ Ảnh minh họa", columns=2, height=600)
selected_indices = gr.CheckboxGroup(choices=[], label="Chọn ảnh để tải về")
# Bước 3: Tải ảnh đã chọn
btn_download = gr.Button("📥 Tải ảnh đã chọn")
download_file = gr.File(label="File ZIP tải về")
# Logic
btn_summary.click(
fn=summarize_article,
inputs=[article_input],
outputs=[error_box, title_output, prompt_output]
)
def update_gallery(prompt, style, num_images):
images = generate_images(prompt, style, num_images)
choices = [str(i) for i in range(len(images))]
return images, gr.update(choices=choices, value=[]), images # images lưu tạm trong state
image_state = gr.State([])
btn_generate_images.click(
fn=update_gallery,
inputs=[prompt_output, style_dropdown, num_images_slider],
outputs=[gallery, selected_indices, image_state]
)
btn_download.click(
fn=save_selected_images,
inputs=[selected_indices, image_state],
outputs=[download_file]
)
return demo
# === Chạy app ===
if __name__ == "__main__":
app = create_app()
app.launch(debug=True, share=True)