ML-Motivators commited on
Commit
130c731
·
verified ·
1 Parent(s): 6532741

Update app (3).py

Browse files
Files changed (1) hide show
  1. app (3).py +68 -150
app (3).py CHANGED
@@ -1,4 +1,9 @@
 
 
 
1
  import gradio as gr
 
 
2
  from PIL import Image
3
 
4
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
@@ -10,27 +15,21 @@ from transformers import (
10
  CLIPTextModel,
11
  CLIPTextModelWithProjection,
12
  )
13
- from diffusers import DDPMScheduler,AutoencoderKL
14
  from typing import List
15
 
16
- import torch
17
  import os
18
  from transformers import AutoTokenizer
19
- import spaces
20
  import numpy as np
21
  from utils_mask import get_mask_location
22
  from torchvision import transforms
23
  import apply_net
24
  from preprocess.humanparsing.run_parsing import Parsing
25
  from preprocess.openpose.run_openpose import OpenPose
26
- from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
27
  from torchvision.transforms.functional import to_pil_image
28
 
29
-
30
-
31
-
32
-
33
-
34
  def pil_to_binary_mask(pil_image, threshold=0):
35
  np_image = np.array(pil_image)
36
  grayscale_image = Image.fromarray(np_image).convert("L")
@@ -38,16 +37,17 @@ def pil_to_binary_mask(pil_image, threshold=0):
38
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
39
  for i in range(binary_mask.shape[0]):
40
  for j in range(binary_mask.shape[1]):
41
- if binary_mask[i,j] == True :
42
- mask[i,j] = 1
43
- mask = (mask*255).astype(np.uint8)
44
  output_mask = Image.fromarray(mask)
45
  return output_mask
46
 
47
-
48
  base_path = 'yisol/IDM-VTON'
49
  example_path = os.path.join(os.path.dirname(__file__), 'example')
50
 
 
51
  unet = UNet2DConditionModel.from_pretrained(
52
  base_path,
53
  subfolder="unet",
@@ -57,13 +57,11 @@ unet.requires_grad_(False)
57
  tokenizer_one = AutoTokenizer.from_pretrained(
58
  base_path,
59
  subfolder="tokenizer",
60
- revision=None,
61
  use_fast=False,
62
  )
63
  tokenizer_two = AutoTokenizer.from_pretrained(
64
  base_path,
65
  subfolder="tokenizer_2",
66
- revision=None,
67
  use_fast=False,
68
  )
69
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
@@ -82,7 +80,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
82
  base_path,
83
  subfolder="image_encoder",
84
  torch_dtype=torch.float16,
85
- )
86
  vae = AutoencoderKL.from_pretrained(base_path,
87
  subfolder="vae",
88
  torch_dtype=torch.float16,
@@ -105,40 +103,40 @@ unet.requires_grad_(False)
105
  text_encoder_one.requires_grad_(False)
106
  text_encoder_two.requires_grad_(False)
107
  tensor_transfrom = transforms.Compose(
108
- [
109
- transforms.ToTensor(),
110
- transforms.Normalize([0.5], [0.5]),
111
- ]
112
- )
113
 
 
114
  pipe = TryonPipeline.from_pretrained(
115
- base_path,
116
- unet=unet,
117
- vae=vae,
118
- feature_extractor= CLIPImageProcessor(),
119
- text_encoder = text_encoder_one,
120
- text_encoder_2 = text_encoder_two,
121
- tokenizer = tokenizer_one,
122
- tokenizer_2 = tokenizer_two,
123
- scheduler = noise_scheduler,
124
- image_encoder=image_encoder,
125
- torch_dtype=torch.float16,
126
  )
127
  pipe.unet_encoder = UNet_Encoder
128
 
129
-
130
-
131
  @spaces.GPU
132
- def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
133
  device = "cuda"
134
-
135
  openpose_model.preprocessor.body_estimation.model.to(device)
136
  pipe.to(device)
137
  pipe.unet_encoder.to(device)
138
 
139
- garm_img= garm_img.convert("RGB").resize((768,1024))
140
- human_img_orig = dict["background"].convert("RGB")
141
-
142
  if is_checked_crop:
143
  width, height = human_img_orig.size
144
  target_width = int(min(width, height * (3 / 4)))
@@ -149,37 +147,31 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
149
  bottom = (height + target_height) / 2
150
  cropped_img = human_img_orig.crop((left, top, right, bottom))
151
  crop_size = cropped_img.size
152
- human_img = cropped_img.resize((768,1024))
153
  else:
154
- human_img = human_img_orig.resize((768,1024))
155
-
156
 
157
  if is_checked:
158
- keypoints = openpose_model(human_img.resize((384,512)))
159
- model_parse, _ = parsing_model(human_img.resize((384,512)))
160
  mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
161
- mask = mask.resize((768,1024))
162
  else:
163
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
164
- # mask = transforms.ToTensor()(mask)
165
- # mask = mask.unsqueeze(0)
166
- mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
167
- mask_gray = to_pil_image((mask_gray+1.0)/2.0)
168
 
169
-
170
- human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
171
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
172
-
173
-
174
 
175
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
176
- # verbosity = getattr(args, "verbosity", None)
177
- pose_img = args.func(args,human_img_arg)
178
- pose_img = pose_img[:,:,::-1]
179
- pose_img = Image.fromarray(pose_img).resize((768,1024))
180
-
 
181
  with torch.no_grad():
182
- # Extract the images
183
  with torch.cuda.amp.autocast():
184
  with torch.no_grad():
185
  prompt = "model is wearing " + garment_des
@@ -196,7 +188,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
196
  do_classifier_free_guidance=True,
197
  negative_prompt=negative_prompt,
198
  )
