alexnasa commited on
Commit
c71a46c
·
verified ·
1 Parent(s): 1924d17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -76
app.py CHANGED
@@ -13,9 +13,6 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- # STEP 1: Very first thing in the file: force spawn
17
- import multiprocessing as mp
18
- mp.set_start_method("spawn", force=True)
19
 
20
  import spaces
21
 
@@ -119,7 +116,7 @@ detector = ObjectDetector(device)
119
  config = get_train_config(config_path)
120
  model.config = config
121
 
122
- run_mode = "mod_only" # orig_only, mod_only, both
123
  store_attn_map = False
124
  run_name = time.strftime("%m%d-%H%M")
125
 
@@ -166,7 +163,6 @@ def crop_face_img(image):
166
  if isinstance(image, str):
167
  image = Image.open(image).convert("RGB")
168
 
169
- # image = resize_keep_aspect_ratio(image, 1024)
170
  image = pad_to_square(image).resize((2048, 2048))
171
 
172
  face_bbox = face_model.detect(
@@ -194,7 +190,7 @@ def vlm_img_caption(image):
194
 
195
 
196
  def generate_random_string(length=4):
197
- letters = string.ascii_letters # 包含大小写字母的字符串
198
  result_str = ''.join(random.choice(letters) for i in range(length))
199
  return result_str
200
 
@@ -209,31 +205,31 @@ def resize_keep_aspect_ratio(pil_image, target_size=1024):
209
 
210
  @spaces.GPU()
211
  def generate_image(
212
- prompt,
213
  cond_size, target_height, target_width,
214
  seed,
215
  vae_skip_iter, control_weight_lambda,
216
- double_attention, # 新增参数
217
- single_attention, # 新增参数
218
  ip_scale,
219
  latent_sblora_scale_str, vae_lora_scale,
220
- indexs, # 新增参数
221
- *images_captions_faces, # Combine all unpacked arguments into one tuple
 
222
  ):
223
  torch.cuda.empty_cache()
224
  num_images = 1
225
 
226
- # Determine the number of images, captions, and faces based on the indexs length
227
  images = list(images_captions_faces[:num_inputs])
228
  captions = list(images_captions_faces[num_inputs:2 * num_inputs])
229
  idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs])
230
- images = [images[i] for i in indexs]
231
- captions = [captions[i] for i in indexs]
232
- idips_checkboxes = [idips_checkboxes[i] for i in indexs]
233
 
234
  print(f"Length of images: {len(images)}")
235
  print(f"Length of captions: {len(captions)}")
236
- print(f"Indexs: {indexs}")
237
 
238
  print(f"Control weight lambda: {control_weight_lambda}")
239
  if control_weight_lambda != "no":
