Ephemeral182 commited on
Commit
1aca16b
·
verified ·
1 Parent(s): 4b3c0f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -211
app.py CHANGED
@@ -45,14 +45,7 @@ logging.basicConfig(
45
  # 2. Model Download Function (CPU only)
46
  # ------------------------------------------------------------------
47
  def download_model_weights(target_dir, repo_id, subdir=None):
48
- """
49
- Download model weights to specified directory (CPU operation)
50
-
51
- Args:
52
- target_dir (str): Local target directory
53
- repo_id (str): HuggingFace repository ID
54
- subdir (str): Subdirectory path in the repository (optional)
55
- """
56
  from huggingface_hub import snapshot_download
57
  import shutil
58
 
@@ -71,7 +64,6 @@ def download_model_weights(target_dir, repo_id, subdir=None):
71
  "local_dir_use_symlinks": False,
72
  }
73
 
74
- # Add token if available
75
  if hf_token:
76
  download_kwargs["token"] = hf_token
77
 
@@ -125,26 +117,23 @@ def ensure_models_downloaded():
125
  ensure_models_downloaded()
126
 
127
  # ------------------------------------------------------------------
128
- # 4. Qwen Prompt Rewriting Agent
129
  # ------------------------------------------------------------------
130
- def create_qwen_agent(model_path):
131
- """Create Qwen agent inside GPU context"""
132
- load_kwargs = {
133
- "torch_dtype": torch.bfloat16,
134
- "device_map": "auto"
135
- }
136
-
137
- # Add token if available
138
- if hf_token:
139
- load_kwargs["token"] = hf_token
140
-
141
- tokenizer = AutoTokenizer.from_pretrained(model_path, **load_kwargs)
142
- model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
143
- return tokenizer, model
144
 
