Ephemeral182 commited on
Commit
9819163
·
verified ·
1 Parent(s): bbb89b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -96
app.py CHANGED
@@ -7,22 +7,27 @@ import spaces
7
  import torch
8
  from diffusers import FluxPipeline, FluxTransformer2DModel
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from huggingface_hub import login
11
 
12
  # ------------------------------------------------------------------
13
  # 1. Authentication and Global Configuration
14
  # ------------------------------------------------------------------
15
  # Authenticate with HF token
16
  hf_token = os.getenv("HF_TOKEN")
 
 
17
  if hf_token:
18
  try:
19
  login(token=hf_token, add_to_git_credential=True)
20
- logging.info("Successfully authenticated with Hugging Face")
 
 
21
  except Exception as e:
22
  logging.error(f"HF authentication failed: {e}")
23
- raise Exception("Authentication failed. Please check your HF_TOKEN.")
24
  else:
25
  logging.warning("No HF_TOKEN found in environment variables")
 
26
 
27
  DEFAULT_PIPELINE_PATH = "black-forest-labs/FLUX.1-dev"
28
  DEFAULT_QWEN_MODEL_PATH = "Qwen/Qwen3-8B"
@@ -199,11 +204,11 @@ Elaborate on each core requirement to create a rich description.
199
  out = out.split("</think>")[-1].strip()
200
  return out or text
201
  except Exception as e:
202
- logging.warning(f"Recap failed: {e}. Using original prompt.")
203
  return text
204
 
205
  # ------------------------------------------------------------------
206
- # 5. ZeroGPU Inference Function
207
  # ------------------------------------------------------------------
208
  @spaces.GPU(duration=300)
