Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
4ddf508
ยท
verified ยท
1 Parent(s): 8b0546d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -97
app.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  from io import BytesIO
5
  import torch
6
  import base64
7
- import io
8
  import logging
9
  import gradio as gr
10
  import numpy as np
@@ -33,14 +33,12 @@ app = Flask(__name__)
33
  base_path = 'yisol/IDM-VTON'
34
  example_path = os.path.join(os.path.dirname(__file__), 'example')
35
 
36
- # Load models
37
  unet = UNet2DConditionModel.from_pretrained(
38
  base_path,
39
  subfolder="unet",
40
  torch_dtype=torch.float16,
41
  )
42
  unet.requires_grad_(False)
43
-
44
  tokenizer_one = AutoTokenizer.from_pretrained(
45
  base_path,
46
  subfolder="tokenizer",
@@ -84,33 +82,31 @@ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
84
  parsing_model = Parsing(0)
85
  openpose_model = OpenPose(0)
86
 
87
- # Set models to evaluation mode
88
  UNet_Encoder.requires_grad_(False)
89
  image_encoder.requires_grad_(False)
90
  vae.requires_grad_(False)
91
  unet.requires_grad_(False)
92
  text_encoder_one.requires_grad_(False)
93
  text_encoder_two.requires_grad_(False)
94
-
95
  tensor_transfrom = transforms.Compose(
96
- [
97
- transforms.ToTensor(),
98
- transforms.Normalize([0.5], [0.5]),
99
- ]
100
- )
101
 
102
  pipe = TryonPipeline.from_pretrained(
103
- base_path,
104
- unet=unet,
105
- vae=vae,
106
- feature_extractor=CLIPImageProcessor(),
107
- text_encoder=text_encoder_one,
108
- text_encoder_2=text_encoder_two,
109
- tokenizer=tokenizer_one,
110
- tokenizer_2=tokenizer_two,
111
- scheduler=noise_scheduler,
112
- image_encoder=image_encoder,
113
- torch_dtype=torch.float16,
114
  )
115
  pipe.unet_encoder = UNet_Encoder
116
 
@@ -119,11 +115,15 @@ def pil_to_binary_mask(pil_image, threshold=0):
119
  grayscale_image = Image.fromarray(np_image).convert("L")
120
  binary_mask = np.array(grayscale_image) > threshold
121
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
122
- mask[binary_mask] = 1
 
 
 
123
  mask = (mask * 255).astype(np.uint8)
124
  output_mask = Image.fromarray(mask)
125
  return output_mask
126
 
 
127
  def decode_image_from_base64(base64_str):
128
  try:
129
  img_data = base64.b64decode(base64_str)
@@ -144,7 +144,7 @@ def encode_image_to_base64(img):
144
  raise
145
 
146
  @spaces.GPU
147
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie='upper_body'):
148
  device = "cuda"
149
  openpose_model.preprocessor.body_estimation.model.to(device)
150
  pipe.to(device)
@@ -169,93 +169,141 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
169
 
170
  if is_checked:
171
  keypoints = openpose_model(human_img.resize((384, 512)))
172
- model_parse, _ = parsing_model(human_img)
173
- return model_parse
 
 
 
 
 
174
 
175
- # Generate mask for the garment
176
- garment_mask = pil_to_binary_mask(garm_img)
177
- garment_mask = garment_mask.resize((768, 1024))
178
- garment_mask = torch.from_numpy(np.array(garment_mask)).float().unsqueeze(0).to(device) / 255.0
179
- garment_mask = torch.cat([garment_mask, garment_mask, garment_mask], dim=1)
180
 
181
- # Process human image and garment image
182
- try:
183
- img_human = tensor_transfrom(human_img)
184
- img_garment = tensor_transfrom(garm_img)
185
- except Exception as e:
186
- logging.error(f"Error processing images: {e}")
187
- raise
188
 
189
- # Generate image
190
- try:
191
- result = pipe(
192
- image_human=img_human.unsqueeze(0).to(device),
193
- image_garment=img_garment.unsqueeze(0).to(device),
194
- garment_mask=garment_mask,
195
- denoise_steps=denoise_steps,
196
- seed=seed,
197
- ).images[0]
198
- return result
199
- except Exception as e:
200
- logging.error(f"Error generating image: {e}")
201
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- def combine_images_with_masks(img_top, img_bottom, mask_top, mask_bottom):
204
- img_top_pil = Image.open(io.BytesIO(img_top)).convert("RGBA")
205
- img_bottom_pil = Image.open(io.BytesIO(img_bottom)).convert("RGBA")
206
- mask_top_pil = Image.open(io.BytesIO(mask_top)).convert("L")
207
- mask_bottom_pil = Image.open(io.BytesIO(mask_bottom)).convert("L")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- img_top_pil = img_top_pil.resize((768, 1024))
210
- img_bottom_pil = img_bottom_pil.resize((768, 1024))
211
- mask_top_pil = mask_top_pil.resize((768, 1024))
212
- mask_bottom_pil = mask_bottom_pil.resize((768, 1024))
 
 
213
 
