Spaces:
Sleeping
Sleeping
Update gradio_sd3.py
Browse files- gradio_sd3.py +94 -104
gradio_sd3.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import math
|
@@ -15,104 +16,102 @@ from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DMo
|
|
15 |
from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
|
16 |
import cv2
|
17 |
import random
|
|
|
18 |
|
19 |
example_path = os.path.join(os.path.dirname(__file__), 'examples')
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
class FitDiTGenerator:
|
23 |
-
def __init__(self, model_root, offload=False, aggressive_offload=False, device="cuda:0", with_fp16=False):
|
24 |
-
weight_dtype = torch.float16 if with_fp16 else torch.bfloat16
|
25 |
-
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(model_root, "transformer_garm"), torch_dtype=weight_dtype)
|
26 |
-
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(model_root, "transformer_vton"), torch_dtype=weight_dtype)
|
27 |
-
pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
|
28 |
-
pose_guider.load_state_dict(torch.load(os.path.join(model_root, "pose_guider", "diffusion_pytorch_model.bin")))
|
29 |
-
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
|
30 |
-
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
|
31 |
-
pose_guider.to(device=device, dtype=weight_dtype)
|
32 |
-
image_encoder_large.to(device=device)
|
33 |
-
image_encoder_bigG.to(device=device)
|
34 |
-
self.pipeline = StableDiffusion3TryOnPipeline.from_pretrained(model_root, torch_dtype=weight_dtype, transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
|
35 |
-
self.pipeline.to(device)
|
36 |
-
if offload:
|
37 |
-
self.pipeline.enable_model_cpu_offload()
|
38 |
-
self.dwprocessor = DWposeDetector(model_root=model_root, device='cpu')
|
39 |
-
self.parsing_model = Parsing(model_root=model_root, device='cpu')
|
40 |
-
elif aggressive_offload:
|
41 |
-
self.pipeline.enable_sequential_cpu_offload()
|
42 |
-
self.dwprocessor = DWposeDetector(model_root=model_root, device='cpu')
|
43 |
-
self.parsing_model = Parsing(model_root=model_root, device='cpu')
|
44 |
-
else:
|
45 |
-
self.pipeline.to(device)
|
46 |
-
self.dwprocessor = DWposeDetector(model_root=model_root, device=device)
|
47 |
-
self.parsing_model = Parsing(model_root=model_root, device=device)
|
48 |
-
|
49 |
-
def generate_mask(self, vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
|
50 |
-
with torch.inference_mode():
|
51 |
-
vton_img = Image.open(vton_img)
|
52 |
-
vton_img_det = resize_image(vton_img)
|
53 |
-
pose_image, keypoints, _, candidate = self.dwprocessor(np.array(vton_img_det)[:,:,::-1])
|
54 |
-
candidate[candidate<0]=0
|
55 |
-
candidate = candidate[0]
|
56 |
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
|
117 |
|
118 |
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
|
@@ -186,11 +185,10 @@ HEADER = """
|
|
186 |
</div>
|
187 |
<br>
|
188 |
FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers (DiT). It can only be used for <b>Non-commercial Use</b>.<br>
|
189 |
-
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
|
190 |
"""
|
191 |
|
192 |
-
def create_demo(
|
193 |
-
generator = FitDiTGenerator(model_path, offload, aggressive_offload, device, with_fp16)
|
194 |
with gr.Blocks(title="FitDiT") as demo:
|
195 |
gr.Markdown(HEADER)
|
196 |
with gr.Row():
|
@@ -264,7 +262,7 @@ def create_demo(model_path, device, offload, aggressive_offload, with_fp16):
|
|
264 |
inputs=garm_img,
|
265 |
examples_per_page=7,
|
266 |
examples=[
|
267 |
-
os.path.join(example_path, 'garment/12.
|
268 |
os.path.join(example_path, 'garment/0012.jpg'),
|
269 |
os.path.join(example_path, 'garment/0047.jpg'),
|
270 |
os.path.join(example_path, 'garment/0049.jpg'),
|
@@ -291,25 +289,17 @@ def create_demo(model_path, device, offload, aggressive_offload, with_fp16):
|
|
291 |
])
|
292 |
with gr.Column():
|
293 |
category = gr.Dropdown(label="Garment category", choices=["Upper-body", "Lower-body", "Dresses"], value="Upper-body")
|
294 |
-
resolution = gr.Dropdown(label="Try-on resolution", choices=["768x1024", "1152x1536", "1536x2048"], value="
|
295 |
with gr.Column():
|
296 |
run_mask_button = gr.Button(value="Step1: Run Mask")
|
297 |
run_button = gr.Button(value="Step2: Run Try-on")
|
298 |
|
299 |
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
|
300 |
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
|
301 |
-
run_mask_button.click(fn=
|
302 |
-
run_button.click(fn=
|
303 |
return demo
|
304 |
|
305 |
if __name__ == "__main__":
|
306 |
-
|
307 |
-
|
308 |
-
parser.add_argument("--model_path", type=str, default="BoyuanJiang/FitDiT", required=True, help="The path of FitDiT model.")
|
309 |
-
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
|
310 |
-
parser.add_argument("--fp16", action="store_true", help="Load model with fp16, default is bf16")
|
311 |
-
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use.")
|
312 |
-
parser.add_argument("--aggressive_offload", action="store_true", help="Offload model more aggressively to CPU when not in use.")
|
313 |
-
args = parser.parse_args()
|
314 |
-
demo = create_demo(args.model_path, args.device, args.offload, args.aggressive_offload, args.fp16)
|
315 |
-
demo.launch(share=True)
|
|
|
1 |
+
import spaces
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
import math
|
|
|
16 |
from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
|
17 |
import cv2
|
18 |
import random
|
19 |
+
from huggingface_hub import snapshot_download
|
20 |
|
21 |
example_path = os.path.join(os.path.dirname(__file__), 'examples')
|
22 |
|
23 |
+
fitdit_repo = "BoyuanJiang/FitDiT"
|
24 |
+
repo_path = snapshot_download(repo_id=fitdit_repo)
|
25 |
+
|
26 |
+
weight_dtype = torch.bfloat16
|
27 |
+
device = "cuda"
|
28 |
+
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
|
29 |
+
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
|
30 |
+
pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
|
31 |
+
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
|
32 |
+
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
|
33 |
+
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
|
34 |
+
pose_guider.to(device=device, dtype=weight_dtype)
|
35 |
+
image_encoder_large.to(device=device)
|
36 |
+
image_encoder_bigG.to(device=device)
|
37 |
+
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(repo_path, torch_dtype=weight_dtype, \
|
38 |
+
transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, \
|
39 |
+
image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
|
40 |
+
pipeline.to(device)
|
41 |
+
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
|
42 |
+
parsing_model = Parsing(model_root=repo_path, device=device)
|
43 |
+
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
|
47 |
+
def generate_mask(vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
|
48 |
+
with torch.inference_mode():
|
49 |
+
vton_img = Image.open(vton_img)
|
50 |
+
vton_img_det = resize_image(vton_img)
|
51 |
+
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
|
52 |
+
candidate[candidate<0]=0
|
53 |
+
candidate = candidate[0]
|
54 |
+
|
55 |
+
candidate[:, 0]*=vton_img_det.width
|
56 |
+
candidate[:, 1]*=vton_img_det.height
|
57 |
|
58 |
+
pose_image = pose_image[:,:,::-1] #rgb
|
59 |
+
pose_image = Image.fromarray(pose_image)
|
60 |
+
model_parse, _ = parsing_model(vton_img_det)
|
61 |
|
62 |
+
mask, mask_gray = get_mask_location(category, model_parse, \
|
63 |
+
candidate, model_parse.width, model_parse.height, \
|
64 |
+
offset_top, offset_bottom, offset_left, offset_right)
|
65 |
+
mask = mask.resize(vton_img.size)
|
66 |
+
mask_gray = mask_gray.resize(vton_img.size)
|
67 |
+
mask = mask.convert("L")
|
68 |
+
mask_gray = mask_gray.convert("L")
|
69 |
+
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
|
70 |
|
71 |
+
im = {}
|
72 |
+
im['background'] = np.array(vton_img.convert("RGBA"))
|
73 |
+
im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
|
74 |
+
im['composite'] = np.array(masked_vton_img.convert("RGBA"))
|
75 |
+
|
76 |
+
return im, pose_image
|
77 |
|
78 |
+
@spaces.GPU
|
79 |
+
def process(vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
|
80 |
+
assert resolution in ["768x1024", "1152x1536", "1536x2048"]
|
81 |
+
new_width, new_height = resolution.split("x")
|
82 |
+
new_width = int(new_width)
|
83 |
+
new_height = int(new_height)
|
84 |
+
with torch.inference_mode():
|
85 |
+
garm_img = Image.open(garm_img)
|
86 |
+
vton_img = Image.open(vton_img)
|
87 |
|
88 |
+
model_image_size = vton_img.size
|
89 |
+
garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
|
90 |
+
vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
|
91 |
|
92 |
+
mask = pre_mask["layers"][0][:,:,3]
|
93 |
+
mask = Image.fromarray(mask)
|
94 |
+
mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
|
95 |
+
mask = mask.convert("L")
|
96 |
+
pose_image = Image.fromarray(pose_image)
|
97 |
+
pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
|
98 |
+
if seed==-1:
|
99 |
+
seed = random.randint(0, 2147483647)
|
100 |
+
res = pipeline(
|
101 |
+
height=new_height,
|
102 |
+
width=new_width,
|
103 |
+
guidance_scale=image_scale,
|
104 |
+
num_inference_steps=n_steps,
|
105 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
106 |
+
cloth_image=garm_img,
|
107 |
+
model_image=vton_img,
|
108 |
+
mask=mask,
|
109 |
+
pose_image=pose_image,
|
110 |
+
num_images_per_prompt=num_images_per_prompt
|
111 |
+
).images
|
112 |
+
for idx in range(len(res)):
|
113 |
+
res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
|
114 |
+
return res
|
115 |
|
116 |
|
117 |
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
|
|
|
185 |
</div>
|
186 |
<br>
|
187 |
FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers (DiT). It can only be used for <b>Non-commercial Use</b>.<br>
|
188 |
+
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>. A <b>ComfyUI version</b> of FitDiT is available <a href="https://github.com/BoyuanJiang/FitDiT/tree/FitDiT-ComfyUI" style="color: blue; text-decoration: underline;">here</a>.
|
189 |
"""
|
190 |
|
191 |
+
def create_demo():
|
|
|
192 |
with gr.Blocks(title="FitDiT") as demo:
|
193 |
gr.Markdown(HEADER)
|
194 |
with gr.Row():
|
|
|
262 |
inputs=garm_img,
|
263 |
examples_per_page=7,
|
264 |
examples=[
|
265 |
+
os.path.join(example_path, 'garment/12.jpg'),
|
266 |
os.path.join(example_path, 'garment/0012.jpg'),
|
267 |
os.path.join(example_path, 'garment/0047.jpg'),
|
268 |
os.path.join(example_path, 'garment/0049.jpg'),
|
|
|
289 |
])
|
290 |
with gr.Column():
|
291 |
category = gr.Dropdown(label="Garment category", choices=["Upper-body", "Lower-body", "Dresses"], value="Upper-body")
|
292 |
+
resolution = gr.Dropdown(label="Try-on resolution", choices=["768x1024", "1152x1536", "1536x2048"], value="768x1024")
|
293 |
with gr.Column():
|
294 |
run_mask_button = gr.Button(value="Step1: Run Mask")
|
295 |
run_button = gr.Button(value="Step2: Run Try-on")
|
296 |
|
297 |
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
|
298 |
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
|
299 |
+
run_mask_button.click(fn=generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
|
300 |
+
run_button.click(fn=process, inputs=ips2, outputs=[result_gallery])
|
301 |
return demo
|
302 |
|
303 |
if __name__ == "__main__":
|
304 |
+
demo = create_demo()
|
305 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|