Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
df466fb
·
verified ·
1 Parent(s): 29c157c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -167
app.py CHANGED
@@ -1,13 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
- from flask import Flask, request, jsonify, logging
3
  from PIL import Image
4
  from io import BytesIO
5
  import torch
6
  import base64
7
- import uuid
8
- import random
9
  import logging
 
 
10
  import spaces
 
 
11
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
12
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
13
  from src.unet_hacked_tryon import UNet2DConditionModel
@@ -21,6 +70,7 @@ from transformers import (
21
  from diffusers import DDPMScheduler, AutoencoderKL
22
  from utils_mask import get_mask_location
23
  from torchvision import transforms
 
24
  from preprocess.humanparsing.run_parsing import Parsing
25
  from preprocess.openpose.run_openpose import OpenPose
26
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
@@ -28,97 +78,117 @@ from torchvision.transforms.functional import to_pil_image
28
 
29
  app = Flask(__name__)
30
 
31
- # Set base paths
32
  base_path = 'yisol/IDM-VTON'
 
33
 
34
- # Load models
35
- def load_models():
36
- global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder, pipe, parsing_model, openpose_model
37
- try:
38
- unet = UNet2DConditionModel.from_pretrained(
39
- base_path,
40
- subfolder="unet",
41
- torch_dtype=torch.float16,
42
- force_download=False
43
- )
44
- tokenizer_one = AutoTokenizer.from_pretrained(
45
- base_path,
46
- subfolder="tokenizer",
47
- use_fast=False,
48
- force_download=False
49
- )
50
- tokenizer_two = AutoTokenizer.from_pretrained(
51
- base_path,
52
- subfolder="tokenizer_2",
53
- use_fast=False,
54
- force_download=False
55
- )
56
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
57
- text_encoder_one = CLIPTextModel.from_pretrained(
58
- base_path,
59
- subfolder="text_encoder",
60
- torch_dtype=torch.float16,
61
- force_download=False
62
- )
63
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
64
- base_path,
65
- subfolder="text_encoder_2",
66
- torch_dtype=torch.float16,
67
- force_download=False
68
- )
69
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
70
- base_path,
71
- subfolder="image_encoder",
72
- torch_dtype=torch.float16,
73
- force_download=False
74
- )
75
- vae = AutoencoderKL.from_pretrained(base_path,
76
- subfolder="vae",
77
- torch_dtype=torch.float16,
78
- force_download=False
79
- )
80
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
81
- base_path,
82
- subfolder="unet_encoder",
83
- torch_dtype=torch.float16,
84
- force_download=False
85
- )
86
- parsing_model = Parsing(0)
87
- openpose_model = OpenPose(0)
88
-
89
- # Disable gradients for performance
90
- for model in [unet, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder]:
91
- model.requires_grad_(False)
92
-
93
- # Initialize pipeline
94
- pipe = TryonPipeline.from_pretrained(
95
- base_path,
96
- unet=unet,
97
- vae=vae,
98
- feature_extractor=CLIPImageProcessor(),
99
- text_encoder=text_encoder_one,
100
- text_encoder_2=text_encoder_two,
101
- tokenizer=tokenizer_one,
102
- tokenizer_2=tokenizer_two,
103
- scheduler=noise_scheduler,
104
- image_encoder=image_encoder,
105
- torch_dtype=torch.float16,
106
- force_download=False
107
- )
108
- pipe.unet_encoder = UNet_Encoder
109
- except Exception as e:
110
- logging.error(f"Error loading models: {e}")
111
- raise
112
 
113
- # Pre-load models during initialization
114
- load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Utility functions
117
  def pil_to_binary_mask(pil_image, threshold=0):
