lzyhha commited on
Commit
b4faa43
·
1 Parent(s): 391f6e1
Files changed (3) hide show
  1. app.py +22 -16
  2. demo_tasks/gradio_tasks_unseen.py +1 -1
  3. visualcloze.py +19 -17
app.py CHANGED
@@ -17,6 +17,10 @@ default_steps = 30
17
 
18
  GUIDANCE = """
19
 
 
 
 
 
20
  ## 📋 Quick Start Guide:
21
  1. Adjust **Number of In-context Examples**, 0 disables in-context learning.
22
  2. Set **Task Columns**, the number of images involved in a task.
@@ -24,16 +28,18 @@ GUIDANCE = """
24
  4. Click **Generate** to create the images.
25
  5. Parameters can be fine-tuned under **Advanced Options**.
26
 
27
- <div style='font-size: 20px; color:red;'>🔥 Click the task button in the right bottom to acquire examples of various tasks.</div>
 
 
28
 
29
- <div style='font-size: 20px; '> 📧 Need help or have questions? Contact us at: lizhongyu [AT] mail.nankai.edu.cn</div>
30
 
31
- <div style='font-size: 20px;'>
32
- 💻 The runtime on the zero GPU runtime depends on the size of the image grid.
33
- When generating an image with the resoluation of 1024, the runtime is approximately <span style='font-weight: bold; color:red;'>[45s for a 2x2 grid], [55s for a 2x3 grid], [70s for a 3x3 grid], [90s for a 3x4 grid]</span>.
34
- When generating three images in a 3x4 grid, i.e., Image to Depth + Normal + Hed, the runtime is approximately <span style='font-weight: bold; color:red;'>110s</span>.
35
- Deploying locally with an 80G A100 can reduce the runtime by more than half.
36
- </div>
 
37
 
38
  """
39
 
@@ -90,9 +96,7 @@ def create_demo(model):
90
  for i in range(max_grid_h):
91
  # Add row label before each row
92
  row_texts.append(gr.Markdown(
93
- "<div style='font-size: 24px; font-weight: bold;'>" +
94
- ("query" if i == default_grid_h - 1 else f"In-context Example {i + 1}") +
95
- "</div>",
96
  elem_id=f"row_text_{i}",
97
  visible=i < default_grid_h
98
  ))
@@ -297,9 +301,7 @@ def create_demo(model):
297
  gr.update(
298
  elem_id=f"row_text_{i}",
299
  visible=i < actual_h,
300
- value="<div style='font-size: 24px; font-weight: bold;'>" +
301
- ("Query" if i == actual_h - 1 else f"In-context Example {i + 1}") +
302
- "</div>",
303
  )
304
  )
305
 
@@ -314,6 +316,9 @@ def create_demo(model):
314
  images.append([])
315
  for j in range(model.grid_w):
316
  images[i].append(inputs[i * max_grid_w + j])
 
 
 
317
  seed, cfg, steps, upsampling_steps, upsampling_noise, layout_text, task_text, content_text = inputs[-8:]
318
 
319
  results = generate(
@@ -489,7 +494,7 @@ def parse_args():
489
  if __name__ == "__main__":
490
  args = parse_args()
491
 
492
- snapshot_download(repo_id="VisualCloze/VisualCloze", repo_type="model", local_dir="models")
493
 
494
  # Initialize model
495
  model = VisualClozeModel(resolution=args.resolution, model_path=args.model_path, precision=args.precision)
@@ -498,4 +503,5 @@ if __name__ == "__main__":
498
  demo = create_demo(model)
499
 
500
  # Start Gradio server
501
- demo.launch()
 
 
17
 
18
  GUIDANCE = """
19
 
20
+
21
+ ## 📧 Contact:
22
+ Need help or have questions? Contact us at: lizhongyu [AT] mail.nankai.edu.cn.
23
+
24
  ## 📋 Quick Start Guide:
25
  1. Adjust **Number of In-context Examples**, 0 disables in-context learning.
26
  2. Set **Task Columns**, the number of images involved in a task.
 
28
  4. Click **Generate** to create the images.
29
  5. Parameters can be fine-tuned under **Advanced Options**.
30
 
31
+ ## 🔥 Task Examples:
32
+ Click the task button in the right bottom to acquire **examples** of various tasks.
33
+ Make sure all images and prompts are loaded before clicking the generate button.
34
 
 
35
 
36
+ ## 💻 Runtime on the Zero GPU:
37
+ The runtime on the Zero GPU runtime depends on the size of the image grid.
38
+ When generating an image with the resoluation of 1024,
39
+ the runtime is approximately **[45s for a 2x2 grid], [55s for a 2x3 grid], [70s for a 3x3 grid], [90s for a 3x4 grid]**.
40
+ When generating three images in a 3x4 grid, i.e., Image to Depth + Normal + Hed,
41
+ the runtime is approximately **110s**.
42
+ **Deploying locally with an 80G A100 can reduce the runtime by more than half.**
43
 
