Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
from flask import Flask, request, jsonify
|
3 |
-
from PIL import Image
|
4 |
from io import BytesIO
|
5 |
import torch
|
6 |
import base64
|
@@ -283,72 +283,36 @@ def tryon():
|
|
283 |
'mask_image': mask_base64
|
284 |
})
|
285 |
|
286 |
-
|
287 |
-
if mask_upper.size != mask_lower.size:
|
288 |
-
mask_lower = mask_lower.resize(mask_upper.size)
|
289 |
-
|
290 |
-
# Ensure masks are in 'L' mode
|
291 |
-
mask_upper = mask_upper.convert("L")
|
292 |
-
mask_lower = mask_lower.convert("L")
|
293 |
-
|
294 |
-
# Combine the two masks with logical OR
|
295 |
-
combined_mask = ImageChops.logical_or(mask_upper, mask_lower)
|
296 |
-
return combined_mask
|
297 |
-
|
298 |
-
@spaces.GPU
|
299 |
def generate_mask(human_img, categorie='upper_body'):
|
300 |
-
device = "cuda"
|
301 |
-
openpose_model.preprocessor.body_estimation.model.to(device)
|
302 |
-
pipe.to(device)
|
303 |
-
|
304 |
try:
|
|
|
|
|
305 |
# Redimensionner l'image pour le modèle
|
306 |
-
human_img_resized = human_img.
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
mask_upper_body, _ = get_mask_location('hd', 'upper_body', model_parse, keypoints)
|
324 |
-
mask_lower_body, _ = get_mask_location('hd', 'lower_body', model_parse, keypoints)
|
325 |
-
|
326 |
-
# Combiner les deux masques
|
327 |
-
combined_mask = combine_masks(mask_upper_body, mask_lower_body)
|
328 |
-
else:
|
329 |
-
# Générer le masque pour la catégorie spécifiée
|
330 |
-
keypoints = openpose_model(human_img_resized)
|
331 |
-
model_parse, _ = parsing_model(human_img_resized)
|
332 |
-
|
333 |
-
# Ajouter des logs pour vérifier la sortie
|
334 |
-
logging.info(f"model_parse type: {type(model_parse)}")
|
335 |
-
logging.info(f"keypoints type: {type(keypoints)}")
|
336 |
-
|
337 |
-
# Vérifiez si `parsing_model` renvoie des données correctes
|
338 |
-
if not isinstance(model_parse, dict):
|
339 |
-
raise TypeError("parsing_model must return a dictionary.")
|
340 |
-
if not isinstance(keypoints, dict):
|
341 |
-
raise TypeError("openpose_model must return a dictionary.")
|
342 |
-
|
343 |
-
combined_mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
|
344 |
-
|
345 |
# Redimensionner le masque à la taille d'origine de l'image
|
346 |
-
mask_resized =
|
347 |
|
348 |
return mask_resized
|
349 |
except Exception as e:
|
350 |
logging.error(f"Error generating mask: {e}")
|
351 |
-
raise e
|
352 |
|
353 |
|
354 |
@app.route('/generate_mask', methods=['POST'])
|
@@ -356,14 +320,11 @@ def generate_mask_api():
|
|
356 |
try:
|
357 |
# Récupérer les données de l'image à partir de la requête
|
358 |
data = request.json
|
|
|
359 |
categorie = data.get('categorie', 'upper_body')
|
360 |
|
361 |
-
valid_categories = ['upper_body', 'lower_body', 'dresses', 'full_body']
|
362 |
-
if categorie not in valid_categories:
|
363 |
-
raise ValueError(f"Invalid category '{categorie}'. Valid categories are: {', '.join(valid_categories)}")
|
364 |
-
|
365 |
# Décodage de l'image à partir de base64
|
366 |
-
human_img = decode_image_from_base64(
|
367 |
|
368 |
# Appeler la fonction pour générer le masque
|
369 |
mask_resized = generate_mask(human_img, categorie)
|
@@ -381,4 +342,3 @@ def generate_mask_api():
|
|
381 |
|
382 |
if __name__ == "__main__":
|
383 |
app.run(debug=True, host="0.0.0.0", port=7860)
|
384 |
-
|
|
|
1 |
import os
|
2 |
from flask import Flask, request, jsonify
|
3 |
+
from PIL import Image
|
4 |
from io import BytesIO
|
5 |
import torch
|
6 |
import base64
|
|
|
283 |
'mask_image': mask_base64
|
284 |
})
|
285 |
|
286 |
+
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
def generate_mask(human_img, categorie='upper_body'):
|
|
|
|
|
|
|
|
|
288 |
try:
|
289 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
290 |
+
|
291 |
# Redimensionner l'image pour le modèle
|
292 |
+
human_img_resized = human_img.resize((384, 512))
|
293 |
+
|
294 |
+
# Déplacer les modèles sur le bon appareil
|
295 |
+
openpose_model.to(device)
|
296 |
+
parsing_model.to(device)
|
297 |
+
|
298 |
+
# Convertir l'image en tenseur et déplacer sur l'appareil
|
299 |
+
input_tensor = transforms.ToTensor()(human_img_resized).unsqueeze(0).to(device)
|
300 |
+
|
301 |
+
# Génération des points clés et du masque
|
302 |
+
keypoints = openpose_model(input_tensor)
|
303 |
+
model_parse, _ = parsing_model(input_tensor)
|
304 |
+
mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
|
305 |
+
|
306 |
+
# Déplacement du masque sur le CPU pour traitement avec PIL
|
307 |
+
mask = mask.cpu()
|
308 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
# Redimensionner le masque à la taille d'origine de l'image
|
310 |
+
mask_resized = transforms.ToPILImage()(mask.squeeze(0)).resize(human_img.size)
|
311 |
|
312 |
return mask_resized
|
313 |
except Exception as e:
|
314 |
logging.error(f"Error generating mask: {e}")
|
315 |
+
raise e # Renvoyer l'exception pour que l'API puisse la gérer
|
316 |
|
317 |
|
318 |
@app.route('/generate_mask', methods=['POST'])
|
|
|
320 |
try:
|
321 |
# Récupérer les données de l'image à partir de la requête
|
322 |
data = request.json
|
323 |
+
base64_image = data.get('human_image')
|
324 |
categorie = data.get('categorie', 'upper_body')
|
325 |
|
|
|
|
|
|
|
|
|
326 |
# Décodage de l'image à partir de base64
|
327 |
+
human_img = decode_image_from_base64(base64_image).convert("RGB")
|
328 |
|
329 |
# Appeler la fonction pour générer le masque
|
330 |
mask_resized = generate_mask(human_img, categorie)
|
|
|
342 |
|
343 |
if __name__ == "__main__":
|
344 |
app.run(debug=True, host="0.0.0.0", port=7860)
|
|