Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
227771d
·
verified ·
1 Parent(s): 5183562

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -65
app.py CHANGED
@@ -36,59 +36,19 @@ app = Flask(__name__)
36
  base_path = 'yisol/IDM-VTON'
37
  example_path = os.path.join(os.path.dirname(__file__), 'example')
38
 
39
- unet = UNet2DConditionModel.from_pretrained(
40
- base_path,
41
- subfolder="unet",
42
- torch_dtype=torch.float16,
43
- force_download=False
44
- )
45
- unet.requires_grad_(False)
46
- tokenizer_one = AutoTokenizer.from_pretrained(
47
- base_path,
48
- subfolder="tokenizer",
49
- revision=None,
50
- use_fast=False,
51
- force_download=False
52
- )
53
- tokenizer_two = AutoTokenizer.from_pretrained(
54
- base_path,
55
- subfolder="tokenizer_2",
56
- revision=None,
57
- use_fast=False,
58
- force_download=False
59
- )
60
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
61
-
62
- text_encoder_one = CLIPTextModel.from_pretrained(
63
- base_path,
64
- subfolder="text_encoder",
65
- torch_dtype=torch.float16,
66
- force_download=False
67
- )
68
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
69
- base_path,
70
- subfolder="text_encoder_2",
71
- torch_dtype=torch.float16,
72
- force_download=False
73
- )
74
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
75
- base_path,
76
- subfolder="image_encoder",
77
- torch_dtype=torch.float16,
78
- force_download=False
79
- )
80
- vae = AutoencoderKL.from_pretrained(base_path,
81
- subfolder="vae",
82
- torch_dtype=torch.float16,
83
- force_download=False
84
- )
85
-
86
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
87
- base_path,
88
- subfolder="unet_encoder",
89
- torch_dtype=torch.float16,
90
- force_download=False
91
- )
92
 
93
  parsing_model = Parsing(0)
94
  openpose_model = OpenPose(0)
@@ -169,6 +129,12 @@ def save_image(img):
169
  img.save(unique_name, format="WEBP", lossless=True)
170
  return unique_name
171
 
 
 
 
 
 
 
172
  @spaces.GPU
173
  def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
174
  device = "cuda"
@@ -176,7 +142,7 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
176
  pipe.to(device)
177
  pipe.unet_encoder.to(device)
178
 
179
- garm_img = garm_img.convert("RGB").resize((768, 1024))
180
  human_img_orig = dict["background"].convert("RGB")
181
 
182
  if is_checked_crop:
@@ -189,9 +155,9 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
189
  bottom = (height + target_height) / 2
190
  cropped_img = human_img_orig.crop((left, top, right, bottom))
191
  crop_size = cropped_img.size
192
- human_img = cropped_img.resize((768, 1024))
193
  else:
194
- human_img = human_img_orig.resize((768, 1024))
195
 
196
  if is_checked:
197
  keypoints = openpose_model(human_img.resize((384, 512)))
@@ -199,7 +165,7 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
199
  mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
200
  mask = mask.resize((768, 1024))
201
  else:
202
- mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
203
  mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
204
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
205
 
@@ -209,7 +175,7 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
209
  args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
210
  pose_img = args.func(args, human_img_arg)
211
  pose_img = pose_img[:, :, ::-1]
212
- pose_img = Image.fromarray(pose_img).resize((768, 1024))
213
 
214
  with torch.no_grad():
215
  with torch.cuda.amp.autocast():
@@ -265,10 +231,10 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
265
  image=human_img,
266
  height=1024,
267
  width=768,
268
- ip_adapter_image=garm_img.resize((768, 1024)),
269
  guidance_scale=2.0,
270
  )[0]
271
-
272
  if is_checked_crop:
273
  out_img = images[0].resize(crop_size)
274
  human_img_orig.paste(out_img, (int(left), int(top)))
@@ -277,10 +243,6 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
277
  return images[0], mask_gray
278
 
279
 
280
- def clear_gpu_memory():
281
- torch.cuda.empty_cache()
282
- torch.cuda.synchronize()
283
-
284
  def process_image(image_data):
285
  # Vérifie si l'image est en base64 ou URL
286
  if image_data.startswith('http://') or image_data.startswith('https://'):
@@ -304,7 +266,6 @@ def tryon():
304
  'layers': [human_image] if not use_auto_mask else None,
305
  'composite': None
306
  }
307
- clear_gpu_memory()
308
 
309
  output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
310
 
 
36
  base_path = 'yisol/IDM-VTON'
37
  example_path = os.path.join(os.path.dirname(__file__), 'example')
38
 
39
+ def load_models():
40
+ global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
41
+ unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
42
+ tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
43
+ tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
44
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
45
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
46
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
47
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
48
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
49
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
50
+
51
+ load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  parsing_model = Parsing(0)
54
  openpose_model = OpenPose(0)
 
129
  img.save(unique_name, format="WEBP", lossless=True)
130
  return unique_name
131
 
132
+
133
+ def clear_gpu_memory():
134
+ torch.cuda.empty_cache()
135
+ torch.cuda.synchronize()
136
+
137
+
138
  @spaces.GPU
139
  def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
140
  device = "cuda"
 
142
  pipe.to(device)
143
  pipe.unet_encoder.to(device)
144
 
145
+ garm_img = garm_img.convert("RGB").resize((512, 768))
146
  human_img_orig = dict["background"].convert("RGB")
147
 
148
  if is_checked_crop:
 
155
  bottom = (height + target_height) / 2
156
  cropped_img = human_img_orig.crop((left, top, right, bottom))
157
  crop_size = cropped_img.size
158
+ human_img = cropped_img.resize((512, 768))
159
  else:
160
+ human_img = human_img_orig.resize((512, 768))
161
 
162
  if is_checked:
163
  keypoints = openpose_model(human_img.resize((384, 512)))
 
165
  mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
166
  mask = mask.resize((768, 1024))
167
  else:
168
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((512, 768)))
169
  mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
170
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
171
 
 
175
  args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
176
  pose_img = args.func(args, human_img_arg)
177
  pose_img = pose_img[:, :, ::-1]
178
+ pose_img = Image.fromarray(pose_img).resize((512, 768))
179
 
180
  with torch.no_grad():
181
  with torch.cuda.amp.autocast():
 
231
  image=human_img,
232
  height=1024,
233
  width=768,
234
+ ip_adapter_image=garm_img.resize((512, 768)),
235
  guidance_scale=2.0,
236
  )[0]
237
+ clear_gpu_memory()
238
  if is_checked_crop:
239
  out_img = images[0].resize(crop_size)
240
  human_img_orig.paste(out_img, (int(left), int(top)))
 
243
  return images[0], mask_gray
244
 
245
 
 
 
 
 
246
  def process_image(image_data):
247
  # Vérifie si l'image est en base64 ou URL
248
  if image_data.startswith('http://') or image_data.startswith('https://'):
 
266
  'layers': [human_image] if not use_auto_mask else None,
267
  'composite': None
268
  }
 
269
 
270
  output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
271