145
- def recap_prompt(tokenizer, model, text):
146
- """Recap prompt using Qwen model"""
147
- prompt_template = """You are an expert poster prompt designer. Your task is to rewrite a user's short poster prompt into a detailed and vivid long-format prompt. Follow these steps carefully:
 
 
 
 
 
 
148
 
149
  **Step 1: Analyze the Core Requirements**
150
  Identify the key elements in the user's prompt. Do not miss any details.
@@ -182,33 +171,120 @@ Elaborate on each core requirement to create a rich description.
182
  ---
183
  **User Prompt:**
184
  {brief_description}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- try:
187
- messages = [
188
- {"role": "user", "content": prompt_template.format(brief_description=text)}
189
- ]
190
- chat = tokenizer.apply_chat_template(
191
- messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
192
- )
193
- inputs = tokenizer([chat], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
194
 
195
- with torch.no_grad():
196
- ids = model.generate(
197
- **inputs, max_new_tokens=1024, temperature=0.6, do_sample=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
- out = tokenizer.decode(
200
- ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True
201
- ).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- if "</think>" in out:
204
- out = out.split("</think>")[-1].strip()
205
- return out or text
206
- except Exception as e:
207
- logging.error(f"Prompt recap failed: {e}")
208
- return text
 
 
 
 
 
209
 
210
  # ------------------------------------------------------------------
211
- # 5. Main Generation Function (GPU)
212
  # ------------------------------------------------------------------
213
  @spaces.GPU(duration=300)
214
  def generate_image_interface(
@@ -217,102 +293,47 @@ def generate_image_interface(
217
  progress=gr.Progress(track_tqdm=True),
218
  ):
219
  """Generate image using FLUX pipeline"""
 
 
 
220
  try:
221
- # If no token available, return error message
222
  if not hf_token:
223
  return None, "❌ Error: HF_TOKEN not found. Please configure authentication.", ""
224
 
225
- # Set device and dtype
226
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
227
- torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
228
-
229
- # Initialize FLUX pipeline
230
- progress(0.1, desc="Loading FLUX pipeline...")
231
- pipeline = FluxPipeline.from_pretrained(
232
- DEFAULT_PIPELINE_PATH,
233
- torch_dtype=torch_dtype,
234
- device_map="balanced" if device.type == "cuda" else None,
235
- token=hf_token
236
- )
237
 
238
- # Load custom transformer weights if available
239
- custom_weights_local = "local_weights/PosterCraft-v1_RL"
240
- if os.path.exists(custom_weights_local):
241
- progress(0.3, desc="Loading custom transformer weights...")
242
- try:
243
- custom_transformer = FluxTransformer2DModel.from_pretrained(
244
- custom_weights_local,
245
- torch_dtype=torch_dtype,
246
- device_map="balanced" if device.type == "cuda" else None,
247
- token=hf_token
248
- )
249
- pipeline.transformer = custom_transformer
250
- logging.info("Custom transformer weights loaded successfully")
251
- except Exception as e:
252
- logging.warning(f"Failed to load custom transformer weights: {e}")
253
-
254
- # Process prompt
255
- final_prompt = original_prompt
256
- if enable_recap:
257
- progress(0.5, desc="Processing prompt with Qwen...")
258
- qwen_local = "local_weights/Qwen3-8B"
259
- if os.path.exists(qwen_local):
260
- try:
261
- tokenizer, model = create_qwen_agent(qwen_local)
262
- final_prompt = recap_prompt(tokenizer, model, original_prompt)
263
- logging.info(f"Enhanced prompt: {final_prompt}")
264
- # Clean up Qwen model to free memory
265
- del tokenizer, model
266
- torch.cuda.empty_cache()
267
- except Exception as e:
268
- logging.warning(f"Qwen processing failed: {e}")
269
- final_prompt = original_prompt
270
- else:
271
- # Fallback to online Qwen model
272
- try:
273
- tokenizer, model = create_qwen_agent(DEFAULT_QWEN_MODEL_PATH)
274
- final_prompt = recap_prompt(tokenizer, model, original_prompt)
275
- del tokenizer, model
276
- torch.cuda.empty_cache()
277
- except Exception as e:
278
- logging.warning(f"Online Qwen failed: {e}")
279
- final_prompt = original_prompt
280
-
281
- # Generate seed
282
- if seed_input == -1:
283
- seed = random.randint(0, MAX_SEED)
284
- else:
285
- seed = int(seed_input)
286
-
287
- generator = torch.Generator(device=device).manual_seed(seed)
288
-
289
- # Generate image
290
- progress(0.7, desc="Generating image...")
291
- with torch.no_grad():
292
- result = pipeline(
293
- prompt=final_prompt,
294
- height=height,
295
- width=width,
296
- num_inference_steps=num_inference_steps,
297
- guidance_scale=guidance_scale,
298
- generator=generator,
299
  )
300
 
301
- image = result.images[0]
302
 
303
- # Clean up
304
- del pipeline
305
- torch.cuda.empty_cache()
306
 
307
- progress(1.0, desc="Complete!")
308
- return image, f"✅ Generation complete! Seed: {seed}", final_prompt
 
 
 
 
 
 
 
 
 
 
309
 
310
  except Exception as e:
311
  logging.error(f"Generation failed: {e}")
312
  return None, f"❌ Generation failed: {str(e)}", ""
313
 
314
  # ------------------------------------------------------------------
315
- # 6. Gradio Interface
316
  # ------------------------------------------------------------------
317
  def create_interface():
318
  """Create Gradio interface"""
@@ -323,8 +344,6 @@ def create_interface():
323
  css="""
324
  .main-container { max-width: 1200px; margin: 0 auto; }
325
  .status-box { padding: 10px; border-radius: 5px; margin: 10px 0; }
326
- .auth-success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
327
- .auth-error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
328
  """
329
  ) as demo:
330
 
@@ -343,98 +362,47 @@ def create_interface():
343
 
344
  gr.HTML("""
345
  <div class="status-box">
346
- <p><strong>⚠️ First use requires model download, please wait about 10-15 minutes</strong></p>
347
  </div>
