nguyenlam0306 commited on
Commit
55b9944
·
1 Parent(s): 226ec5f
Files changed (1) hide show
  1. app.py +106 -39
app.py CHANGED
@@ -1,57 +1,105 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, GenerationConfig, AutoModelForCausalLM, AutoTokenizer
3
- from diffusers import StableDiffusionPipeline
4
  import torch
 
 
5
  import io
6
  from PIL import Image
7
  import traceback
 
 
8
 
9
- # === Load models ===
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Summarizer (BART)
13
- model_name = "lacos03/bart-base-finetuned-xsum"
14
- tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)
15
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
16
- generation_config = GenerationConfig.from_pretrained(model_name)
17
- generation_config.early_stopping = True
18
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, generation_config=generation_config)
19
-
20
- # Promptist
21
- promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
22
- promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist")
23
-
24
- # Stable Diffusion + LoRA
25
- sd_model_id = "runwayml/stable-diffusion-v1-5"
26
- image_generator = StableDiffusionPipeline.from_pretrained(
27
- sd_model_id,
28
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
29
- use_safetensors=True
30
- ).to(device)
31
- lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
32
- image_generator.load_lora_weights(lora_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # === Modular hóa ===
35
  def summarize(article_text):
 
 
36
  try:
37
- if not article_text.strip():
38
- return "[Empty input]", "[Empty input]"
39
  summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
40
- title = summary.split(".")[0]
41
  return title, summary
42
  except Exception as e:
43
- return "[Error in summarization]", str(e)
44
 
45
  def generate_prompt(title):
 
 
46
  try:
47
- inputs = promptist_tokenizer(title, return_tensors="pt").to(device)
48
- output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1)
49
- prompt = promptist_tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
50
  return prompt
51
  except Exception as e:
52
- return "[Error in prompt generation]"
53
 
54
  def generate_image(prompt, style):
 
 
 
 
 
 
55
  try:
56
  styled_prompt = f"{prompt}, {style.lower()} style"
57
  result = image_generator(
@@ -64,16 +112,35 @@ def generate_image(prompt, style):
64
  img_byte_arr.seek(0)
65
  return result, img_byte_arr
66
  except Exception as e:
67
- print(traceback.format_exc())
68
  blank = Image.new("RGB", (512, 512), (255, 255, 255))
69
- return blank, io.BytesIO()
 
 
 
70
 
71
  # === Main processing function ===
72
  def process(article_text, style_choice):
 
73
  title, summary = summarize(article_text)
 
74
  prompt = generate_prompt(title)
 
75
  image, img_bytes = generate_image(prompt, style_choice)
76
- return title, prompt, image, img_bytes
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # === Gradio UI ===
79
  def create_app():
@@ -83,7 +150,7 @@ def create_app():
83
 
84
  with gr.Row():
85
  article_input = gr.Textbox(label="📄 Bài viết", lines=10, placeholder="Dán nội dung bài viết ở đây...")
86
- style_dropdown = gr.Dropdown(choices=["Realistic", "Anime", "Watercolor", "Cyberpunk"], label="🎨 Phong cách ảnh", value="Realistic")
87
 
88
  with gr.Row():
89
  submit_button = gr.Button("🚀 Tạo Tiêu đề & Ảnh Minh họa")
@@ -91,7 +158,7 @@ def create_app():
91
  with gr.Row():
92
  title_output = gr.Textbox(label="📌 Tiêu đề được tạo")
93
  prompt_output = gr.Textbox(label="🔧 Prompt sinh ảnh")
94
-
95
  image_output = gr.Image(label="🖼️ Ảnh minh họa", interactive=True)
96
  download_button = gr.File(label="📥 Tải ảnh")
97
 
@@ -108,4 +175,4 @@ def create_app():
108
  # === Launch ===
109
  if __name__ == "__main__":
110
  app = create_app()
111
- app.launch()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer
4
+ from diffusers import StableDiffusionPipeline
5
  import io
6
  from PIL import Image
7
  import traceback
8
+ import os
9
+ from pathlib import Path
10
 
 
 
11
 
12
+ # === Thiết lập môi trường ===
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Device: {device}")
15
+
16
+ # === Load models với xử lý lỗi ===
17
+ try:
18
+ # Summarizer (BART)
19
+ model_name = "lacos03/bart-base-finetuned-xsum"
20
+ print(f"Loading BART model from {model_name}...")
21
+ tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
23
+ model.to(device)
24
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
25
+ print("✅ BART loaded successfully")
26
+ except Exception as e:
27
+ print(f"❌ Error loading BART: {e}")
28
+ summarizer = None
29
+
30
+ try:
31
+ # Promptist
32
+ print("Loading Promptist model...")
33
+ def load_prompter():
34
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
35
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+ tokenizer.padding_side = "left"
38
+ return prompter_model, tokenizer
39
+ promptist_model, promptist_tokenizer = load_prompter()
40
+ print("✅ Promptist loaded successfully")
41
+ except Exception as e:
42
+ print(f"❌ Error loading Promptist: {e}")
43
+ promptist_model = None
44
+ promptist_tokenizer = None
45
+
46
+ try:
47
+ # Stable Diffusion + LoRA
48
+ print("Loading Stable Diffusion model...")
49
+ sd_model_id = "runwayml/stable-diffusion-v1-5"
50
+ image_generator = StableDiffusionPipeline.from_pretrained(
51
+ sd_model_id,
52
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
53
+ use_safetensors=True
54
+ ).to(device)
55
+ lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
56
+ print(f"Loading LoRA weights from {lora_weights}...")
57
+ image_generator.load_lora_weights(lora_weights)
58
+ print("✅ Stable Diffusion with LoRA loaded successfully")
59
+ except Exception as e:
60
+ print(f"❌ Error loading Stable Diffusion or LoRA: {e}")
61
+ image_generator = None
62
 
63
  # === Modular hóa ===
64
  def summarize(article_text):
65
+ if not summarizer or not article_text.strip():
66
+ return "[Empty input or model not loaded]", "[Empty input or model not loaded]"
67
  try:
 
 
68
  summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
69
+ title = summary.split(".")[0] + "."
70
  return title, summary
71
  except Exception as e:
72
+ return f"[Error in summarization: {e}]", f"[Error in summarization: {e}]"
73
 
74
  def generate_prompt(title):
75
+ if not promptist_model or not promptist_tokenizer or not title:
76
+ return "[Error: Promptist not loaded or no title]"
77
  try:
78
+ input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device)
79
+ eos_id = promptist_tokenizer.eos_token_id
80
+ outputs = promptist_model.generate(
81
+ input_ids,
82
+ do_sample=False,
83
+ max_new_tokens=75,
84
+ num_beams=8,
85
+ num_return_sequences=8,
86
+ eos_token_id=eos_id,
87
+ pad_token_id=eos_id,
88
+ length_penalty=-1.0
89
+ )
90
+ output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
91
+ prompt = output_texts[0].replace(title + " Rephrase:", "").strip()
92
  return prompt