@@ -243,7 +239,6 @@ def generate_image(
243
  if ':' in part:
244
  left, right = part.split(':')
245
  values = right.split('/')
246
- # 保存整体值
247
  global_value = values[0]
248
  id_value = values[1]
249
  ip_value = values[2]
@@ -265,7 +260,7 @@ def generate_image(
265
  use_words = []
266
  cur_run_time = time.strftime("%m%d-%H%M%S")
267
  tmp_dir_root = f"tmp/gradio_demo/{run_name}"
268
- temp_dir = f"{tmp_dir_root}/{cur_run_time}_{generate_random_string(4)}"
269
  os.makedirs(temp_dir, exist_ok=True)
270
  print(f"Temporary directory created: {temp_dir}")
271
  for i, (image_path, caption) in enumerate(zip(images, captions)):
@@ -279,7 +274,7 @@ def generate_image(
279
  prompt = prompt.replace(f"ENT{i+1}", caption)
280
 
281
  image = resize_keep_aspect_ratio(Image.open(image_path), 768)
282
- save_path = f"{temp_dir}/tmp_resized_input_{i}.png"
283
  image.save(save_path)
284
 
285
  input_image_path = save_path
@@ -317,7 +312,7 @@ def generate_image(
317
  ),
318
  ]
319
 
320
- json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8')
321
  assert single_attention == True
322
  target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16)
323
  print(test_sample)
@@ -338,10 +333,10 @@ def generate_image(
338
  target_width=target_width,
339
  seed=seed,
340
  store_attn_map=store_attn_map,
341
- vae_skip_iter=vae_skip_iter, # 使用新的参数
342
- control_weight_lambda=control_weight_lambda, # 传递新的参数
343
- double_attention=double_attention, # 新增参数
344
- single_attention=single_attention, # 新增参数
345
  ip_scale=ip_scale,
346
  use_latent_sblora_control=use_latent_sblora_control,
347
  latent_sblora_scale=latent_sblora_scale_str,
@@ -353,12 +348,12 @@ def generate_image(
353
  num_rows = int(math.ceil(num_images / num_cols))
354
  image = image_grid(image, num_rows, num_cols)
355
 
356
- save_path = f"{temp_dir}/tmp_result.png"
357
- image.save(save_path)
358
 
359
  return image
360
 
361
- def create_image_input(index, open=True, indexs_state=None):
362
  accordion_state = gr.State(open)
363
  with gr.Column():
364
  with gr.Accordion(f"Input Image {index + 1}", open=accordion_state.value) as accordion:
@@ -366,18 +361,18 @@ def create_image_input(index, open=True, indexs_state=None):
366
  caption = gr.Textbox(label=f"Caption {index + 1}", value="")
367
  id_ip_checkbox = gr.Checkbox(value=False, label=f"ID or not {index + 1}", visible=True)
368
  with gr.Row():
369
- vlm_btn = gr.Button("Auto Caption")
370
  det_btn = gr.Button("Det & Seg")
371
  face_btn = gr.Button("Crop Face")
372
  accordion.expand(
373
- inputs=[indexs_state],
374
  fn = lambda x: update_inputs(True, index, x),
375
- outputs=[indexs_state, accordion_state],
376
  )
377
  accordion.collapse(
378
- inputs=[indexs_state],
379
  fn = lambda x: update_inputs(False, index, x),
380
- outputs=[indexs_state, accordion_state],
381
  )
382
  return image, caption, face_btn, det_btn, vlm_btn, accordion_state, accordion, id_ip_checkbox
383
 
@@ -402,44 +397,95 @@ def merge_instances(orig_img, indices, ins_bboxes, ins_images):
402
 
403
  def change_accordion(at: bool, index: int, state: list):
404
  print(at, state)
405
- indexs = state
406
  if at:
407
- if index not in indexs:
408
- indexs.append(index)
409
  else:
410
- if index in indexs:
411
- indexs.remove(index)
412
 
413
- # 确保 indexs 是有序的
414
- indexs.sort()
415
- print(indexs)
416
- return gr.Accordion(open=at), indexs
417
 
418
  def update_inputs(is_open, index, state: list):
419
- indexs = state
420
  if is_open:
421
- if index not in indexs:
422
- indexs.append(index)
423
  else:
424
- if index in indexs:
425
- indexs.remove(index)
426
 
427
- # 确保 indexs 是有序的
428
- indexs.sort()
429
- print(indexs)
430
- return indexs, is_open
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  if __name__ == "__main__":
433
 
434
  with gr.Blocks() as demo:
435
-
436
- indexs_state = gr.State([0, 1]) # 添加状态来存储 indexs
 
437
 
438
  gr.Markdown("### XVerse Demo")
439
  with gr.Row():
440
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  prompt = gr.Textbox(label="Prompt", value="")
442
- with gr.Accordion("Open for More!", open=False):
 
443
 
444
  with gr.Row():
445
  target_height = gr.Slider(512, 1024, step=128, value=768, label="Generated Height", info="")
@@ -520,24 +566,13 @@ if __name__ == "__main__":
520
  double_attention = gr.Checkbox(value=False, label="Double Attention", visible=False)
521
  single_attention = gr.Checkbox(value=True, label="Single Attention", visible=False)
522
 
523
- clear_btn = gr.Button("清空输入图像")
524
- with gr.Row():
525
- for i in range(num_inputs):
526
- image, caption, face_btn, det_btn, vlm_btn, accordion_state, accordion, id_ip_checkbox = create_image_input(i, open=i<2, indexs_state=indexs_state)
527
- images.append(image)
528
- idip_checkboxes.append(id_ip_checkbox)
529
- captions.append(caption)
530
- face_btns.append(face_btn)
531
- det_btns.append(det_btn)
532
- vlm_btns.append(vlm_btn)
533
- accordion_states.append(accordion_state)
534
-
535
- accordions.append(accordion)
536
 
537
  with gr.Column():
538
- output = gr.Image(label="生成的图像")
539
  seed = gr.Number(value=42, label="Seed", info="")
540
- gen_btn = gr.Button("生成图像")
541
 
542
  gen_btn.click(
543
  generate_image,
@@ -546,25 +581,23 @@ if __name__ == "__main__":
546
  vae_skip_iter, weight_id_ip_str,
547
  double_attention, single_attention,
548
  db_latent_lora_scale_str, sb_latent_lora_scale_str, vae_lora_scale_str,
549
- indexs_state, # 传递 indexs 状态
 
550
  *images,
551
  *captions,
552
  *idip_checkboxes,
553
  ],
554
  outputs=output
555
  )
556
-
557
-
558
- # 修改清空函数的输出参数
559
  clear_btn.click(clear_images, outputs=images)
560
-
561
- # 循环绑定 Det & Seg 和 Auto Caption 按钮的点击事件
562
  for i in range(num_inputs):
563
  face_btns[i].click(crop_face_img, inputs=[images[i]], outputs=[images[i]])
564
  det_btns[i].click(det_seg_img, inputs=[images[i], captions[i]], outputs=[images[i]])
565
  vlm_btns[i].click(vlm_img_caption, inputs=[images[i]], outputs=[captions[i]])
566
- accordion_states[i].change(fn=lambda x, state, index=i: change_accordion(x, index, state), inputs=[accordion_states[i], indexs_state], outputs=[accordions[i], indexs_state])
567
-
 
568
 
569
  demo.queue()
570
  demo.launch()
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
 
16
 
17
  import spaces
18
 
 
116
  config = get_train_config(config_path)
117
  model.config = config
118
 
119
+ run_mode = "mod_only"
120
  store_attn_map = False
121
  run_name = time.strftime("%m%d-%H%M")
122
 
 
163
  if isinstance(image, str):
164
  image = Image.open(image).convert("RGB")
165
 
 
166
  image = pad_to_square(image).resize((2048, 2048))
167
 
168
  face_bbox = face_model.detect(
 
190
 
191
 
192
  def generate_random_string(length=4):
193
+ letters = string.ascii_letters
194
  result_str = ''.join(random.choice(letters) for i in range(length))
195
  return result_str
196
 
 
205
 
206
  @spaces.GPU()
207
  def generate_image(
208
+ prompt,
209
  cond_size, target_height, target_width,
210
  seed,
211
  vae_skip_iter, control_weight_lambda,
212
+ double_attention,
213
+ single_attention,
214
  ip_scale,
215
  latent_sblora_scale_str, vae_lora_scale,
216
+ indices,
217
+ session_id,
218
+ *images_captions_faces,
219
  ):
220
  torch.cuda.empty_cache()
221
  num_images = 1
222
 
 
223
  images = list(images_captions_faces[:num_inputs])
224
  captions = list(images_captions_faces[num_inputs:2 * num_inputs])
225
  idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs])
226
+ images = [images[i] for i in indices]
227
+ captions = [captions[i] for i in indices]
228
+ idips_checkboxes = [idips_checkboxes[i] for i in indices]
229
 
230
  print(f"Length of images: {len(images)}")
231
  print(f"Length of captions: {len(captions)}")
232
+ print(f"indices: {indices}")
233
 
234
  print(f"Control weight lambda: {control_weight_lambda}")
235
  if control_weight_lambda != "no":
 
239
  if ':' in part:
240
  left, right = part.split(':')
241
  values = right.split('/')
 
242
  global_value = values[0]
243
  id_value = values[1]
244
  ip_value = values[2]
 
260
  use_words = []
261
  cur_run_time = time.strftime("%m%d-%H%M%S")
262
  tmp_dir_root = f"tmp/gradio_demo/{run_name}"
263
+ temp_dir = f"{tmp_dir_root}/{session_id}/{cur_run_time}_{generate_random_string(4)}"
264
  os.makedirs(temp_dir, exist_ok=True)
265
  print(f"Temporary directory created: {temp_dir}")
266
  for i, (image_path, caption) in enumerate(zip(images, captions)):
 
274
  prompt = prompt.replace(f"ENT{i+1}", caption)
275
 
276
  image = resize_keep_aspect_ratio(Image.open(image_path), 768)
277
+ save_path = f"{temp_dir}/{session_id}/tmp_resized_input_{i}.png"
278
  image.save(save_path)
279
 
280
  input_image_path = save_path
 
312
  ),