348
  """)
349
-
350
  with gr.Row():
351
  with gr.Column(scale=1):
352
- original_prompt = gr.Textbox(
353
- label="Poster Prompt",
 
 
354
  placeholder="Enter your poster description...",
355
- lines=3,
356
  value="A vintage travel poster for Paris, featuring the Eiffel Tower at sunset with warm golden lighting"
357
  )
358
-
359
- enable_recap = gr.Checkbox(
360
- label="Enable Prompt Enhancement (Qwen3-8B)",
361
- value=True,
362
  info="Use AI to enhance and expand your prompt"
363
  )
364
 
365
  with gr.Row():
366
- height = gr.Slider(
367
- label="Height",
368
- minimum=256,
369
- maximum=MAX_IMAGE_SIZE,
370
- value=1024,
371
- step=32
372
- )
373
- width = gr.Slider(
374
- label="Width",
375
- minimum=256,
376
- maximum=MAX_IMAGE_SIZE,
377
- value=768,
378
- step=32
379
- )
380
-
381
- with gr.Row():
382
- num_inference_steps = gr.Slider(
383
- label="Inference Steps",
384
- minimum=1,
385
- maximum=50,
386
- value=20,
387
- step=1
388
- )
389
- guidance_scale = gr.Slider(
390
- label="Guidance Scale",
391
- minimum=1.0,
392
- maximum=15.0,
393
- value=3.5,
394
- step=0.1
395
- )
396
-
397
- seed_input = gr.Number(
398
- label="Seed (-1 for random)",
399
- value=-1,
400
- precision=0
401
- )
402
 
403
- generate_btn = gr.Button(
404
- "🎨 Generate Poster",
405
- variant="primary",
406
- size="lg"
407
- )
408
-
409
  with gr.Column(scale=1):
410
- output_image = gr.Image(
411
- label="Generated Poster",
412
- type="pil",
413
- height=600
414
- )
415
-
416
- status_output = gr.Textbox(
417
- label="Generation Status",
418
- interactive=False,
419
- lines=2
420
- )
421
-
422
- enhanced_prompt = gr.Textbox(
423
- label="Enhanced Prompt",
424
- interactive=False,
425
- lines=5,
426
- info="The final prompt used for generation"
427
- )
428
 
429
- # Event handlers
430
- generate_btn.click(
431
- fn=generate_image_interface,
432
- inputs=[
433
- original_prompt, enable_recap, height, width,
434
- num_inference_steps, guidance_scale, seed_input
435
- ],
436
- outputs=[output_image, status_output, enhanced_prompt]
437
- )
438
 
439
  # Examples
440
  gr.Examples(
@@ -444,13 +412,13 @@ def create_interface():
444
  ["A minimalist concert poster with bold typography"],
445
  ["A vintage advertisement for organic coffee"],
446
  ],
447
- inputs=[original_prompt]
448
  )
449
 
450
  return demo
451
 
452
  # ------------------------------------------------------------------
453
- # 7. Launch Application
454
  # ------------------------------------------------------------------
455
  if __name__ == "__main__":
456
  demo = create_interface()
@@ -458,4 +426,4 @@ if __name__ == "__main__":
458
  server_name="0.0.0.0",
459
  server_port=7860,
460
  show_api=False
461
- )
 
45
  # 2. Model Download Function (CPU only)
46
  # ------------------------------------------------------------------
47
  def download_model_weights(target_dir, repo_id, subdir=None):
48
+ """Download model weights to specified directory (CPU operation)"""
 
 
 
 
 
 
 
49
  from huggingface_hub import snapshot_download
50
  import shutil
51
 
 
64
  "local_dir_use_symlinks": False,
65
  }
66
 
 
67
  if hf_token:
68
  download_kwargs["token"] = hf_token
69
 
 
117
  ensure_models_downloaded()
118
 
119
  # ------------------------------------------------------------------
120
+ # 4. Qwen Recap Agent (基于你的原始逻辑)
121
  # ------------------------------------------------------------------
122
+ class QwenRecapAgent:
123
+ def __init__(self, model_path, max_retries=3, retry_delay=2, device_map="auto"):
124
+ self.max_retries = max_retries
125
+ self.retry_delay = retry_delay
126
+ self.device = device_map
 
 
 
 
 
 
 
 
 
127
 
128
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)
129
+ model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": device_map if device_map == "auto" else None}
130
+ if hf_token:
131
+ model_kwargs["token"] = hf_token
132
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
133
+ if device_map != "auto":
134
+ self.model.to(device_map)
135
+
136
+ self.prompt_template = """You are an expert poster prompt designer. Your task is to rewrite a user's short poster prompt into a detailed and vivid long-format prompt. Follow these steps carefully:
137
 
138
  **Step 1: Analyze the Core Requirements**
139
  Identify the key elements in the user's prompt. Do not miss any details.
 
171
  ---
172
  **User Prompt:**
