import gradio as gr import torch from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer from diffusers import StableDiffusionPipeline from PIL import Image import io import os import zipfile import traceback # === Thiết lập thiết bị === device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") # === Load models === # BART Summarizer model_name = "lacos03/bart-base-finetuned-xsum" tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1) # Promptist promptist_model = AutoModelForCausalLM.from_pretrained( "microsoft/Promptist", torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) promptist_tokenizer = AutoTokenizer.from_pretrained("gpt2") promptist_tokenizer.pad_token = promptist_tokenizer.eos_token promptist_tokenizer.padding_side = "left" # Stable Diffusion + LoRA sd_model_id = "runwayml/stable-diffusion-v1-5" image_generator = StableDiffusionPipeline.from_pretrained( sd_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, use_safetensors=True ).to(device) lora_weights = "lacos03/std-1.5-lora-midjourney-1.0" image_generator.load_lora_weights(lora_weights) # Cấu hình beam search cho BART NUM_BEAMS = 10 NO_REPEAT_NGRAM_SIZE = 3 LENGTH_PENALTY = 1.0 MIN_NEW_TOKENS = 10 MAX_NEW_TOKENS = 62 # === Hàm xử lý === def summarize_article(article_text): """Tóm tắt bài viết và tạo prompt refinement""" # Kiểm tra rỗng if not article_text.strip(): return gr.update(value="❌ Bạn chưa nhập bài viết"), "", "" # Kiểm tra số từ word_count = len(article_text.split()) if word_count < 20 or word_count > 300: return gr.update(value=f"❌ Bài viết phải từ 20–300 từ (hiện tại: {word_count} từ)"), "", "" # Nếu hợp lệ thì xóa cảnh báo error_msg = gr.update(value="") summary = summarizer( article_text, num_beams=NUM_BEAMS, no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE, length_penalty=LENGTH_PENALTY, min_new_tokens=MIN_NEW_TOKENS, max_new_tokens=MAX_NEW_TOKENS, do_sample=False )[0]["summary_text"] title = summary.split(".")[0] + "." input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device) eos_id = promptist_tokenizer.eos_token_id outputs = promptist_model.generate( input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=1, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0 ) output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True) prompt = output_texts[0].replace(title + " Rephrase:", "").strip() return error_msg, title, prompt def generate_images(prompt, style, num_images=4): """Sinh nhiều ảnh""" styled_prompt = f"{prompt}, {style.lower()} style" results = image_generator( styled_prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=num_images ).images return results def save_selected_images(selected_idx, all_images): """Lưu ảnh đã chọn và nén thành ZIP""" if not selected_idx: return None temp_dir = "./temp_selected" os.makedirs(temp_dir, exist_ok=True) zip_path = os.path.join(temp_dir, "selected_images.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: for idx in selected_idx: img = all_images[int(idx)] img_path = os.path.join(temp_dir, f"image_{idx}.png") img.save(img_path, format="PNG") zipf.write(img_path, f"image_{idx}.png") return zip_path # === UI Gradio === def create_app(): with gr.Blocks() as demo: gr.Markdown("## 📰 Article → 🖼️ Multiple Image Generator with Selection") # Bước 1: Nhập bài viết và sinh tiêu đề + prompt with gr.Row(): article_input = gr.Textbox(label="📄 Bài viết", lines=10) style_dropdown = gr.Dropdown( choices=["Art", "Anime", "Watercolor", "Cyberpunk"], label="🎨 Phong cách ảnh", value="Art" ) error_box = gr.Markdown(value="", elem_id="error-msg") num_images_slider = gr.Slider(1, 8, value=4, step=1, label="🔢 Số lượng ảnh") btn_summary = gr.Button("📌 Sinh Tiêu đề & Prompt") title_output = gr.Textbox(label="Tiêu đề") prompt_output = gr.Textbox(label="Prompt sinh ảnh") # Bước 2: Sinh ảnh từ prompt đã refine btn_generate_images = gr.Button("🎨 Sinh ảnh từ Prompt") gallery = gr.Gallery(label="🖼️ Ảnh minh họa", columns=2, height=600) selected_indices = gr.CheckboxGroup(choices=[], label="Chọn ảnh để tải về") # Bước 3: Tải ảnh đã chọn btn_download = gr.Button("📥 Tải ảnh đã chọn") download_file = gr.File(label="File ZIP tải về") # Logic btn_summary.click( fn=summarize_article, inputs=[article_input], outputs=[error_box, title_output, prompt_output] ) def update_gallery(prompt, style, num_images): images = generate_images(prompt, style, num_images) choices = [str(i) for i in range(len(images))] return images, gr.update(choices=choices, value=[]), images # images lưu tạm trong state image_state = gr.State([]) btn_generate_images.click( fn=update_gallery, inputs=[prompt_output, style_dropdown, num_images_slider], outputs=[gallery, selected_indices, image_state] ) btn_download.click( fn=save_selected_images, inputs=[selected_indices, image_state], outputs=[download_file] ) return demo # === Chạy app === if __name__ == "__main__": app = create_app() app.launch(debug=True, share=True)