313
  ]
314
 
315
+ json_dump(test_sample, f"{temp_dir}/{session_id}/test_sample.json", 'utf-8')
316
  assert single_attention == True
317
  target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16)
318
  print(test_sample)
 
333
  target_width=target_width,
334
  seed=seed,
335
  store_attn_map=store_attn_map,
336
+ vae_skip_iter=vae_skip_iter,
337
+ control_weight_lambda=control_weight_lambda,
338
+ double_attention=double_attention,
339
+ single_attention=single_attention,
340
  ip_scale=ip_scale,
341
  use_latent_sblora_control=use_latent_sblora_control,
342
  latent_sblora_scale=latent_sblora_scale_str,
 
348
  num_rows = int(math.ceil(num_images / num_cols))
349
  image = image_grid(image, num_rows, num_cols)
350
 
351
+ # save_path = f"{temp_dir}/tmp_result.png"
352
+ # image.save(save_path)
353
 
354
  return image
355
 
356
+ def create_image_input(index, open=True, indices_state=None):
357
  accordion_state = gr.State(open)
358
  with gr.Column():
359
  with gr.Accordion(f"Input Image {index + 1}", open=accordion_state.value) as accordion:
 
361
  caption = gr.Textbox(label=f"Caption {index + 1}", value="")
362
  id_ip_checkbox = gr.Checkbox(value=False, label=f"ID or not {index + 1}", visible=True)