209
  def generate_image_interface(
@@ -211,135 +216,246 @@ def generate_image_interface(
211
  num_inference_steps, guidance_scale, seed_input,
212
  progress=gr.Progress(track_tqdm=True),
213
  ):
214
- if not original_prompt or not original_prompt.strip():
215
- raise gr.Error("Prompt cannot be empty!")
 
 
 
216
 
217
- if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
218
- raise gr.Error(f"Maximum resolution limit is {MAX_IMAGE_SIZE}×{MAX_IMAGE_SIZE}")
 
219
 
220
- try:
221
- actual_seed = int(seed_input) if seed_input and seed_input > 0 else random.randint(1, 2**32 - 1)
222
-
223
  progress(0.1, desc="Loading FLUX pipeline...")
 
 
 
 
 
 
224
 
225
- # Load FLUX pipeline with explicit token
226
- load_kwargs = {
227
- "torch_dtype": torch.bfloat16
228
- }
229
- if hf_token:
230
- load_kwargs["token"] = hf_token
231
-
232
- pipeline = FluxPipeline.from_pretrained(DEFAULT_PIPELINE_PATH, **load_kwargs)
233
-
234
- progress(0.2, desc="Loading custom transformer...")
235
-
236
- # Load custom transformer if available
237
  custom_weights_local = "local_weights/PosterCraft-v1_RL"
238
  if os.path.exists(custom_weights_local):
 
239
  try:
240
- transformer_kwargs = {"torch_dtype": torch.bfloat16}
241
- if hf_token:
242
- transformer_kwargs["token"] = hf_token
243
-
244
- transformer = FluxTransformer2DModel.from_pretrained(
245
- custom_weights_local, **transformer_kwargs
246
  )
247
- pipeline.transformer = transformer
248
- logging.info("Custom Transformer loaded successfully")
249
  except Exception as e:
250
- logging.warning(f"Custom weights loading failed: {e}, using default weights")
251
-
252
- # Move pipeline to GPU
253
- pipeline = pipeline.to("cuda")
254
 
 
255
  final_prompt = original_prompt
256
-
257
  if enable_recap:
258
- progress(0.4, desc="Loading Qwen model for prompt enhancement...")
259
-
260
  qwen_local = "local_weights/Qwen3-8B"
261
  if os.path.exists(qwen_local):
262
  try:
263
  tokenizer, model = create_qwen_agent(qwen_local)
264
  final_prompt = recap_prompt(tokenizer, model, original_prompt)
265
- progress(0.6, desc="Prompt enhanced, starting generation...")
266
-
267
  # Clean up Qwen model to free memory
268
  del tokenizer, model
269
  torch.cuda.empty_cache()
270
  except Exception as e:
271
- logging.warning(f"Qwen model failed: {e}, using original prompt")
272
  final_prompt = original_prompt
273
  else:
274
- logging.warning("Qwen model not found, using original prompt")
275
- final_prompt = original_prompt
276
-
277
- progress(0.7, desc="Generating image...")
 
 
 
 
 
 
 
 
 
 
 
278
 
 
 
279
  # Generate image
280
- generator = torch.Generator(device="cuda").manual_seed(actual_seed)
281
-
282
- with torch.inference_mode():
283
- image = pipeline(
284
  prompt=final_prompt,
 
 
 
 
285
  generator=generator,
286
- num_inference_steps=int(num_inference_steps),
287
- guidance_scale=float(guidance_scale),
288
- width=int(width),
289
- height=int(height)
290
- ).images[0]
291
 
292
- progress(1.0, desc="Generation complete!")
293
 
294
- status_log = f"Seed: {actual_seed} | Generation complete."
295
- return image, final_prompt, status_log
 
296
 
 
 
 
297
  except Exception as e:
298
  logging.error(f"Generation failed: {e}")
299
- raise gr.Error(f"An error occurred: {str(e)}")
300
 
301
  # ------------------------------------------------------------------
302
  # 6. Gradio Interface
303
  # ------------------------------------------------------------------
304
- with gr.Blocks(theme=gr.themes.Soft(), title="PosterCraft") as demo:
305
- gr.Markdown("# PosterCraft-v1.0")
306
- gr.Markdown(f"Base Pipeline: **{DEFAULT_PIPELINE_PATH}**")
307
-
308
- # Show authentication status
309
- auth_status = "🟢 Authenticated" if hf_token else "🔴 Not Authenticated"
310
- gr.Markdown(f"Authentication Status: {auth_status}")
311
 
312
- gr.Markdown("⚠️ **First use requires model download, please wait about 10-15 minutes**")
313
-
314
- with gr.Row():
315
- with gr.Column(scale=1):
316
- gr.Markdown("### 1. Configuration")
317
- prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your creative prompt...")
318
- enable_recap_checkbox = gr.Checkbox(label="Enable Prompt Recap", value=True, info="Uses Qwen3-8B for prompt enhancement")
319
-
320
- with gr.Row():
321
- width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=832, step=64)
322
- height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=1216, step=64)
323
- gr.Markdown("Tip: Recommended size is 832x1216 for best results.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- num_inference_steps_input = gr.Slider(label="Inference Steps", minimum=1, maximum=100, value=28, step=1)
326
- guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=0.0, maximum=20.0, value=3.5, step=0.1)
327
- seed_number_input = gr.Number(label="Seed", value=None, minimum=-1, step=1, info="Leave blank or set to -1 for a random seed.")
328
- generate_button = gr.Button("Generate Image", variant="primary")
329
-
330
- with gr.Column(scale=1):
331
- gr.Markdown("### 2. Results")
332
- image_output = gr.Image(label="Generated Image", type="pil", show_download_button=True, height=512)
333
- recapped_prompt_output = gr.Textbox(label="Final Prompt Used", lines=5, interactive=False)
334
- status_output = gr.Textbox(label="Status Log", lines=4, interactive=False)
335
-
336
- inputs_list = [
337
- prompt_input, enable_recap_checkbox, height_input, width_input,
338
- num_inference_steps_input, guidance_scale_input, seed_number_input
339
- ]
340
- outputs_list = [image_output, recapped_prompt_output, status_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- generate_button.click(fn=generate_image_interface, inputs=inputs_list, outputs=outputs_list)
343
 
 
 
 
344
  if __name__ == "__main__":
345
- demo.launch()
 
 
 
 
 
 
7
  import torch
8
  from diffusers import FluxPipeline, FluxTransformer2DModel
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from huggingface_hub import login, whoami
11
 
12
  # ------------------------------------------------------------------
13
  # 1. Authentication and Global Configuration
14
  # ------------------------------------------------------------------
15
  # Authenticate with HF token
16
  hf_token = os.getenv("HF_TOKEN")
17
+ auth_status = "🔴 Not Authenticated"
18
+
19
  if hf_token:
20
  try:
21
  login(token=hf_token, add_to_git_credential=True)
22
+ user_info = whoami(hf_token)
23
+ auth_status = f"✅ Authenticated as {user_info['name']}"
24
+ logging.info(f"Successfully authenticated with Hugging Face as {user_info['name']}")
25
  except Exception as e:
26
  logging.error(f"HF authentication failed: {e}")
27
+ auth_status = f"🔴 Authentication Error: {str(e)}"
28
  else:
29
  logging.warning("No HF_TOKEN found in environment variables")
30
+ auth_status = "🔴 No HF_TOKEN found"
31
 
32
  DEFAULT_PIPELINE_PATH = "black-forest-labs/FLUX.1-dev"
33
  DEFAULT_QWEN_MODEL_PATH = "Qwen/Qwen3-8B"
 
204
  out = out.split("</think>")[-1].strip()
205
  return out or text
206
  except Exception as e:
207
+ logging.error(f"Prompt recap failed: {e}")
208
  return text
209
 
210
  # ------------------------------------------------------------------
211
+ # 5. Main Generation Function (GPU)
212
  # ------------------------------------------------------------------
213
  @spaces.GPU(duration=300)
214
  def generate_image_interface(
 
216
  num_inference_steps, guidance_scale, seed_input,
217
  progress=gr.Progress(track_tqdm=True),
218
  ):
219
+ """Generate image using FLUX pipeline"""
220
+ try:
221
+ # If no token available, return error message
222
+ if not hf_token:
223
+ return None, "❌ Error: HF_TOKEN not found. Please configure authentication.", ""
224
 
225
+ # Set device and dtype
226
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
227
+ torch_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
228
 
229
+ # Initialize FLUX pipeline
 
 
230
  progress(0.1, desc="Loading FLUX pipeline...")
231
+ pipeline = FluxPipeline.from_pretrained(
232
+ DEFAULT_PIPELINE_PATH,
233
+ torch_dtype=torch_dtype,
234
+ device_map="balanced" if device.type == "cuda" else None,
235
+ token=hf_token
236
+ )
237
 
238
+ # Load custom transformer weights if available
 
 
 
 
 
 
 
 
 
 
 
239
  custom_weights_local = "local_weights/PosterCraft-v1_RL"
240
  if os.path.exists(custom_weights_local):
241
+ progress(0.3, desc="Loading custom transformer weights...")
242
  try:
243
+ custom_transformer = FluxTransformer2DModel.from_pretrained(
244
+ custom_weights_local,
245
+ torch_dtype=torch_dtype,
246
+ device_map="balanced" if device.type == "cuda" else None,
247
+ token=hf_token
 
248
  )
249
+ pipeline.transformer = custom_transformer
250
+ logging.info("Custom transformer weights loaded successfully")
251
  except Exception as e:
252
+ logging.warning(f"Failed to load custom transformer weights: {e}")
 
 
 
253
 
254
+ # Process prompt
255
  final_prompt = original_prompt
 
256
  if enable_recap:
257
+ progress(0.5, desc="Processing prompt with Qwen...")
 
258
  qwen_local = "local_weights/Qwen3-8B"
259
  if os.path.exists(qwen_local):
260
  try:
261
  tokenizer, model = create_qwen_agent(qwen_local)
262
  final_prompt = recap_prompt(tokenizer, model, original_prompt)
263
+ logging.info(f"Enhanced prompt: {final_prompt}")
 
264
  # Clean up Qwen model to free memory
265
  del tokenizer, model
266
  torch.cuda.empty_cache()
267
  except Exception as e:
268
+ logging.warning(f"Qwen processing failed: {e}")
269
  final_prompt = original_prompt
270
  else:
271
+ # Fallback to online Qwen model
272
+ try:
273
+ tokenizer, model = create_qwen_agent(DEFAULT_QWEN_MODEL_PATH)
274
+ final_prompt = recap_prompt(tokenizer, model, original_prompt)
275
+ del tokenizer, model
276
+ torch.cuda.empty_cache()
277
+ except Exception as e:
278
+ logging.warning(f"Online Qwen failed: {e}")
279
+ final_prompt = original_prompt
280
+
281
+ # Generate seed
282
+ if seed_input == -1:
283
+ seed = random.randint(0, MAX_SEED)
284
+ else:
285
+ seed = int(seed_input)
286
 
287
+ generator = torch.Generator(device=device).manual_seed(seed)
288
+
289
  # Generate image
290
+ progress(0.7, desc="Generating image...")
291
+ with torch.no_grad():
292
+ result = pipeline(
 
293
  prompt=final_prompt,
294
+ height=height,
295
+ width=width,
296
+ num_inference_steps=num_inference_steps,
297
+ guidance_scale=guidance_scale,
298
  generator=generator,
299
+ )
 
 
 
 
300
 
301
+ image = result.images[0]
302
 
303
+ # Clean up
304
+ del pipeline
305
+ torch.cuda.empty_cache()
306
 
307
+ progress(1.0, desc="Complete!")
308
+ return image, f"✅ Generation complete! Seed: {seed}", final_prompt
309
+
310
  except Exception as e:
311
  logging.error(f"Generation failed: {e}")
312
+ return None, f" Generation failed: {str(e)}", ""
313
 
314
  # ------------------------------------------------------------------
315
  # 6. Gradio Interface
316
  # ------------------------------------------------------------------
317
+ def create_interface():
318
+ """Create Gradio interface"""
 
 
 
 
 
319
 
320
+ with gr.Blocks(
321
+ title="PosterCraft-v1.0",
322
+ theme=gr.themes.Soft(),
323
+ css="""
324
+ .main-container { max-width: 1200px; margin: 0 auto; }
325
+ .status-box { padding: 10px; border-radius: 5px; margin: 10px 0; }
326
+ .auth-success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
327
+ .auth-error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
328
+ """
329
+ ) as demo:
330
+
331
+ gr.HTML("""
332
+ <div class="main-container">
333
+ <h1 style="text-align: center; margin-bottom: 20px;">🎨 PosterCraft-v1.0</h1>
334
+ <p style="text-align: center; color: #666; margin-bottom: 30px;">
335
+ Professional poster generation with FLUX.1-dev and custom fine-tuned weights
336
+ </p>
337
+ </div>
338
+ """)
339
+
340
+ with gr.Row():
341
+ gr.Markdown(f"**Base Pipeline:** `{DEFAULT_PIPELINE_PATH}`")
342
+ gr.Markdown(f"**Authentication Status:** {auth_status}")
343
+
344
+ gr.HTML("""
345
+ <div class="status-box">
346
+ <p><strong>⚠️ First use requires model download, please wait about 10-15 minutes</strong></p>
347
+ </div>
348
+ """)
349
+
350
+ with gr.Row():
351
+ with gr.Column(scale=1):
352
+ original_prompt = gr.Textbox(
353
+ label="Poster Prompt",
354
+ placeholder="Enter your poster description...",
355
+ lines=3,
356
+ value="A vintage travel poster for Paris, featuring the Eiffel Tower at sunset with warm golden lighting"
357
+ )
358
+
359
+ enable_recap = gr.Checkbox(
360
+ label="Enable Prompt Enhancement (Qwen3-8B)",
361
+ value=True,
362
+ info="Use AI to enhance and expand your prompt"
363
+ )
364
+
365
+ with gr.Row():
366
+ height = gr.Slider(
367
+ label="Height",
368
+ minimum=256,
369
+ maximum=MAX_IMAGE_SIZE,
370
+ value=1024,
371
+ step=32
372
+ )
373
+ width = gr.Slider(
374
+ label="Width",
375
+ minimum=256,
376
+ maximum=MAX_IMAGE_SIZE,
377
+ value=768,
378
+ step=32
379
+ )
380
+
381
+ with gr.Row():
382
+ num_inference_steps = gr.Slider(
383
+ label="Inference Steps",
384
+ minimum=1,
385
+ maximum=50,
386
+ value=20,
387
+ step=1
388
+ )
389
+ guidance_scale = gr.Slider(
390
+ label="Guidance Scale",
391
+ minimum=1.0,
392
+ maximum=15.0,
393
+ value=3.5,
394
+ step=0.1
395
+ )
396
+
397
+ seed_input = gr.Number(
398
+ label="Seed (-1 for random)",
399
+ value=-1,
400
+ precision=0
401
+ )
402
+
403
+ generate_btn = gr.Button(
404
+ "🎨 Generate Poster",
405
+ variant="primary",
406
+ size="lg"
407
+ )
408
 
409
+ with gr.Column(scale=1):
410
+ output_image = gr.Image(
411
+ label="Generated Poster",
412
+ type="pil",
413
+ height=600
414
+ )
415
+
416
+ status_output = gr.Textbox(
417
+ label="Generation Status",
418
+ interactive=False,
419
+ lines=2
420
+ )
421
+
422
+ enhanced_prompt = gr.Textbox(
423
+ label="Enhanced Prompt",
424
+ interactive=False,
425
+ lines=5,
426
+ info="The final prompt used for generation"
427
+ )
428
+
429
+ # Event handlers
430
+ generate_btn.click(
431
+ fn=generate_image_interface,
432
+ inputs=[
433
+ original_prompt, enable_recap, height, width,
434
+ num_inference_steps, guidance_scale, seed_input
435
+ ],
436
+ outputs=[output_image, status_output, enhanced_prompt]
437
+ )
438
+
439
+ # Examples
440
+ gr.Examples(
441
+ examples=[
442
+ ["A retro sci-fi movie poster with neon colors and flying cars"],
443
+ ["An elegant art deco poster for a luxury hotel"],
444
+ ["A minimalist concert poster with bold typography"],
445
+ ["A vintage advertisement for organic coffee"],
446
+ ],
447
+ inputs=[original_prompt]
448
+ )
449
 
450
+ return demo
451
 
452
+ # ------------------------------------------------------------------
453
+ # 7. Launch Application
454
+ # ------------------------------------------------------------------
455
  if __name__ == "__main__":
456
+ demo = create_interface()
457
+ demo.launch(
458
+ server_name="0.0.0.0",
459
+ server_port=7860,
460
+ show_api=False
461
+ )