173
  {brief_description}"""
174
+
175
+ def recap_prompt(self, original_prompt):
176
+ full_prompt = self.prompt_template.format(brief_description=original_prompt)
177
+ messages = [{"role": "user", "content": full_prompt}]
178
+ try:
179
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
180
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
181
+
182
+ with torch.no_grad():
183
+ generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024, temperature=0.6)
184
+
185
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
186
+ full_response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
187
+ final_answer = self._extract_final_answer(full_response)
188
+
189
+ if final_answer:
190
+ return final_answer.strip()
191
+
192
+ logging.info("Qwen returned an empty answer. Using original prompt.")
193
+ return original_prompt
194
+ except Exception as e:
195
+ logging.error(f"Qwen recap failed: {e}. Using original prompt.")
196
+ return original_prompt
197
 
198
+ def _extract_final_answer(self, full_response):
199
+ if "</think>" in full_response:
200
+ return full_response.split("</think>")[-1].strip()
201
+ if "<think>" not in full_response:
202
+ return full_response.strip()
203
+ return None
204
+
205
+ # ------------------------------------------------------------------
206
+ # 5. Poster Generator Class (基于你的原始逻辑,但加上缓存)
207
+ # ------------------------------------------------------------------
208
+ class PosterGenerator:
209
+ def __init__(self, pipeline_path, qwen_model_path, custom_weights_path, device):
210
+ self.device = device
211
+ self.pipeline_path = pipeline_path
212
+ self.qwen_model_path = qwen_model_path
213
+ self.custom_weights_path = custom_weights_path
214
 
215
+ # 缓存变量
216
+ self.qwen_agent = None
217
+ self.pipeline = None
218
+
219
+ def _load_qwen_agent(self):
220
+ if self.qwen_agent is None:
221
+ if not self.qwen_model_path:
222
+ return None
223
+
224
+ # 检查本地路径
225
+ qwen_local = "local_weights/Qwen3-8B"
226
+ model_path = qwen_local if os.path.exists(qwen_local) else self.qwen_model_path
227
+
228
+ logging.info(f"Loading Qwen agent from {model_path}")
229
+ self.qwen_agent = QwenRecapAgent(model_path=model_path, device_map=str(self.device))
230
+ return self.qwen_agent
231
+
232
+ def _load_flux_pipeline(self):
233
+ if self.pipeline is None:
234
+ logging.info("Loading FLUX pipeline...")
235
+ self.pipeline = FluxPipeline.from_pretrained(
236
+ self.pipeline_path,
237
+ torch_dtype=torch.bfloat16,
238
+ token=hf_token
239
  )
240
+
241
+ # 加载自定义权重
242
+ custom_weights_local = "local_weights/PosterCraft-v1_RL"
243
+ if os.path.exists(custom_weights_local):
244
+ logging.info(f"Loading custom Transformer from directory: {custom_weights_local}")
245
+ transformer = FluxTransformer2DModel.from_pretrained(
246
+ custom_weights_local,
247
+ torch_dtype=torch.bfloat16,
248
+ token=hf_token
249
+ )
250
+ self.pipeline.transformer = transformer
251
+ elif self.custom_weights_path and os.path.exists(self.custom_weights_path):
252
+ logging.info(f"Loading custom Transformer from directory: {self.custom_weights_path}")
253
+ transformer = FluxTransformer2DModel.from_pretrained(
254
+ self.custom_weights_path,
255
+ torch_dtype=torch.bfloat16,
256
+ token=hf_token
257
+ )
258
+ self.pipeline.transformer = transformer
259
+
260
+ self.pipeline.to(self.device)
261
+ return self.pipeline
262
+
263
+ def generate(self, prompt, enable_recap, **kwargs):
264
+ final_prompt = prompt
265
+ if enable_recap:
266
+ qwen_agent = self._load_qwen_agent()
267
+ if not qwen_agent:
268
+ raise gr.Error("Recap is enabled, but the recap model is not available. Check model path.")
269
+ final_prompt = qwen_agent.recap_prompt(prompt)
270
+
271
+ pipeline = self._load_flux_pipeline()
272
+ generator = torch.Generator(device=self.device).manual_seed(kwargs['seed'])
273
 
274
+ with torch.inference_mode():
275
+ image = pipeline(
276
+ prompt=final_prompt,
277
+ generator=generator,
278
+ num_inference_steps=kwargs['num_inference_steps'],
279
+ guidance_scale=kwargs['guidance_scale'],
280
+ width=kwargs['width'],
281
+ height=kwargs['height']
282
+ ).images[0]
283
+
284
+ return image, final_prompt
285
 
286
  # ------------------------------------------------------------------
287
+ # 6. Main Generation Function (GPU) - 保持你的原始逻辑
288
  # ------------------------------------------------------------------
289
  @spaces.GPU(duration=300)
290
  def generate_image_interface(
 
293
  progress=gr.Progress(track_tqdm=True),
294
  ):
295
  """Generate image using FLUX pipeline"""
296
+ if not original_prompt or not original_prompt.strip():
297
+ return None, "❌ Prompt cannot be empty!", ""
298
+
299
  try:
 
300
  if not hf_token:
301
  return None, "❌ Error: HF_TOKEN not found. Please configure authentication.", ""
302
 
 
303
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
304
 
305
+ # 全局生成器实例
306
+ if not hasattr(generate_image_interface, 'generator'):
307
+ generate_image_interface.generator = PosterGenerator(
308
+ pipeline_path=DEFAULT_PIPELINE_PATH,
309
+ qwen_model_path=DEFAULT_QWEN_MODEL_PATH,
310
+ custom_weights_path=DEFAULT_CUSTOM_WEIGHTS_PATH,
311
+ device=device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  )
313
 
314
+ actual_seed = int(seed_input) if seed_input and seed_input != -1 else random.randint(1, 2**32 - 1)
315
 
316
+ progress(0.1, desc="Starting generation...")
 
 
317
 
318
+ image, final_prompt = generate_image_interface.generator.generate(
319
+ prompt=original_prompt,
320
+ enable_recap=enable_recap,
321
+ height=int(height),
322
+ width=int(width),
323
+ num_inference_steps=int(num_inference_steps),
324
+ guidance_scale=float(guidance_scale),
325
+ seed=actual_seed
326
+ )
327
+
328
+ status_log = f"✅ Generation complete! Seed: {actual_seed}"
329
+ return image, status_log, final_prompt
330
 
331
  except Exception as e:
332
  logging.error(f"Generation failed: {e}")
333
  return None, f"❌ Generation failed: {str(e)}", ""
334
 
335
  # ------------------------------------------------------------------
336
+ # 7. Gradio Interface (保持你的原始风格)
337
  # ------------------------------------------------------------------
338
  def create_interface():
339
  """Create Gradio interface"""
 
344
  css="""