199
-
200
  prompt = "a photo of " + garment_des
201
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
202
  if not isinstance(prompt, List):
@@ -216,105 +208,31 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
216
  negative_prompt=negative_prompt,
217
  )
218
 
219
-
220
-
221
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
222
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
223
  generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
224
  images = pipe(
225
- prompt_embeds=prompt_embeds.to(device,torch.float16),
226
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
227
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
228
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
229
  num_inference_steps=denoise_steps,
230
  generator=generator,
231
- strength = 1.0,
232
- pose_img = pose_img.to(device,torch.float16),
233
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
234
- cloth = garm_tensor.to(device,torch.float16),
235
  mask_image=mask,
236
- image=human_img,
237
  height=1024,
238
  width=768,
239
- ip_adapter_image = garm_img.resize((768,1024)),
240
  guidance_scale=2.0,
241
  )[0]
242
 
243
  if is_checked_crop:
244
- out_img = images[0].resize(crop_size)
245
- human_img_orig.paste(out_img, (int(left), int(top)))
246
  return human_img_orig, mask_gray
247
  else:
248
- return images[0], mask_gray
249
- # return images[0], mask_gray
250
-
251
- garm_list = os.listdir(os.path.join(example_path,"cloth"))
252
- garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
253
-
254
- human_list = os.listdir(os.path.join(example_path,"human"))
255
- human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
256
-
257
- human_ex_list = []
258
- for ex_human in human_list_path:
259
- ex_dict= {}
260
- ex_dict['background'] = ex_human
261
- ex_dict['layers'] = None
262
- ex_dict['composite'] = None
263
- human_ex_list.append(ex_dict)
264
-
265
- ##default human
266
-
267
-
268
- image_blocks = gr.Blocks().queue()
269
- with image_blocks as demo:
270
- gr.Markdown("## IDM-VTON 👕👔👚")
271
- gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)")
272
- with gr.Row():
273
- with gr.Column():
274
- imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
275
- with gr.Row():
276
- is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
277
- with gr.Row():
278
- is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
279
-
280
- example = gr.Examples(
281
- inputs=imgs,
282
- examples_per_page=10,
283
- examples=human_ex_list
284
- )
285
-
286
- with gr.Column():
287
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
288
- with gr.Row(elem_id="prompt-container"):
289
- with gr.Row():
290
- prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
291
- example = gr.Examples(
292
- inputs=garm_img,
293
- examples_per_page=8,
294
- examples=garm_list_path)
295
- with gr.Column():
296
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
297
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
298
- with gr.Column():
299
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
300
- image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
301
-
302
-
303
-
304
-
305
- with gr.Column():
306
- try_button = gr.Button(value="Try-on")
307
- with gr.Accordion(label="Advanced Settings", open=False):
308
- with gr.Row():
309
- denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
310
- seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
311
-
312
-
313
-
314
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
315
-
316
-
317
-
318
-
319
- image_blocks.launch()
320
-
 
1
+ # You can add the `demo = gr.Interface(fn=greet, inputs=[], outputs=[])` and `demo.launch()` at the end of your script, after all the function definitions and imports. Here's the updated version:
2
+
3
+ # ```python
4
  import gradio as gr
5
+ import spaces
6
+ import torch
7
  from PIL import Image
8
 
9
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
 
15
  CLIPTextModel,
16
  CLIPTextModelWithProjection,
17
  )
18
+ from diffusers import DDPMScheduler, AutoencoderKL
19
  from typing import List
20
 
 
21
  import os
22
  from transformers import AutoTokenizer
 
23
  import numpy as np
24
  from utils_mask import get_mask_location
25
  from torchvision import transforms
26
  import apply_net
27
  from preprocess.humanparsing.run_parsing import Parsing
28
  from preprocess.openpose.run_openpose import OpenPose
29
+ from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
30
  from torchvision.transforms.functional import to_pil_image
31
 
32
+ # Function to convert PIL image to binary mask
 
 
 
 
33
  def pil_to_binary_mask(pil_image, threshold=0):
34
  np_image = np.array(pil_image)
35
  grayscale_image = Image.fromarray(np_image).convert("L")
 
37
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
38
  for i in range(binary_mask.shape[0]):