363
  with gr.Row():
364
+ vlm_btn = gr.Button("Generate Caption")
365
  det_btn = gr.Button("Det & Seg")
366
  face_btn = gr.Button("Crop Face")
367
  accordion.expand(
368
+ inputs=[indices_state],
369
  fn = lambda x: update_inputs(True, index, x),
370
+ outputs=[indices_state, accordion_state],
371
  )
372
  accordion.collapse(
373
+ inputs=[indices_state],
374
  fn = lambda x: update_inputs(False, index, x),
375
+ outputs=[indices_state, accordion_state],
376
  )
377
  return image, caption, face_btn, det_btn, vlm_btn, accordion_state, accordion, id_ip_checkbox
378
 
 
397
 
398
  def change_accordion(at: bool, index: int, state: list):
399
  print(at, state)
400
+ indices = state
401
  if at:
402
+ if index not in indices:
403
+ indices.append(index)
404
  else:
405
+ if index in indices:
406
+ indices.remove(index)
407
 
408
+ # 确保 indices 是有序的
409
+ indices.sort()
410
+ print(indices)
411
+ return gr.Accordion(open=at), indices
412
 
413
  def update_inputs(is_open, index, state: list):
414
+ indices = state
415
  if is_open:
416
+ if index not in indices:
417
+ indices.append(index)
418
  else:
419
+ if index in indices:
420
+ indices.remove(index)
421
 
