Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
0114ce2
·
verified ·
1 Parent(s): 1035629

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -65
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from flask import Flask, request, jsonify
3
- from PIL import Image , ImageChops
4
  from io import BytesIO
5
  import torch
6
  import base64
@@ -283,72 +283,36 @@ def tryon():
283
  'mask_image': mask_base64
284
  })
285
 
286
- def combine_masks(mask_upper, mask_lower):
287
- if mask_upper.size != mask_lower.size:
288
- mask_lower = mask_lower.resize(mask_upper.size)
289
-
290
- # Ensure masks are in 'L' mode
291
- mask_upper = mask_upper.convert("L")
292
- mask_lower = mask_lower.convert("L")
293
-
294
- # Combine the two masks with logical OR
295
- combined_mask = ImageChops.logical_or(mask_upper, mask_lower)
296
- return combined_mask
297
-
298
- @spaces.GPU
299
  def generate_mask(human_img, categorie='upper_body'):
300
- device = "cuda"
301
- openpose_model.preprocessor.body_estimation.model.to(device)
302
- pipe.to(device)
303
-
304
  try:
 
 
305
  # Redimensionner l'image pour le modèle
306
- human_img_resized = human_img.convert("RGB").resize((384, 512))
307
-
308
- if categorie == 'full_body':
309
- # Générer les masques pour 'upper_body' et 'lower_body'
310
- model_parse, _ = parsing_model(human_img_resized)
311
- keypoints = openpose_model(human_img_resized)
312
-
313
- # Ajouter des logs pour vérifier la sortie
314
- logging.info(f"model_parse type: {type(model_parse)}")
315
- logging.info(f"keypoints type: {type(keypoints)}")
316
-
317
- # Vérifiez si `parsing_model` renvoie des données correctes
318
- if not isinstance(model_parse, dict):
319
- raise TypeError("parsing_model must return a dictionary.")
320
- if not isinstance(keypoints, dict):
321
- raise TypeError("openpose_model must return a dictionary.")
322
-
323
- mask_upper_body, _ = get_mask_location('hd', 'upper_body', model_parse, keypoints)
324
- mask_lower_body, _ = get_mask_location('hd', 'lower_body', model_parse, keypoints)
325
-
326
- # Combiner les deux masques
327
- combined_mask = combine_masks(mask_upper_body, mask_lower_body)
328
- else:
329
- # Générer le masque pour la catégorie spécifiée
330
- keypoints = openpose_model(human_img_resized)
331
- model_parse, _ = parsing_model(human_img_resized)
332
-
333
- # Ajouter des logs pour vérifier la sortie
334
- logging.info(f"model_parse type: {type(model_parse)}")
335
- logging.info(f"keypoints type: {type(keypoints)}")
336
-
337
- # Vérifiez si `parsing_model` renvoie des données correctes
338
- if not isinstance(model_parse, dict):
339
- raise TypeError("parsing_model must return a dictionary.")
340
- if not isinstance(keypoints, dict):
341
- raise TypeError("openpose_model must return a dictionary.")
342
-
343
- combined_mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
344
-
345
  # Redimensionner le masque à la taille d'origine de l'image
346
- mask_resized = combined_mask.resize(human_img.size)
347
 
348
  return mask_resized
349
  except Exception as e:
350
  logging.error(f"Error generating mask: {e}")
351
- raise e
352
 
353
 
354
  @app.route('/generate_mask', methods=['POST'])
@@ -356,14 +320,11 @@ def generate_mask_api():
356
  try:
357
  # Récupérer les données de l'image à partir de la requête
358
  data = request.json
 
359
  categorie = data.get('categorie', 'upper_body')
360
 
361
- valid_categories = ['upper_body', 'lower_body', 'dresses', 'full_body']
362
- if categorie not in valid_categories:
363
- raise ValueError(f"Invalid category '{categorie}'. Valid categories are: {', '.join(valid_categories)}")
364
-
365
  # Décodage de l'image à partir de base64
366
- human_img = decode_image_from_base64(data['human_image'])
367
 
368
  # Appeler la fonction pour générer le masque
369
  mask_resized = generate_mask(human_img, categorie)
@@ -381,4 +342,3 @@ def generate_mask_api():
381
 
382
  if __name__ == "__main__":
383
  app.run(debug=True, host="0.0.0.0", port=7860)
384
-
 
1
  import os
2
  from flask import Flask, request, jsonify
3
+ from PIL import Image
4
  from io import BytesIO
5
  import torch
6
  import base64
 
283
  'mask_image': mask_base64
284
  })
285
 
286
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
287
  def generate_mask(human_img, categorie='upper_body'):
 
 
 
 
288
  try:
289
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
290
+
291
  # Redimensionner l'image pour le modèle
292
+ human_img_resized = human_img.resize((384, 512))
293
+
294
+ # Déplacer les modèles sur le bon appareil
295
+ openpose_model.to(device)
296
+ parsing_model.to(device)
297
+
298
+ # Convertir l'image en tenseur et déplacer sur l'appareil
299
+ input_tensor = transforms.ToTensor()(human_img_resized).unsqueeze(0).to(device)
300
+
301
+ # Génération des points clés et du masque
302
+ keypoints = openpose_model(input_tensor)
303
+ model_parse, _ = parsing_model(input_tensor)
304
+ mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
305
+
306
+ # Déplacement du masque sur le CPU pour traitement avec PIL
307
+ mask = mask.cpu()
308
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # Redimensionner le masque à la taille d'origine de l'image
310
+ mask_resized = transforms.ToPILImage()(mask.squeeze(0)).resize(human_img.size)
311
 
312
  return mask_resized
313
  except Exception as e:
314
  logging.error(f"Error generating mask: {e}")
315
+ raise e # Renvoyer l'exception pour que l'API puisse la gérer
316
 
317
 
318
  @app.route('/generate_mask', methods=['POST'])
 
320
  try:
321
  # Récupérer les données de l'image à partir de la requête
322
  data = request.json
323
+ base64_image = data.get('human_image')
324
  categorie = data.get('categorie', 'upper_body')
325
 
 
 
 
 
326
  # Décodage de l'image à partir de base64
327
+ human_img = decode_image_from_base64(base64_image).convert("RGB")
328
 
329
  # Appeler la fonction pour générer le masque
330
  mask_resized = generate_mask(human_img, categorie)
 
342
 
343
  if __name__ == "__main__":
344
  app.run(debug=True, host="0.0.0.0", port=7860)