118
- np_image = np.array(pil_image.convert("L")) # Convert to grayscale
119
- binary_mask = np_image > threshold
120
- mask = (binary_mask * 255).astype(np.uint8)
121
- return Image.fromarray(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def decode_image_from_base64(base64_str):
124
  try:
@@ -133,104 +203,257 @@ def encode_image_to_base64(img):
133
  try:
134
  buffered = BytesIO()
135
  img.save(buffered, format="PNG")
136
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
137
  except Exception as e:
138
  logging.error(f"Error encoding image: {e}")
139
  raise
140
 
141
- def process_image(image_data):
142
- try:
143
- if image_data.startswith('http://') or image_data.startswith('https://'):
144
- response = requests.get(image_data)
145
- response.raise_for_status()
146
- return Image.open(BytesIO(response.content))
147
- else:
148
- return decode_image_from_base64(image_data)
149
- except Exception as e:
150
- logging.error(f"Error processing image: {e}")
151
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def clear_gpu_memory():
154
  torch.cuda.empty_cache()
 
155
 
156
- # Main try-on function
157
- @torch.no_grad()
158
- @spaces.GPU
159
- def start_tryon(human_dict, garment_img, garment_des, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie='upper_body'):
160
- try:
161
- device = torch.device("cuda")
162
- pipe.to(device)
163
- pipe.unet_encoder.to(device)
164
- openpose_model.preprocessor.body_estimation.model.to(device)
165
-
166
- human_img = human_dict["background"].convert("RGB").resize((768, 1024))
167
- garment_img = garment_img.convert("RGB").resize((768, 1024))
168
-
169
- if use_auto_crop:
170
- width, height = human_img.size
171
- target_width = int(min(width, height * (3 / 4)))
172
- target_height = int(min(height, width * (4 / 3)))
173
- left = (width - target_width) // 2
174
- top = (height - target_height) // 2
175
- cropped_img = human_img.crop((left, top, left + target_width, top + target_height))
176
- crop_size = cropped_img.size
177
- human_img = cropped_img.resize((768, 1024))
178
- else:
179
- crop_size = None
180
-
181
- if use_auto_mask:
182
- keypoints = openpose_model(human_img.resize((384, 512)))
183
- model_parse, _ = parsing_model(human_img.resize((384, 512)))
184
- mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
185
- mask = mask.resize((768, 1024))
186
- else:
187
- mask = pil_to_binary_mask(human_dict['layers'][0].convert("RGB").resize((768, 1024)))
188
-
189
- generator = torch.Generator(device).manual_seed(seed)
190
- images = pipe(
191
- prompt=f"model is wearing {garment_des}",
192
- negative_prompt="monochrome, lowres, bad anatomy, worst quality",
193
- num_inference_steps=denoise_steps,
194
- generator=generator,
195
- guidance_scale=2.0,
196
- image=human_img,
197
- mask_image=mask,
198
- cloth=garment_img,
199
- height=1024,
200
- width=768
201
- )[0]
202
-
203
- if crop_size:
204
- out_img = images[0].resize(crop_size)
205
- human_dict["background"].paste(out_img, (left, top))
206
- return human_dict["background"]
207
- return images[0]
208
-
209
- except Exception as e:
210
- logging.error(f"Error during try-on: {e}")
211
- raise
212
- finally:
213
- print("end generation")
214
- #clear_gpu_memory()
215
 
216
- # API endpoints
217
  @app.route('/tryon', methods=['POST'])
218
  def tryon():
219
  data = request.json
220
  human_image = process_image(data['human_image'])
221
  garment_image = process_image(data['garment_image'])
222
- description = data.get('description', 'garment')
223
  use_auto_mask = data.get('use_auto_mask', True)
224
  use_auto_crop = data.get('use_auto_crop', False)
225
  denoise_steps = int(data.get('denoise_steps', 30))
226
- seed = int(data.get('seed', random.randint(0, 999999)))
227
- categorie = data.get('categorie', 'upper_body')
 
 
 
 
 
 
228
 
229
- human_dict = {'background': human_image, 'layers': [human_image]}
230
- output_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie)
231
 
232
  output_base64 = encode_image_to_base64(output_image)
233
- return jsonify({'output_image': output_base64})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  if __name__ == "__main__":
236
  app.run(debug=False, host="0.0.0.0", port=7860)
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Solutions
10
+ Pricing
11
+
12
+
13
+
14
+ Spaces:
15
+
16
+ Saad0KH
17
+ /
18
+ IDM-VTON
19
+
20
+
21
+ like
22
+ 0
23
+
24
+ Logs
25
+ App
26
+ Files
27
+ Community
28
+ Settings
29
+ IDM-VTON
30
+ /
31
+ app.py
32
+
33
+ Saad0KH's picture
34
+ Saad0KH
35
+ Update app.py
36
+ f45631d
37
+ verified
38
+ 2 days ago
39
+ raw
40
+
41
+ Copy download link
42
+ history
43
+ blame
44
+ No virus
45
+
46
+ 14.9 kB
47
  import os
48
+ from flask import Flask, request, jsonify,send_file
49
  from PIL import Image
50
  from io import BytesIO
51
  import torch
52
  import base64
53
+ import io
 
54
  import logging
55
+ import gradio as gr
56
+ import numpy as np
57
  import spaces
58
+ import uuid
59
+ import random
60
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
61
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
62
  from src.unet_hacked_tryon import UNet2DConditionModel
 