44
  """
45
 
 
96
  for i in range(max_grid_h):
97
  # Add row label before each row
98
  row_texts.append(gr.Markdown(
99
+ "## Query" if i == default_grid_h - 1 else f"## In-context Example {i + 1}",
 
 
100
  elem_id=f"row_text_{i}",
101
  visible=i < default_grid_h
102
  ))
 
301
  gr.update(
302
  elem_id=f"row_text_{i}",
303
  visible=i < actual_h,
304
+ value="## Query" if i == actual_h - 1 else f"## In-context Example {i + 1}",
 
 
305
  )
306
  )
307
 
 
316
  images.append([])
317
  for j in range(model.grid_w):
318
  images[i].append(inputs[i * max_grid_w + j])
319
+ if i != model.grid_h - 1:
320
+ if inputs[i * max_grid_w + j] is None:
321
+ raise gr.Error('Please upload in-context examples.')
322
  seed, cfg, steps, upsampling_steps, upsampling_noise, layout_text, task_text, content_text = inputs[-8:]
323
 
324
  results = generate(
 
494
  if __name__ == "__main__":
495
  args = parse_args()
496
 
497
+ # snapshot_download(repo_id="VisualCloze/VisualCloze", repo_type="model", local_dir="models")
498
 
499
  # Initialize model
500
  model = VisualClozeModel(resolution=args.resolution, model_path=args.model_path, precision=args.precision)
 
503
  demo = create_demo(model)
504
 
505
  # Start Gradio server
506
+ demo.launch()
507
+ # demo.launch(share=False, server_port=10050, server_name="0.0.0.0")
demo_tasks/gradio_tasks_unseen.py CHANGED
@@ -253,7 +253,7 @@ def process_unseen_tasks(x):
253
  mask = task.get('mask', [0 for _ in range(grid_w - 1)] + [1])
254
  layout_prompt = get_layout_instruction(grid_w, grid_h)
255
 
256
- upsampling_noise = None
257
  steps = None
258
  outputs = [mask, grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + rets
259
  break
 
253
  mask = task.get('mask', [0 for _ in range(grid_w - 1)] + [1])
254
  layout_prompt = get_layout_instruction(grid_w, grid_h)
255
 
256
+ upsampling_noise = 0.7
257
  steps = None
258
  outputs = [mask, grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + rets
259
  break
visualcloze.py CHANGED
@@ -91,26 +91,26 @@ class VisualClozeModel:
91
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
93
 
94
- # Initialize model
95
- print("Initializing model...")
96
- self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
97
 
98
- # Initialize VAE
99
- print("Initializing VAE...")
100
- self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
101
- self.ae.requires_grad_(False)
102
 
103
- # Initialize text encoders
104
- print("Initializing text encoders...")
105
- self.t5 = load_t5(self.device, max_length=self.max_length)
106
- self.clip = load_clip(self.device)
107
 
108
- self.model.eval().to(self.device, dtype=self.dtype)
109
 
110
- # Load model weights
111
- ckpt = torch.load(model_path)
112
- self.model.load_state_dict(ckpt, strict=False)
113
- del ckpt
114
 
115
  # Initialize sampler
116
  transport = create_transport(
@@ -337,6 +337,8 @@ class VisualClozeModel:
337
  processed_images.append(blank)
338
  if i == grid_h - 1:
339
  mask_position.append(1)
 
 
340
 
341
  if len(mask_position) > 1 and sum(mask_position) > 1:
342
  if target_size is None:
@@ -443,7 +445,7 @@ class VisualClozeModel:
443
  if True: # images[i] is None:
444
  cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h))
445
  ret.append(cropped)
446
- if mask_position[i - row_start] and is_upsampling:
447
  upsampled = self.upsampling(
448
  cropped,
449
  upsampling_size,
 
91
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
93
 
94
+ # # Initialize model
95
+ # print("Initializing model...")
96
+ # self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
97
 
98
+ # # Initialize VAE
99
+ # print("Initializing VAE...")
100
+ # self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
101
+ # self.ae.requires_grad_(False)
102
 
103
+ # # Initialize text encoders
104
+ # print("Initializing text encoders...")
105
+ # self.t5 = load_t5(self.device, max_length=self.max_length)
106
+ # self.clip = load_clip(self.device)
107
 
108
+ # self.model.eval().to(self.device, dtype=self.dtype)
109
 
110
+ # # Load model weights
111
+ # ckpt = torch.load(model_path)
112
+ # self.model.load_state_dict(ckpt, strict=False)
113
+ # del ckpt
114
 
115
  # Initialize sampler
116
  transport = create_transport(
 
337
  processed_images.append(blank)
338
  if i == grid_h - 1:
339
  mask_position.append(1)
340
+
341
+ return processed_images
342
 
343
  if len(mask_position) > 1 and sum(mask_position) > 1:
344
  if target_size is None:
 
445
  if True: # images[i] is None:
446
  cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h))
447
  ret.append(cropped)
448
+ if mask_position[i - row_start] and is_upsampling and upsampling_noise < 1.0:
449
  upsampled = self.upsampling(
450
  cropped,
451
  upsampling_size,