Clone04 commited on
Commit
1c3b830
·
verified ·
1 Parent(s): 8a77b59

Update gradio_sd3.py

Browse files
Files changed (1) hide show
  1. gradio_sd3.py +94 -104
gradio_sd3.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import os
3
  import math
@@ -15,104 +16,102 @@ from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DMo
15
  from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
16
  import cv2
17
  import random
 
18
 
19
  example_path = os.path.join(os.path.dirname(__file__), 'examples')
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- class FitDiTGenerator:
23
- def __init__(self, model_root, offload=False, aggressive_offload=False, device="cuda:0", with_fp16=False):
24
- weight_dtype = torch.float16 if with_fp16 else torch.bfloat16
25
- transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(model_root, "transformer_garm"), torch_dtype=weight_dtype)
26
- transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(model_root, "transformer_vton"), torch_dtype=weight_dtype)
27
- pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
28
- pose_guider.load_state_dict(torch.load(os.path.join(model_root, "pose_guider", "diffusion_pytorch_model.bin")))
29
- image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
30
- image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
31
- pose_guider.to(device=device, dtype=weight_dtype)
32
- image_encoder_large.to(device=device)
33
- image_encoder_bigG.to(device=device)
34
- self.pipeline = StableDiffusion3TryOnPipeline.from_pretrained(model_root, torch_dtype=weight_dtype, transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
35
- self.pipeline.to(device)
36
- if offload:
37
- self.pipeline.enable_model_cpu_offload()
38
- self.dwprocessor = DWposeDetector(model_root=model_root, device='cpu')
39
- self.parsing_model = Parsing(model_root=model_root, device='cpu')
40
- elif aggressive_offload:
41
- self.pipeline.enable_sequential_cpu_offload()
42
- self.dwprocessor = DWposeDetector(model_root=model_root, device='cpu')
43
- self.parsing_model = Parsing(model_root=model_root, device='cpu')
44
- else:
45
- self.pipeline.to(device)
46
- self.dwprocessor = DWposeDetector(model_root=model_root, device=device)
47
- self.parsing_model = Parsing(model_root=model_root, device=device)
48
-
49
- def generate_mask(self, vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
50
- with torch.inference_mode():
51
- vton_img = Image.open(vton_img)
52
- vton_img_det = resize_image(vton_img)
53
- pose_image, keypoints, _, candidate = self.dwprocessor(np.array(vton_img_det)[:,:,::-1])
54
- candidate[candidate<0]=0
55
- candidate = candidate[0]
56
 
57
- candidate[:, 0]*=vton_img_det.width
58
- candidate[:, 1]*=vton_img_det.height
 
 
 
 
 
 
 
 
 
59
 
60
- pose_image = pose_image[:,:,::-1] #rgb
61
- pose_image = Image.fromarray(pose_image)
62
- model_parse, _ = self.parsing_model(vton_img_det)
63
 
64
- mask, mask_gray = get_mask_location(category, model_parse, \
65
- candidate, model_parse.width, model_parse.height, \
66
- offset_top, offset_bottom, offset_left, offset_right)
67
- mask = mask.resize(vton_img.size)
68
- mask_gray = mask_gray.resize(vton_img.size)
69
- mask = mask.convert("L")
70
- mask_gray = mask_gray.convert("L")
71
- masked_vton_img = Image.composite(mask_gray, vton_img, mask)
72
 
73
- im = {}
74
- im['background'] = np.array(vton_img.convert("RGBA"))
75
- im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
76
- im['composite'] = np.array(masked_vton_img.convert("RGBA"))
77
-
78
- return im, pose_image
79
 
80
- def process(self, vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
81
- assert resolution in ["768x1024", "1152x1536", "1536x2048"]
82
- new_width, new_height = resolution.split("x")
83
- new_width = int(new_width)
84
- new_height = int(new_height)
85
- with torch.inference_mode():
86
- garm_img = Image.open(garm_img)
87
- vton_img = Image.open(vton_img)
 
88
 
89
- model_image_size = vton_img.size
90
- garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
91
- vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
92
 
93
- mask = pre_mask["layers"][0][:,:,3]
94
- mask = Image.fromarray(mask)
95
- mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
96
- mask = mask.convert("L")
97
- pose_image = Image.fromarray(pose_image)
98
- pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
99
- if seed==-1:
100
- seed = random.randint(0, 2147483647)
101
- res = self.pipeline(
102
- height=new_height,
103
- width=new_width,
104
- guidance_scale=image_scale,
105
- num_inference_steps=n_steps,
106
- generator=torch.Generator("cpu").manual_seed(seed),
107
- cloth_image=garm_img,
108
- model_image=vton_img,
109
- mask=mask,
110
- pose_image=pose_image,
111
- num_images_per_prompt=num_images_per_prompt
112
- ).images
113
- for idx in range(len(res)):
114
- res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
115
- return res
116
 
117
 
118
  def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
@@ -186,11 +185,10 @@ HEADER = """
186
  </div>
187
  <br>
188
  FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers (DiT). It can only be used for <b>Non-commercial Use</b>.<br>
189
- If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
190
  """
191
 
192
- def create_demo(model_path, device, offload, aggressive_offload, with_fp16):
193
- generator = FitDiTGenerator(model_path, offload, aggressive_offload, device, with_fp16)
194
  with gr.Blocks(title="FitDiT") as demo:
195
  gr.Markdown(HEADER)
196
  with gr.Row():
@@ -264,7 +262,7 @@ def create_demo(model_path, device, offload, aggressive_offload, with_fp16):
264
  inputs=garm_img,
265
  examples_per_page=7,
266
  examples=[
267
- os.path.join(example_path, 'garment/12.png'),
268
  os.path.join(example_path, 'garment/0012.jpg'),
269
  os.path.join(example_path, 'garment/0047.jpg'),
270
  os.path.join(example_path, 'garment/0049.jpg'),
@@ -291,25 +289,17 @@ def create_demo(model_path, device, offload, aggressive_offload, with_fp16):
291
  ])
292
  with gr.Column():
293
  category = gr.Dropdown(label="Garment category", choices=["Upper-body", "Lower-body", "Dresses"], value="Upper-body")
294
- resolution = gr.Dropdown(label="Try-on resolution", choices=["768x1024", "1152x1536", "1536x2048"], value="1152x1536")
295
  with gr.Column():
296
  run_mask_button = gr.Button(value="Step1: Run Mask")
297
  run_button = gr.Button(value="Step2: Run Try-on")
298
 
299
  ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
300
  ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
301
- run_mask_button.click(fn=generator.generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
302
- run_button.click(fn=generator.process, inputs=ips2, outputs=[result_gallery])
303
  return demo
304
 
305
  if __name__ == "__main__":
306
- import argparse
307
- parser = argparse.ArgumentParser(description="FitDiT")
308
- parser.add_argument("--model_path", type=str, default="BoyuanJiang/FitDiT", required=True, help="The path of FitDiT model.")
309
- parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
310
- parser.add_argument("--fp16", action="store_true", help="Load model with fp16, default is bf16")
311
- parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use.")
312
- parser.add_argument("--aggressive_offload", action="store_true", help="Offload model more aggressively to CPU when not in use.")
313
- args = parser.parse_args()
314
- demo = create_demo(args.model_path, args.device, args.offload, args.aggressive_offload, args.fp16)
315
- demo.launch(share=True)
 
1
+ import spaces
2
  import gradio as gr
3
  import os
4
  import math
 
16
  from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
17
  import cv2
18
  import random
19
+ from huggingface_hub import snapshot_download
20
 
21
  example_path = os.path.join(os.path.dirname(__file__), 'examples')
22
 
23
+ fitdit_repo = "BoyuanJiang/FitDiT"
24
+ repo_path = snapshot_download(repo_id=fitdit_repo)
25
+
26
+ weight_dtype = torch.bfloat16
27
+ device = "cuda"
28
+ transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
29
+ transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
30
+ pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
31
+ pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
32
+ image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
33
+ image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
34
+ pose_guider.to(device=device, dtype=weight_dtype)
35
+ image_encoder_large.to(device=device)
36
+ image_encoder_bigG.to(device=device)
37
+ pipeline = StableDiffusion3TryOnPipeline.from_pretrained(repo_path, torch_dtype=weight_dtype, \
38
+ transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, \
39
+ image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
40
+ pipeline.to(device)
41
+ dwprocessor = DWposeDetector(model_root=repo_path, device=device)
42
+ parsing_model = Parsing(model_root=repo_path, device=device)
43
+
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+
47
+ def generate_mask(vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
48
+ with torch.inference_mode():
49
+ vton_img = Image.open(vton_img)
50
+ vton_img_det = resize_image(vton_img)
51
+ pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
52
+ candidate[candidate<0]=0
53
+ candidate = candidate[0]
54
+
55
+ candidate[:, 0]*=vton_img_det.width
56
+ candidate[:, 1]*=vton_img_det.height
57
 
58
+ pose_image = pose_image[:,:,::-1] #rgb
59
+ pose_image = Image.fromarray(pose_image)
60
+ model_parse, _ = parsing_model(vton_img_det)
61
 
62
+ mask, mask_gray = get_mask_location(category, model_parse, \
63
+ candidate, model_parse.width, model_parse.height, \
64
+ offset_top, offset_bottom, offset_left, offset_right)
65
+ mask = mask.resize(vton_img.size)
66
+ mask_gray = mask_gray.resize(vton_img.size)
67
+ mask = mask.convert("L")
68
+ mask_gray = mask_gray.convert("L")
69
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
70
 
71
+ im = {}
72
+ im['background'] = np.array(vton_img.convert("RGBA"))
73
+ im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
74
+ im['composite'] = np.array(masked_vton_img.convert("RGBA"))
75
+
76
+ return im, pose_image
77
 
78
+ @spaces.GPU
79
+ def process(vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
80
+ assert resolution in ["768x1024", "1152x1536", "1536x2048"]
81
+ new_width, new_height = resolution.split("x")
82
+ new_width = int(new_width)
83
+ new_height = int(new_height)
84
+ with torch.inference_mode():
85
+ garm_img = Image.open(garm_img)
86
+ vton_img = Image.open(vton_img)
87
 
88
+ model_image_size = vton_img.size
89
+ garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
90
+ vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
91
 
92
+ mask = pre_mask["layers"][0][:,:,3]
93
+ mask = Image.fromarray(mask)
94
+ mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
95
+ mask = mask.convert("L")
96
+ pose_image = Image.fromarray(pose_image)
97
+ pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
98
+ if seed==-1:
99
+ seed = random.randint(0, 2147483647)
100
+ res = pipeline(
101
+ height=new_height,
102
+ width=new_width,
103
+ guidance_scale=image_scale,
104
+ num_inference_steps=n_steps,
105
+ generator=torch.Generator("cpu").manual_seed(seed),
106
+ cloth_image=garm_img,
107
+ model_image=vton_img,
108
+ mask=mask,
109
+ pose_image=pose_image,
110
+ num_images_per_prompt=num_images_per_prompt
111
+ ).images
112
+ for idx in range(len(res)):
113
+ res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
114
+ return res
115
 
116
 
117
  def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
 
185
  </div>
186
  <br>
187
  FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers (DiT). It can only be used for <b>Non-commercial Use</b>.<br>
188
+ If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>. A <b>ComfyUI version</b> of FitDiT is available <a href="https://github.com/BoyuanJiang/FitDiT/tree/FitDiT-ComfyUI" style="color: blue; text-decoration: underline;">here</a>.
189
  """
190
 
191
+ def create_demo():
 
192
  with gr.Blocks(title="FitDiT") as demo:
193
  gr.Markdown(HEADER)
194
  with gr.Row():
 
262
  inputs=garm_img,
263
  examples_per_page=7,
264
  examples=[
265
+ os.path.join(example_path, 'garment/12.jpg'),
266
  os.path.join(example_path, 'garment/0012.jpg'),
267
  os.path.join(example_path, 'garment/0047.jpg'),
268
  os.path.join(example_path, 'garment/0049.jpg'),
 
289
  ])
290
  with gr.Column():
291
  category = gr.Dropdown(label="Garment category", choices=["Upper-body", "Lower-body", "Dresses"], value="Upper-body")
292
+ resolution = gr.Dropdown(label="Try-on resolution", choices=["768x1024", "1152x1536", "1536x2048"], value="768x1024")
293
  with gr.Column():
294
  run_mask_button = gr.Button(value="Step1: Run Mask")
295
  run_button = gr.Button(value="Step2: Run Try-on")
296
 
297
  ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
298
  ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
299
+ run_mask_button.click(fn=generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
300
+ run_button.click(fn=process, inputs=ips2, outputs=[result_gallery])
301
  return demo
302
 
303
  if __name__ == "__main__":
304
+ demo = create_demo()
305
+ demo.launch()