214
- img_top_pil.paste(img_bottom_pil, (0, 0), mask_bottom_pil)
215
- img_top_pil.paste(img_top_pil, (0, 0), mask_top_pil)
216
 
217
- with BytesIO() as output:
218
- img_top_pil.save(output, format="PNG")
219
- combined_img_str = base64.b64encode(output.getvalue()).decode("utf-8")
220
- return combined_img_str
221
 
222
  @app.route('/tryon', methods=['POST'])
223
  def tryon():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  try:
225
- data = request.json
226
- image_base64 = data.get('human_image')
227
- garm_img_base64 = data.get('garment_image')
228
- is_checked = data.get('is_checked', False)
229
- is_checked_crop = data.get('is_checked_crop', False)
230
- denoise_steps = data.get('denoise_steps', 20)
231
- seed = data.get('seed', 42)
232
- category = data.get('category', 'upper_body')
233
-
234
- human_img = decode_image_from_base64(image_base64)
235
- garm_img = decode_image_from_base64(garm_img_base64)
236
-
237
- combined_image = start_tryon(
238
- dict={"background": human_img},
239
- garm_img=garm_img,
240
- garment_des="",
241
- is_checked=is_checked,
242
- is_checked_crop=is_checked_crop,
243
- denoise_steps=denoise_steps,
244
- seed=seed,
245
- categorie=category,
246
- )
247
- result = combine_images_with_masks(combined_image, mask)
248
-
249
- return jsonify({
250
- 'image': result
251
- })
252
  except Exception as e:
253
- logging.error(f"Error in tryon route: {e}")
254
- return jsonify({'error': str(e)}), 500
255
 
256
- @app.route('/', methods=['GET'])
257
- def welcome():
258
- return "Welcome to IDM-VTON API"
259
 
260
  if __name__ == "__main__":
261
  app.run(debug=True, host="0.0.0.0", port=7860)
 
 
4
  from io import BytesIO
5
  import torch
6
  import base64
7
+ import io
8
  import logging
9
  import gradio as gr
10
  import numpy as np
 
33
  base_path = 'yisol/IDM-VTON'
34
  example_path = os.path.join(os.path.dirname(__file__), 'example')
35
 
 
36
  unet = UNet2DConditionModel.from_pretrained(
37
  base_path,
38
  subfolder="unet",
39
  torch_dtype=torch.float16,
40
  )
41
  unet.requires_grad_(False)
 
