|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer |
|
from diffusers import StableDiffusionPipeline |
|
from PIL import Image |
|
import io |
|
import os |
|
import zipfile |
|
import traceback |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Device: {device}") |
|
|
|
|
|
|
|
model_name = "lacos03/bart-base-finetuned-xsum" |
|
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False) |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1) |
|
|
|
|
|
promptist_model = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/Promptist", |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
promptist_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
promptist_tokenizer.pad_token = promptist_tokenizer.eos_token |
|
promptist_tokenizer.padding_side = "left" |
|
|
|
|
|
sd_model_id = "runwayml/stable-diffusion-v1-5" |
|
image_generator = StableDiffusionPipeline.from_pretrained( |
|
sd_model_id, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
use_safetensors=True |
|
).to(device) |
|
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0" |
|
image_generator.load_lora_weights(lora_weights) |
|
|
|
|
|
NUM_BEAMS = 10 |
|
NO_REPEAT_NGRAM_SIZE = 3 |
|
LENGTH_PENALTY = 1.0 |
|
MIN_NEW_TOKENS = 10 |
|
MAX_NEW_TOKENS = 62 |
|
|
|
def summarize_article(article_text): |
|
"""Tóm tắt bài viết và tạo prompt refinement""" |
|
|
|
if not article_text.strip(): |
|
return gr.update(value="❌ <span style='color:red'>Bạn chưa nhập bài viết</span>"), "", "" |
|
|
|
|
|
word_count = len(article_text.split()) |
|
if word_count < 20 or word_count > 300: |
|
return gr.update(value=f"❌ <span style='color:red'>Bài viết phải từ 20–300 từ (hiện tại: {word_count} từ)</span>"), "", "" |
|
|
|
|
|
error_msg = gr.update(value="") |
|
|
|
summary = summarizer( |
|
article_text, |
|
num_beams=NUM_BEAMS, |
|
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE, |
|
length_penalty=LENGTH_PENALTY, |
|
min_new_tokens=MIN_NEW_TOKENS, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
do_sample=False |
|
)[0]["summary_text"] |
|
|
|
title = summary.split(".")[0] + "." |
|
input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device) |
|
eos_id = promptist_tokenizer.eos_token_id |
|
outputs = promptist_model.generate( |
|
input_ids, |
|
do_sample=False, |
|
max_new_tokens=75, |
|
num_beams=8, |
|
num_return_sequences=1, |
|
eos_token_id=eos_id, |
|
pad_token_id=eos_id, |
|
length_penalty=-1.0 |
|
) |
|
output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
prompt = output_texts[0].replace(title + " Rephrase:", "").strip() |
|
|
|
return error_msg, title, prompt |
|
|
|
|
|
def generate_images(prompt, style, num_images=4): |
|
"""Sinh nhiều ảnh""" |
|
styled_prompt = f"{prompt}, {style.lower()} style" |
|
results = image_generator( |
|
styled_prompt, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
num_images_per_prompt=num_images |
|
).images |
|
return results |
|
|
|
def save_selected_images(selected_idx, all_images): |
|
"""Lưu ảnh đã chọn và nén thành ZIP""" |
|
if not selected_idx: |
|
return None |
|
temp_dir = "./temp_selected" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
zip_path = os.path.join(temp_dir, "selected_images.zip") |
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
for idx in selected_idx: |
|
img = all_images[int(idx)] |
|
img_path = os.path.join(temp_dir, f"image_{idx}.png") |
|
img.save(img_path, format="PNG") |
|
zipf.write(img_path, f"image_{idx}.png") |
|
return zip_path |
|
|
|
|
|
def create_app(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 📰 Article → 🖼️ Multiple Image Generator with Selection") |
|
|
|
|
|
with gr.Row(): |
|
article_input = gr.Textbox(label="📄 Bài viết", lines=10) |
|
style_dropdown = gr.Dropdown( |
|
choices=["Art", "Anime", "Watercolor", "Cyberpunk"], |
|
label="🎨 Phong cách ảnh", value="Art" |
|
) |
|
error_box = gr.Markdown(value="", elem_id="error-msg") |
|
num_images_slider = gr.Slider(1, 8, value=4, step=1, label="🔢 Số lượng ảnh") |
|
btn_summary = gr.Button("📌 Sinh Tiêu đề & Prompt") |
|
|
|
title_output = gr.Textbox(label="Tiêu đề") |
|
prompt_output = gr.Textbox(label="Prompt sinh ảnh") |
|
|
|
|
|
btn_generate_images = gr.Button("🎨 Sinh ảnh từ Prompt") |
|
gallery = gr.Gallery(label="🖼️ Ảnh minh họa", columns=2, height=600) |
|
selected_indices = gr.CheckboxGroup(choices=[], label="Chọn ảnh để tải về") |
|
|
|
|
|
btn_download = gr.Button("📥 Tải ảnh đã chọn") |
|
download_file = gr.File(label="File ZIP tải về") |
|
|
|
|
|
btn_summary.click( |
|
fn=summarize_article, |
|
inputs=[article_input], |
|
outputs=[error_box, title_output, prompt_output] |
|
) |
|
|
|
def update_gallery(prompt, style, num_images): |
|
images = generate_images(prompt, style, num_images) |
|
choices = [str(i) for i in range(len(images))] |
|
return images, gr.update(choices=choices, value=[]), images |
|
|
|
image_state = gr.State([]) |
|
|
|
btn_generate_images.click( |
|
fn=update_gallery, |
|
inputs=[prompt_output, style_dropdown, num_images_slider], |
|
outputs=[gallery, selected_indices, image_state] |
|
) |
|
|
|
btn_download.click( |
|
fn=save_selected_images, |
|
inputs=[selected_indices, image_state], |
|
outputs=[download_file] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
app = create_app() |
|
app.launch(debug=True, share=True) |
|
|