422
+ indices.sort()
423
+ print(indices)
424
+ return indices, is_open
425
+
426
+ def start_session(request: gr.Request):
427
+ """
428
+ Initialize a new user session and return the session identifier.
429
+
430
+ This function is triggered when the Gradio demo loads and creates a unique
431
+ session hash that will be used to organize outputs and temporary files
432
+ for this specific user session.
433
+
434
+ Args:
435
+ request (gr.Request): Gradio request object containing session information
436
+
437
+ Returns:
438
+ str: Unique session hash identifier
439
+ """
440
+ return request.session_hash
441
+
442
+
443
+ # Cleanup on unload
444
+ def cleanup(request: gr.Request):
445
+ """
446
+ Clean up session-specific directories and temporary files when the user session ends.
447
+
448
+ This function is triggered when the Gradio demo is unloaded (e.g., when the user
449
+ closes the browser tab or navigates away). It removes all temporary files and
450
+ directories created during the user's session to free up storage space.
451
+
452
+ Args:
453
+ request (gr.Request): Gradio request object containing session information
454
+ """
455
+ sid = request.session_hash
456
+ if sid:
457
+ d1 = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], sid)
458
+ d2 = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], sid)
459
+ shutil.rmtree(d1, ignore_errors=True)
460
+ shutil.rmtree(d2, ignore_errors=True)
461
+
462
 
463
  if __name__ == "__main__":
464
 
465
  with gr.Blocks() as demo:
466
+ session_state = gr.State()
467
+ demo.load(start_session, outputs=[session_state])
468
+ indices_state = gr.State([0, 1])
469
 
470
  gr.Markdown("### XVerse Demo")
471
  with gr.Row():
472
  with gr.Column():
473
+ with gr.Row():
474
+ for i in range(num_inputs):
475
+ image, caption, face_btn, det_btn, vlm_btn, accordion_state, accordion, id_ip_checkbox = create_image_input(i, open=i<2, indices_state=indices_state)
476
+ images.append(image)
477
+ idip_checkboxes.append(id_ip_checkbox)
478
+ captions.append(caption)
479
+ face_btns.append(face_btn)
480
+ det_btns.append(det_btn)
481
+ vlm_btns.append(vlm_btn)
482
+ accordion_states.append(accordion_state)
483
+
484
+ accordions.append(accordion)
485
+
486
  prompt = gr.Textbox(label="Prompt", value="")
487
+ gen_btn = gr.Button("Generate", variant="primary")
488
+ with gr.Accordion("Advanced Settings", open=False):
489
 
490
  with gr.Row():
491
  target_height = gr.Slider(512, 1024, step=128, value=768, label="Generated Height", info="")
 
566
  double_attention = gr.Checkbox(value=False, label="Double Attention", visible=False)
567
  single_attention = gr.Checkbox(value=True, label="Single Attention", visible=False)
568
 
569
+ clear_btn = gr.Button("Clear Images")
570
+
 
 
 
 
 
 
 
 
 
 
 
571
 
572
  with gr.Column():
573
+ output = gr.Image(label="Result")
574
  seed = gr.Number(value=42, label="Seed", info="")
575
+
576
 
577
  gen_btn.click(
578
  generate_image,
 
581
  vae_skip_iter, weight_id_ip_str,
582
  double_attention, single_attention,
583
  db_latent_lora_scale_str, sb_latent_lora_scale_str, vae_lora_scale_str,
584
+ indices_state,
585
+ session_state,
586
  *images,
587
  *captions,
588
  *idip_checkboxes,
589
  ],
590
  outputs=output
591
  )
 
 
 
592
  clear_btn.click(clear_images, outputs=images)
593
+
 
594
  for i in range(num_inputs):
595
  face_btns[i].click(crop_face_img, inputs=[images[i]], outputs=[images[i]])
596
  det_btns[i].click(det_seg_img, inputs=[images[i], captions[i]], outputs=[images[i]])
597
  vlm_btns[i].click(vlm_img_caption, inputs=[images[i]], outputs=[captions[i]])
598
+ accordion_states[i].change(fn=lambda x, state, index=i: change_accordion(x, index, state), inputs=[accordion_states[i], indices_state], outputs=[accordions[i], indices_state])
599
+
600
+ demo.unload(cleanup)
601
 
602
  demo.queue()
603
  demo.launch()