93
  except Exception as e:
94
+ return f"[Error in prompt generation: {e}]"
95
 
96
  def generate_image(prompt, style):
97
+ if not image_generator or not prompt:
98
+ blank = Image.new("RGB", (512, 512), (255, 255, 255))
99
+ img_byte_arr = io.BytesIO()
100
+ blank.save(img_byte_arr, format="PNG")
101
+ img_byte_arr.seek(0)
102
+ return blank, img_byte_arr
103
  try:
104
  styled_prompt = f"{prompt}, {style.lower()} style"
105
  result = image_generator(
 
112
  img_byte_arr.seek(0)
113
  return result, img_byte_arr
114
  except Exception as e:
115
+ print(f"❌ Image generation error: {traceback.format_exc()}")
116
  blank = Image.new("RGB", (512, 512), (255, 255, 255))
117
+ img_byte_arr = io.BytesIO()
118
+ blank.save(img_byte_arr, format="PNG")
119
+ img_byte_arr.seek(0)
120
+ return blank, img_byte_arr
121
 
122
  # === Main processing function ===
123
  def process(article_text, style_choice):
124
+ print(f"Processing article: {article_text[:50]}...")
125
  title, summary = summarize(article_text)
126
+ print(f"Summary title: {title}")
127
  prompt = generate_prompt(title)
128
+ print(f"Generated prompt: {prompt}")
129
  image, img_bytes = generate_image(prompt, style_choice)
130
+ print(f"Image generated: {image.size if image else 'None'}")
131
+
132
+ # Chuyển BytesIO thành file tạm và trả về đường dẫn
133
+ temp_dir = "./temp"
134
+ os.makedirs(temp_dir, exist_ok=True)
135
+ temp_file = os.path.join(temp_dir, f"generated_image_{id(image)}.png")
136
+ image.save(temp_file, format="PNG")
137
+ with open(temp_file, "rb") as f:
138
+ img_file = f.read()
139
+ # Trả về đường dẫn tạm thời cho Gradio
140
+ file_path = temp_file
141
+
142
+ print(f"✅ Process completed")
143
+ return title, prompt, image, file_path
144
 
145
  # === Gradio UI ===
146
  def create_app():
 
150
 
151
  with gr.Row():
152
  article_input = gr.Textbox(label="📄 Bài viết", lines=10, placeholder="Dán nội dung bài viết ở đây...")
153
+ style_dropdown = gr.Dropdown(choices=["Art", "Anime", "Watercolor", "Cyberpunk"], label="🎨 Phong cách ảnh", value="Art")
154
 
155
  with gr.Row():
156
  submit_button = gr.Button("🚀 Tạo Tiêu đề & Ảnh Minh họa")
 
158
  with gr.Row():
159
  title_output = gr.Textbox(label="📌 Tiêu đề được tạo")
160
  prompt_output = gr.Textbox(label="🔧 Prompt sinh ảnh")
161
+
162
  image_output = gr.Image(label="🖼️ Ảnh minh họa", interactive=True)
163
  download_button = gr.File(label="📥 Tải ảnh")
164
 
 
175
  # === Launch ===
176
  if __name__ == "__main__":
177
  app = create_app()
178
+ app.launch(debug=True, share=True)