Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
a37ae3f
·
verified ·
1 Parent(s): 0114ce2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -21
app.py CHANGED
@@ -285,34 +285,26 @@ def tryon():
285
 
286
  @spaces.GPU
287
  def generate_mask(human_img, categorie='upper_body'):
 
 
 
 
288
  try:
289
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
290
-
291
  # Redimensionner l'image pour le modèle
292
- human_img_resized = human_img.resize((384, 512))
293
-
294
- # Déplacer les modèles sur le bon appareil
295
- openpose_model.to(device)
296
- parsing_model.to(device)
297
-
298
- # Convertir l'image en tenseur et déplacer sur l'appareil
299
- input_tensor = transforms.ToTensor()(human_img_resized).unsqueeze(0).to(device)
300
-
301
- # Génération des points clés et du masque
302
- keypoints = openpose_model(input_tensor)
303
- model_parse, _ = parsing_model(input_tensor)
304
  mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
305
-
306
- # Déplacement du masque sur le CPU pour traitement avec PIL
307
- mask = mask.cpu()
308
-
309
  # Redimensionner le masque à la taille d'origine de l'image
310
- mask_resized = transforms.ToPILImage()(mask.squeeze(0)).resize(human_img.size)
311
 
312
  return mask_resized
313
  except Exception as e:
314
  logging.error(f"Error generating mask: {e}")
315
- raise e # Renvoyer l'exception pour que l'API puisse la gérer
316
 
317
 
318
  @app.route('/generate_mask', methods=['POST'])
@@ -324,7 +316,7 @@ def generate_mask_api():
324
  categorie = data.get('categorie', 'upper_body')
325
 
326
  # Décodage de l'image à partir de base64
327
- human_img = decode_image_from_base64(base64_image).convert("RGB")
328
 
329
  # Appeler la fonction pour générer le masque
330
  mask_resized = generate_mask(human_img, categorie)
 
285
 
286
  @spaces.GPU
287
  def generate_mask(human_img, categorie='upper_body'):
288
+ device = "cuda"
289
+ openpose_model.preprocessor.body_estimation.model.to(device)
290
+ pipe.to(device)
291
+
292
  try:
 
 
293
  # Redimensionner l'image pour le modèle
294
+ human_img_resized = human_img.convert("RGB").resize((384, 512))
295
+
296
+ # Générer les points clés et le masque
297
+ keypoints = openpose_model(human_img_resized)
298
+ model_parse, _ = parsing_model(human_img_resized)
 
 
 
 
 
 
 
299
  mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
300
+
 
 
 
301
  # Redimensionner le masque à la taille d'origine de l'image
302
+ mask_resized = mask.resize(human_img.size)
303
 
304
  return mask_resized
305
  except Exception as e:
306
  logging.error(f"Error generating mask: {e}")
307
+ raise e
308
 
309
 
310
  @app.route('/generate_mask', methods=['POST'])
 
316
  categorie = data.get('categorie', 'upper_body')
317
 
318
  # Décodage de l'image à partir de base64
319
+ human_img = decode_image_from_base64(base64_image)
320
 
321
  # Appeler la fonction pour générer le masque
322
  mask_resized = generate_mask(human_img, categorie)