70
  from diffusers import DDPMScheduler, AutoencoderKL
71
  from utils_mask import get_mask_location
72
  from torchvision import transforms
73
+ import apply_net
74
  from preprocess.humanparsing.run_parsing import Parsing
75
  from preprocess.openpose.run_openpose import OpenPose
76
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
 
78
 
79
  app = Flask(__name__)
80
 
 
81
  base_path = 'yisol/IDM-VTON'
82
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
83
 
84
+ unet = UNet2DConditionModel.from_pretrained(
85
+ base_path,
86
+ subfolder="unet",
87
+ torch_dtype=torch.float16,
88
+ force_download=False
89
+ )
90
+ unet.requires_grad_(False)
91
+ tokenizer_one = AutoTokenizer.from_pretrained(
92
+ base_path,
93
+ subfolder="tokenizer",
94
+ revision=None,
95
+ use_fast=False,
96
+ force_download=False
97
+ )
98
+ tokenizer_two = AutoTokenizer.from_pretrained(
99
+ base_path,
100
+ subfolder="tokenizer_2",
101
+ revision=None,
102
+ use_fast=False,
103
+ force_download=False
104
+ )
105
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ text_encoder_one = CLIPTextModel.from_pretrained(
108
+ base_path,
109
+ subfolder="text_encoder",
110
+ torch_dtype=torch.float16,
111
+ force_download=False
112
+ )
113
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
114
+ base_path,
115
+ subfolder="text_encoder_2",
116
+ torch_dtype=torch.float16,
117
+ force_download=False
118
+ )
119
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
120
+ base_path,
121
+ subfolder="image_encoder",
122
+ torch_dtype=torch.float16,
123
+ force_download=False
124
+ )
125
+ vae = AutoencoderKL.from_pretrained(base_path,
126
+ subfolder="vae",
127
+ torch_dtype=torch.float16,
128
+ force_download=False
129
+ )
130
+
131
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
132
+ base_path,
133
+ subfolder="unet_encoder",
134
+ torch_dtype=torch.float16,
135
+ force_download=False
136
+ )
137
+
138
+ parsing_model = Parsing(0)
139
+ openpose_model = OpenPose(0)
140
+
141
+ UNet_Encoder.requires_grad_(False)
142
+ image_encoder.requires_grad_(False)
143
+ vae.requires_grad_(False)
144
+ unet.requires_grad_(False)
145
+ text_encoder_one.requires_grad_(False)
146
+ text_encoder_two.requires_grad_(False)
147
+ tensor_transfrom = transforms.Compose(
148
+ [
149
+ transforms.ToTensor(),
150
+ transforms.Normalize([0.5], [0.5]),
151
+ ]
152
+ )
153
+
154
+ pipe = TryonPipeline.from_pretrained(
155
+ base_path,
156
+ unet=unet,
157
+ vae=vae,
158
+ feature_extractor= CLIPImageProcessor(),
159
+ text_encoder = text_encoder_one,
160
+ text_encoder_2 = text_encoder_two,
161
+ tokenizer = tokenizer_one,
162
+ tokenizer_2 = tokenizer_two,
163
+ scheduler = noise_scheduler,
164
+ image_encoder=image_encoder,
165
+ torch_dtype=torch.float16,
166
+ force_download=False
167
+ )
168
+ pipe.unet_encoder = UNet_Encoder
169
 
 
170
  def pil_to_binary_mask(pil_image, threshold=0):
171
+ np_image = np.array(pil_image)
172
+ grayscale_image = Image.fromarray(np_image).convert("L")
173
+ binary_mask = np.array(grayscale_image) > threshold
174
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
175
+ for i in range(binary_mask.shape[0]):
176
+ for j in range(binary_mask.shape[1]):
177
+ if binary_mask[i, j]:
178
+ mask[i, j] = 1
179
+ mask = (mask * 255).astype(np.uint8)
180
+ output_mask = Image.fromarray(mask)
181
+ return output_mask
182
+
183
+ def get_image_from_url(url):
184
+ try:
185
+ response = requests.get(url)
186
+ response.raise_for_status() # Vérifie les erreurs HTTP
187
+ img = Image.open(BytesIO(response.content))
188
+ return img
189
+ except Exception as e:
190
+ logging.error(f"Error fetching image from URL: {e}")
191
+ raise
192
 
193
  def decode_image_from_base64(base64_str):
194
  try:
 
203
  try:
204
  buffered = BytesIO()
205
  img.save(buffered, format="PNG")
206
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
207
+ return img_str
208
  except Exception as e:
209
  logging.error(f"Error encoding image: {e}")
210
  raise
211
 
212
+ def save_image(img):
213
+ unique_name = str(uuid.uuid4()) + ".webp"
214
+ img.save(unique_name, format="WEBP", lossless=True)
215
+ return unique_name
216
+
217
+ @spaces.GPU
218
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
219
+ device = "cuda"
220
+ openpose_model.preprocessor.body_estimation.model.to(device)
221
+ pipe.to(device)
222
+ pipe.unet_encoder.to(device)
223
+
224
+ garm_img = garm_img.convert("RGB").resize((768, 1024))
225
+ human_img_orig = dict["background"].convert("RGB")
226
+
227
+ if is_checked_crop:
228
+ width, height = human_img_orig.size
229
+ target_width = int(min(width, height * (3 / 4)))
230
+ target_height = int(min(height, width * (4 / 3)))
231
+ left = (width - target_width) / 2
232
+ top = (height - target_height) / 2
233
+ right = (width + target_width) / 2
234
+ bottom = (height + target_height) / 2
235
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
236
+ crop_size = cropped_img.size
237
+ human_img = cropped_img.resize((768, 1024))
238
+ else:
239
+ human_img = human_img_orig.resize((768, 1024))
240
+
241
+ if is_checked:
242
+ keypoints = openpose_model(human_img.resize((384, 512)))
243
+ model_parse, _ = parsing_model(human_img.resize((384, 512)))
244
+ mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
245
+ mask = mask.resize((768, 1024))
246
+ else:
247
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
248
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
249
+ mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
250
+
251
+ human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
252
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
253
+
254
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
255
+ pose_img = args.func(args, human_img_arg)
256
+ pose_img = pose_img[:, :, ::-1]
257
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
258
+
259
+ with torch.no_grad():
260
+ with torch.cuda.amp.autocast():
261
+ prompt = "model is wearing " + garment_des
262
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
263
+ with torch.inference_mode():
264
+ (
265
+ prompt_embeds,
266
+ negative_prompt_embeds,
267
+ pooled_prompt_embeds,
268
+ negative_pooled_prompt_embeds,
269
+ ) = pipe.encode_prompt(
270
+ prompt,
271
+ num_images_per_prompt=1,
272
+ do_classifier_free_guidance=True,
273
+ negative_prompt=negative_prompt,
274
+ )
275
+
276
+ prompt = "a photo of " + garment_des
277
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
278
+ if not isinstance(prompt, list):
279
+ prompt = [prompt] * 1
280
+ if not isinstance(negative_prompt, list):
281
+ negative_prompt = [negative_prompt] * 1
282
+ with torch.inference_mode():
283
+ (
284
+ prompt_embeds_c,
285
+ _,
286
+ _,
287
+ _,
288
+ ) = pipe.encode_prompt(
289
+ prompt,
290
+ num_images_per_prompt=1,
291
+ do_classifier_free_guidance=False,
292
+ negative_prompt=negative_prompt,
293
+ )
294
+
295
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
296
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
297
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
298
+ images = pipe(
299
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
300
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
301
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
302
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
303
+ num_inference_steps=denoise_steps,
304
+ generator=generator,
305
+ strength=1.0,
306
+ pose_img=pose_img.to(device, torch.float16),
307
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
308
+ cloth=garm_tensor.to(device, torch.float16),
309
+ mask_image=mask,
310
+ image=human_img,
311
+ height=1024,
312
+ width=768,
313
+ ip_adapter_image=garm_img.resize((768, 1024)),
314
+ guidance_scale=2.0,
315
+ )[0]
316
+
317
+ if is_checked_crop:
318
+ out_img = images[0].resize(crop_size)
319
+ human_img_orig.paste(out_img, (int(left), int(top)))
320
+ return human_img_orig, mask_gray
321
+ else:
322
+ return images[0], mask_gray
323
+
324
 
325
  def clear_gpu_memory():
326
  torch.cuda.empty_cache()
327
+ torch.cuda.synchronize()
328
 
329
+ def process_image(image_data):
330
+ # Vérifie si l'image est en base64 ou URL
331
+ if image_data.startswith('http://') or image_data.startswith('https://'):
332
+ return get_image_from_url(image_data) # Télécharge l'image depuis l'URL
333
+ else:
334
+ return decode_image_from_base64(image_data) # Décode l'image base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
 
336
  @app.route('/tryon', methods=['POST'])
337
  def tryon():
338
  data = request.json
339
  human_image = process_image(data['human_image'])
340
  garment_image = process_image(data['garment_image'])