39
  for j in range(binary_mask.shape[1]):
40
+ if binary_mask[i, j]:
41
+ mask[i, j] = 1
42
+ mask = (mask * 255).astype(np.uint8)
43
  output_mask = Image.fromarray(mask)
44
  return output_mask
45
 
46
+ # Base path setup
47
  base_path = 'yisol/IDM-VTON'
48
  example_path = os.path.join(os.path.dirname(__file__), 'example')
49
 
50
+ # Model loading
51
  unet = UNet2DConditionModel.from_pretrained(
52
  base_path,
53
  subfolder="unet",
 
57
  tokenizer_one = AutoTokenizer.from_pretrained(
58
  base_path,
59
  subfolder="tokenizer",
 
60
  use_fast=False,
61
  )
62
  tokenizer_two = AutoTokenizer.from_pretrained(
63
  base_path,
64
  subfolder="tokenizer_2",
 
65
  use_fast=False,
66
  )
67
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
 
80
  base_path,
81
  subfolder="image_encoder",
82
  torch_dtype=torch.float16,
83
+ )
84
  vae = AutoencoderKL.from_pretrained(base_path,
85
  subfolder="vae",
86
  torch_dtype=torch.float16,
 
103
  text_encoder_one.requires_grad_(False)
104
  text_encoder_two.requires_grad_(False)
105
  tensor_transfrom = transforms.Compose(
106
+ [
107
+ transforms.ToTensor(),
108
+ transforms.Normalize([0.5], [0.5]),
109
+ ]
110
+ )
111
 
112
+ # Tryon pipeline setup
113
  pipe = TryonPipeline.from_pretrained(
114
+ base_path,
115
+ unet=unet,
116
+ vae=vae,
117
+ feature_extractor=CLIPImageProcessor(),
118
+ text_encoder=text_encoder_one,
119
+ text_encoder_2=text_encoder_two,
120
+ tokenizer=tokenizer_one,
121
+ tokenizer_2=tokenizer_two,
122
+ scheduler=noise_scheduler,
123
+ image_encoder=image_encoder,
124
+ torch_dtype=torch.float16,
125
  )
126
  pipe.unet_encoder = UNet_Encoder
127
 
128
+ # Start try-on function
 
129
  @spaces.GPU
130
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
131
  device = "cuda"
132
+
133
  openpose_model.preprocessor.body_estimation.model.to(device)
134
  pipe.to(device)
135
  pipe.unet_encoder.to(device)
136
 
137
+ garm_img = garm_img.convert("RGB").resize((768, 1024))
138
+ human_img_orig = dict["background"].convert("RGB")
139
+
140
  if is_checked_crop:
141
  width, height = human_img_orig.size
142
  target_width = int(min(width, height * (3 / 4)))
 
147
  bottom = (height + target_height) / 2
148
  cropped_img = human_img_orig.crop((left, top, right, bottom))
149
  crop_size = cropped_img.size
150
+ human_img = cropped_img.resize((768, 1024))
151
  else:
152
+ human_img = human_img_orig.resize((768, 1024))
 
153
 
154
  if is_checked:
155
+ keypoints = openpose_model(human_img.resize((384, 512)))
156
+ model_parse, _ = parsing_model(human_img.resize((384, 512)))
157
  mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
158
+ mask = mask.resize((768, 1024))
159
  else:
160
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
161
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
162
+ mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
 
 
163
 
164
+ human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
 
165
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
 
 
166
 
167
+ args = apply_net.create_argument_parser().parse_args(
168
+ ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')
169
+ )
170
+ pose_img = args.func(args, human_img_arg)
171
+ pose_img = pose_img[:, :, ::-1]
172
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
173
+
174
  with torch.no_grad():
 
175
  with torch.cuda.amp.autocast():
176
  with torch.no_grad():
177
  prompt = "model is wearing " + garment_des
 
188
  do_classifier_free_guidance=True,
189
  negative_prompt=negative_prompt,
190
  )
191
+
192
  prompt = "a photo of " + garment_des
193
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
194
  if not isinstance(prompt, List):
 
208
  negative_prompt=negative_prompt,
209
  )
210
 
211
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
212
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
 
 
213
  generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
214
  images = pipe(
215
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
216
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
217
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
218
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
219
  num_inference_steps=denoise_steps,
220
  generator=generator,
221
+ strength=1.0,
222
+ pose_img=pose_img.to(device, torch.float16),
223
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
224
+ cloth=garm_tensor.to(device, torch.float16),
225
  mask_image=mask,
226
+ image=human_img,
227
  height=1024,
228
  width=768,
229
+ ip_adapter_image=garm_img.resize((768, 1024)),
230
  guidance_scale=2.0,
231
  )[0]
232
 
233
  if is_checked_crop:
234
+ out_img = images[0].resize(crop_size)
235
+ human_img_orig.paste(out_img, (int(left), int(top)))
236
  return human_img_orig, mask_gray
237
  else:
238
+ return images[0], mask_gray