42
  tokenizer_one = AutoTokenizer.from_pretrained(
43
  base_path,
44
  subfolder="tokenizer",
 
82
  parsing_model = Parsing(0)
83
  openpose_model = OpenPose(0)
84
 
 
85
  UNet_Encoder.requires_grad_(False)
86
  image_encoder.requires_grad_(False)
87
  vae.requires_grad_(False)
88
  unet.requires_grad_(False)
89
  text_encoder_one.requires_grad_(False)
90
  text_encoder_two.requires_grad_(False)
 
91
  tensor_transfrom = transforms.Compose(
92
+ [
93
+ transforms.ToTensor(),
94
+ transforms.Normalize([0.5], [0.5]),
95
+ ]
96
+ )
97
 
98
  pipe = TryonPipeline.from_pretrained(
99
+ base_path,
100
+ unet=unet,
101
+ vae=vae,
102
+ feature_extractor= CLIPImageProcessor(),
103
+ text_encoder = text_encoder_one,
104
+ text_encoder_2 = text_encoder_two,
105
+ tokenizer = tokenizer_one,
106
+ tokenizer_2 = tokenizer_two,
107
+ scheduler = noise_scheduler,
108
+ image_encoder=image_encoder,
109
+ torch_dtype=torch.float16,
110
  )
111
  pipe.unet_encoder = UNet_Encoder
112
 
 
115
  grayscale_image = Image.fromarray(np_image).convert("L")
116
  binary_mask = np.array(grayscale_image) > threshold
117
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
118
+ for i in range(binary_mask.shape[0]):
119
+ for j in range(binary_mask.shape[1]):
120
+ if binary_mask[i, j]:
121
+ mask[i, j] = 1
122
  mask = (mask * 255).astype(np.uint8)
123
  output_mask = Image.fromarray(mask)
124
  return output_mask
125
 
126
+
127
  def decode_image_from_base64(base64_str):
128
  try:
129
  img_data = base64.b64decode(base64_str)
 
144
  raise
145
 
146
  @spaces.GPU
147
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
148
  device = "cuda"
149
  openpose_model.preprocessor.body_estimation.model.to(device)
150
  pipe.to(device)
 
169
 
170
  if is_checked:
171
  keypoints = openpose_model(human_img.resize((384, 512)))
172
+ model_parse, _ = parsing_model(human_img.resize((384, 512)))
173
+ mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
174
+ mask = mask.resize((768, 1024))
175
+ else:
176
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
177
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
178
+ mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
179
 
180
+ human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
181
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
 
 
 
182
 
183
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
184
+ pose_img = args.func(args, human_img_arg)
185
+ pose_img = pose_img[:, :, ::-1]
186
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
 
 
 
187
 
188
+ with torch.no_grad():
189
+ with torch.cuda.amp.autocast():
190
+ prompt = "model is wearing " + garment_des
191
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
192
+ with torch.inference_mode():
193
+ (
194
+ prompt_embeds,
195
+ negative_prompt_embeds,
196
+ pooled_prompt_embeds,
197
+ negative_pooled_prompt_embeds,
198
+ ) = pipe.encode_prompt(
199
+ prompt,
200
+ num_images_per_prompt=1,
201
+ do_classifier_free_guidance=True,
202
+ negative_prompt=negative_prompt,
203
+ )
204
+
205
+ prompt = "a photo of " + garment_des
206
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
207
+ if not isinstance(prompt, list):
208
+ prompt = [prompt] * 1
209
+ if not isinstance(negative_prompt, list):
210
+ negative_prompt = [negative_prompt] * 1
211
+ with torch.inference_mode():
212
+ (
213
+ prompt_embeds_c,
214
+ _,
215
+ _,
216
+ _,
217
+ ) = pipe.encode_prompt(
218
+ prompt,
219
+ num_images_per_prompt=1,
220
+ do_classifier_free_guidance=False,
221
+ negative_prompt=negative_prompt,
222
+ )
223
 
224
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
225
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
226
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
227
+ images = pipe(
228
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
229
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
230
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
231
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
232
+ num_inference_steps=denoise_steps,
233
+ generator=generator,
234
+ strength=1.0,
235
+ pose_img=pose_img.to(device, torch.float16),
236
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
237
+ cloth=garm_tensor.to(device, torch.float16),
238
+ mask_image=mask,
239
+ image=human_img,
240
+ height=1024,
241
+ width=768,
242
+ ip_adapter_image=garm_img.resize((768, 1024)),
243
+ guidance_scale=2.0,
244
+ )[0]
245
 
246
+ if is_checked_crop:
247
+ out_img = images[0].resize(crop_size)
248
+ human_img_orig.paste(out_img, (int(left), int(top)))
249
+ return human_img_orig, mask_gray
250
+ else:
251
+ return images[0], mask_gray
252
 
 
 
253
 
254
+ def clear_gpu_memory():
255
+ torch.cuda.empty_cache()
256
+ torch.cuda.synchronize()
 
257
 
258
  @app.route('/tryon', methods=['POST'])
259
  def tryon():
260
+ data = request.json
261
+ human_image = decode_image_from_base64(data['human_image'])
262
+ garment_image = decode_image_from_base64(data['garment_image'])
263
+ description = data.get('description')
264
+ use_auto_mask = data.get('use_auto_mask', True)
265
+ use_auto_crop = data.get('use_auto_crop', False)
266
+ denoise_steps = int(data.get('denoise_steps', 30))
267
+ seed = int(data.get('seed', 42))
268
+ categorie = data.get('categorie' , 'upper_body')
269
+ human_dict = {
270
+ 'background': human_image,
271
+ 'layers': [human_image] if not use_auto_mask else None,
272
+ 'composite': None
273
+ }
274
+ clear_gpu_memory()
275
+
276
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
277
+
278
+ output_base64 = encode_image_to_base64(output_image)
279
+ mask_base64 = encode_image_to_base64(mask_image)
280
+
281
+ return jsonify({
282
+ 'output_image': output_base64,
283
+ 'mask_image': mask_base64
284
+ })
285
+
286
+
287
+ def combine_images_with_masks(tops_image, bottoms_image, mask, is_checked_crop, crop_size):
288
  try:
289
+ # Logique de combinaison des images de haut et de bas
290
+ if is_checked_crop:
291
+ tops_image = tops_image.resize(crop_size)
292
+ bottoms_image = bottoms_image.resize(crop_size)
293
+ combined_image = Image.new('RGB', (tops_image.width, tops_image.height))
294
+ combined_image.paste(tops_image, (0, 0))
295
+ combined_image.paste(bottoms_image, (0, tops_image.height // 2))
296
+ else:
297
+ combined_image = Image.new('RGB', (tops_image.width, tops_image.height))
298
+ combined_image.paste(tops_image, (0, 0))
299
+ combined_image.paste(bottoms_image, (0, tops_image.height // 2))
300
+
301
+ return combined_image
302
+
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  except Exception as e:
304
+ raise ValueError(f"Error combining images with masks: {e}")
 
305
 
 
 
 
306
 
307
  if __name__ == "__main__":
308
  app.run(debug=True, host="0.0.0.0", port=7860)
309
+