341
+ description = data.get('description')
342
  use_auto_mask = data.get('use_auto_mask', True)
343
  use_auto_crop = data.get('use_auto_crop', False)
344
  denoise_steps = int(data.get('denoise_steps', 30))
345
+ seed = int(data.get('seed', 42))
346
+ categorie = data.get('categorie' , 'upper_body')
347
+ human_dict = {
348
+ 'background': human_image,
349
+ 'layers': [human_image] if not use_auto_mask else None,
350
+ 'composite': None
351
+ }
352
+ #clear_gpu_memory()
353
 
354
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
 
355
 
356
  output_base64 = encode_image_to_base64(output_image)
357
+ mask_base64 = encode_image_to_base64(mask_image)
358
+
359
+ return jsonify({
360
+ 'output_image': output_base64,
361
+ 'mask_image': mask_base64
362
+ })
363
+
364
+ @app.route('/tryon-v2', methods=['POST'])
365
+ def tryon_v2():
366
+
367
+ data = request.json
368
+ human_image_data = data['human_image']
369
+ garment_image_data = data['garment_image']
370
+
371
+ # Process images (base64 ou URL)
372
+ human_image = process_image(human_image_data)
373
+ garment_image = process_image(garment_image_data)
374
+
375
+ description = data.get('description')
376
+ use_auto_mask = data.get('use_auto_mask', True)
377
+ use_auto_crop = data.get('use_auto_crop', False)
378
+ denoise_steps = int(data.get('denoise_steps', 30))
379
+ seed = int(data.get('seed', random.randint(0, 9999999)))
380
+ categorie = data.get('categorie', 'upper_body')
381
+
382
+ # Vérifie si 'mask_image' est présent dans les données
383
+ mask_image = None
384
+ if 'mask_image' in data:
385
+ mask_image_data = data['mask_image']
386
+ mask_image = process_image(mask_image_data)
387
+
388
+ human_dict = {
389
+ 'background': human_image,
390
+ 'layers': [mask_image] if not use_auto_mask else None,
391
+ 'composite': None
392
+ }
393
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
394
+ return jsonify({
395
+ 'image_id': save_image(output_image)
396
+ })
397
+
398
+ @spaces.GPU
399
+ def generate_mask(human_img, categorie='upper_body'):
400
+ device = "cuda"
401
+ openpose_model.preprocessor.body_estimation.model.to(device)
402
+ pipe.to(device)
403
+
404
+ try:
405
+ # Redimensionner l'image pour le modèle
406
+ human_img_resized = human_img.convert("RGB").resize((384, 512))
407
+
408
+ # Générer les points clés et le masque
409
+ keypoints = openpose_model(human_img_resized)
410
+ model_parse, _ = parsing_model(human_img_resized)
411
+ mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
412
+
413
+ # Redimensionner le masque à la taille d'origine de l'image
414
+ mask_resized = mask.resize(human_img.size)
415
+
416
+ return mask_resized
417
+ except Exception as e:
418
+ logging.error(f"Error generating mask: {e}")
419
+ raise e
420
+
421
+
422
+ @app.route('/generate_mask', methods=['POST'])
423
+ def generate_mask_api():
424
+ try:
425
+ # Récupérer les données de l'image à partir de la requête
426
+ data = request.json
427
+ base64_image = data.get('human_image')
428
+ categorie = data.get('categorie', 'upper_body')
429
+
430
+ # Décodage de l'image à partir de base64
431
+ human_img = process_image(base64_image)
432
+
433
+ # Appeler la fonction pour générer le masque
434
+ mask_resized = generate_mask(human_img, categorie)
435
+
436
+ # Encodage du masque en base64 pour la réponse
437
+ mask_base64 = encode_image_to_base64(mask_resized)
438
+
439
+ return jsonify({
440
+ 'mask_image': mask_base64
441
+ }), 200
442
+ except Exception as e:
443
+ logging.error(f"Error generating mask: {e}")
444
+ return jsonify({'error': str(e)}), 500
445
+
446
+ # Route pour récupérer l'image générée
447
+ @app.route('/api/get_image/<image_id>', methods=['GET'])
448
+ def get_image(image_id):
449
+ # Construire le chemin complet de l'image
450
+ image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
451
+
452
+ # Renvoyer l'image
453
+ try:
454
+ return send_file(image_path, mimetype='image/webp')
455
+ except FileNotFoundError:
456
+ return jsonify({'error': 'Image not found'}), 404
457
 
458
  if __name__ == "__main__":
459
  app.run(debug=False, host="0.0.0.0", port=7860)