345
  .main-container { max-width: 1200px; margin: 0 auto; }
346
  .status-box { padding: 10px; border-radius: 5px; margin: 10px 0; }
 
 
347
  """
348
  ) as demo:
349
 
 
362
 
363
  gr.HTML("""
364
  <div class="status-box">
365
+ <p><strong>⚠️ First generation requires model loading (5-10 minutes). Subsequent generations are much faster!</strong></p>
366
  </div>
367
  """)
368
+
369
  with gr.Row():
370
  with gr.Column(scale=1):
371
+ gr.Markdown("### 1. Configuration")
372
+ prompt_input = gr.Textbox(
373
+ label="Poster Prompt",
374
+ lines=3,
375
  placeholder="Enter your poster description...",
 
376
  value="A vintage travel poster for Paris, featuring the Eiffel Tower at sunset with warm golden lighting"
377
  )
378
+ enable_recap_checkbox = gr.Checkbox(
379
+ label="Enable Prompt Enhancement (Qwen3-8B)",
380
+ value=True,
 
381
  info="Use AI to enhance and expand your prompt"
382
  )
383
 
384
  with gr.Row():
385
+ width_input = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, value=768, step=32)
386
+ height_input = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, value=1024, step=32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
+ num_inference_steps_input = gr.Slider(label="Inference Steps", minimum=1, maximum=100, value=20, step=1)
389
+ guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, value=3.5, step=0.1)
390
+ seed_number_input = gr.Number(label="Seed (-1 for random)", value=-1, minimum=-1, step=1)
391
+ generate_button = gr.Button("🎨 Generate Poster", variant="primary", size="lg")
392
+
 
393
  with gr.Column(scale=1):
394
+ gr.Markdown("### 2. Results")
395
+ image_output = gr.Image(label="Generated Poster", type="pil", height=600)
396
+ status_output = gr.Textbox(label="Generation Status", lines=2, interactive=False)
397
+ recapped_prompt_output = gr.Textbox(label="Enhanced Prompt", lines=5, interactive=False, info="The final prompt used for generation")
398
+
399
+ inputs_list = [
400
+ prompt_input, enable_recap_checkbox, height_input, width_input,
401
+ num_inference_steps_input, guidance_scale_input, seed_number_input
402
+ ]
403
+ outputs_list = [image_output, status_output, recapped_prompt_output]
 
 
 
 
 
 
 
 
404
 
405
+ generate_button.click(fn=generate_image_interface, inputs=inputs_list, outputs=outputs_list)
 
 
 
 
 
 
 
 
406
 
407
  # Examples
408
  gr.Examples(
 
412
  ["A minimalist concert poster with bold typography"],
413
  ["A vintage advertisement for organic coffee"],
414
  ],
415
+ inputs=[prompt_input]
416
  )
417
 
418
  return demo
419
 
420
  # ------------------------------------------------------------------
421
+ # 8. Launch Application
422
  # ------------------------------------------------------------------
423
  if __name__ == "__main__":
424
  demo = create_interface()
 
426
  server_name="0.0.0.0",
427
  server_port=7860,
428
  show_api=False
429
+ )