Update app.py
Browse files
app.py
CHANGED
@@ -285,34 +285,26 @@ def tryon():
|
|
285 |
|
286 |
@spaces.GPU
|
287 |
def generate_mask(human_img, categorie='upper_body'):
|
|
|
|
|
|
|
|
|
288 |
try:
|
289 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
290 |
-
|
291 |
# Redimensionner l'image pour le modèle
|
292 |
-
human_img_resized = human_img.resize((384, 512))
|
293 |
-
|
294 |
-
#
|
295 |
-
openpose_model
|
296 |
-
parsing_model
|
297 |
-
|
298 |
-
# Convertir l'image en tenseur et déplacer sur l'appareil
|
299 |
-
input_tensor = transforms.ToTensor()(human_img_resized).unsqueeze(0).to(device)
|
300 |
-
|
301 |
-
# Génération des points clés et du masque
|
302 |
-
keypoints = openpose_model(input_tensor)
|
303 |
-
model_parse, _ = parsing_model(input_tensor)
|
304 |
mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
|
305 |
-
|
306 |
-
# Déplacement du masque sur le CPU pour traitement avec PIL
|
307 |
-
mask = mask.cpu()
|
308 |
-
|
309 |
# Redimensionner le masque à la taille d'origine de l'image
|
310 |
-
mask_resized =
|
311 |
|
312 |
return mask_resized
|
313 |
except Exception as e:
|
314 |
logging.error(f"Error generating mask: {e}")
|
315 |
-
raise e
|
316 |
|
317 |
|
318 |
@app.route('/generate_mask', methods=['POST'])
|
@@ -324,7 +316,7 @@ def generate_mask_api():
|
|
324 |
categorie = data.get('categorie', 'upper_body')
|
325 |
|
326 |
# Décodage de l'image à partir de base64
|
327 |
-
human_img = decode_image_from_base64(base64_image)
|
328 |
|
329 |
# Appeler la fonction pour générer le masque
|
330 |
mask_resized = generate_mask(human_img, categorie)
|
|
|
285 |
|
286 |
@spaces.GPU
|
287 |
def generate_mask(human_img, categorie='upper_body'):
|
288 |
+
device = "cuda"
|
289 |
+
openpose_model.preprocessor.body_estimation.model.to(device)
|
290 |
+
pipe.to(device)
|
291 |
+
|
292 |
try:
|
|
|
|
|
293 |
# Redimensionner l'image pour le modèle
|
294 |
+
human_img_resized = human_img.convert("RGB").resize((384, 512))
|
295 |
+
|
296 |
+
# Générer les points clés et le masque
|
297 |
+
keypoints = openpose_model(human_img_resized)
|
298 |
+
model_parse, _ = parsing_model(human_img_resized)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
|
300 |
+
|
|
|
|
|
|
|
301 |
# Redimensionner le masque à la taille d'origine de l'image
|
302 |
+
mask_resized = mask.resize(human_img.size)
|
303 |
|
304 |
return mask_resized
|
305 |
except Exception as e:
|
306 |
logging.error(f"Error generating mask: {e}")
|
307 |
+
raise e
|
308 |
|
309 |
|
310 |
@app.route('/generate_mask', methods=['POST'])
|
|
|
316 |
categorie = data.get('categorie', 'upper_body')
|
317 |
|
318 |
# Décodage de l'image à partir de base64
|
319 |
+
human_img = decode_image_from_base64(base64_image)
|
320 |
|
321 |
# Appeler la fonction pour générer le masque
|
322 |
mask_resized = generate_mask(human_img, categorie)
|