Spaces:
Saad0KH
/
Running on Zero

File size: 10,628 Bytes
83d675a
8a6a4a8
83d675a
 
 
848b0e8
83d675a
 
 
8a6a4a8
004975c
83d675a
 
 
 
 
004975c
8a6a4a8
83d675a
 
 
 
 
 
 
848b0e8
 
 
83d675a
 
ccd0584
83d675a
 
 
c9a29b0
ccd0584
83d675a
 
 
 
 
 
 
 
 
 
 
 
 
c9a29b0
 
 
 
 
 
 
 
 
83d675a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccd0584
83d675a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af7056a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import os
from flask import Flask, request, jsonify, send_file
from PIL import Image
from io import BytesIO
import base64
import torch
import requests
import numpy as np
import uuid
import spaces
from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    CLIPTextModel,
    CLIPTextModelWithProjection,
    AutoTokenizer
)
from diffusers import DDPMScheduler, AutoencoderKL, UNet2DConditionModel
from utils_mask import get_mask_location
from torchvision import transforms
import apply_net
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
from torchvision.transforms.functional import to_pil_image

app = Flask(__name__)

# Variables globales pour stocker les modèles
models_loaded = False

def load_models():
    global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two
    global image_encoder, vae, UNet_Encoder, parsing_model, openpose_model, pipe
    global models_loaded
    
    if not models_loaded:
        base_path = 'yisol/IDM-VTON'
        unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
        unet.requires_grad_(False)
        
        tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
        tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
        
        noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
        text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
        text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
        vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
        
        # Set the correct encoder_hid_dim_type here
        UNet_Encoder = UNet2DConditionModel.from_pretrained(
            base_path,
            subfolder="unet_encoder",
            torch_dtype=torch.float16,
            encoder_hid_dim_type="text_proj",  # Update based on model type
            force_download=False
        )
        
        parsing_model = Parsing(0)
        openpose_model = OpenPose(0)
        
        UNet_Encoder.requires_grad_(False)
        image_encoder.requires_grad_(False)
        vae.requires_grad_(False)
        unet.requires_grad_(False)
        text_encoder_one.requires_grad_(False)
        text_encoder_two.requires_grad_(False)
        
        tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
        
        pipe = TryonPipeline.from_pretrained(
            base_path,
            unet=unet,
            vae=vae,
            feature_extractor=CLIPImageProcessor(),
            text_encoder=text_encoder_one,
            text_encoder_2=text_encoder_two,
            tokenizer=tokenizer_one,
            tokenizer_2=tokenizer_two,
            scheduler=noise_scheduler,
            image_encoder=image_encoder,
            torch_dtype=torch.float16,
            force_download=False
        )
        pipe.unet_encoder = UNet_Encoder
        
        models_loaded = True

def pil_to_binary_mask(pil_image, threshold=0):
    np_image = np.array(pil_image.convert("L"))  # Convert to grayscale directly
    binary_mask = np_image > threshold
    mask = np.uint8(binary_mask * 255)
    return Image.fromarray(mask)

def get_image_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()
        return Image.open(BytesIO(response.content))
    except Exception as e:
        logging.error(f"Error fetching image from URL: {e}")
        raise

def decode_image_from_base64(base64_str):
    try:
        img_data = base64.b64decode(base64_str)
        return Image.open(BytesIO(img_data))
    except Exception as e:
        logging.error(f"Error decoding image: {e}")
        raise

def encode_image_to_base64(img):
    try:
        buffered = BytesIO()
        img.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    except Exception as e:
        logging.error(f"Error encoding image: {e}")
        raise

def save_image(img):
    unique_name = f"{uuid.uuid4()}.webp"
    img.save(unique_name, format="WEBP", lossless=True)
    return unique_name

def clear_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

@spaces.GPU
def start_tryon(human_dict, garment_image, garment_description, use_auto_mask, use_auto_crop, denoise_steps, seed, category='upper_body'):
    device = "cuda"
    openpose_model.preprocessor.body_estimation.model.to(device)
    pipe.to(device)
    pipe.unet_encoder.to(device)

    garment_image = garment_image.convert("RGB").resize((768, 1024))
    human_image_orig = human_dict["background"].convert("RGB")

    if use_auto_crop:
        width, height = human_image_orig.size
        target_width = int(min(width, height * (3 / 4)))
        target_height = int(min(height, width * (4 / 3)))
        left, top = (width - target_width) / 2, (height - target_height) / 2
        right, bottom = (width + target_width) / 2, (height + target_height) / 2
        cropped_img = human_image_orig.crop((left, top, right, bottom)).resize((768, 1024))
    else:
        cropped_img = human_image_orig.resize((768, 1024))

    if use_auto_mask:
        keypoints = openpose_model(cropped_img.resize((384, 512)))
        model_parse, _ = parsing_model(cropped_img.resize((384, 512)))
        mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
        mask = mask.resize((768, 1024))
    else:
        mask = pil_to_binary_mask(human_dict['layers'][0].convert("RGB").resize((768, 1024)))
    
    mask_gray = (1 - transforms.ToTensor()(mask)) * transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(cropped_img)
    mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)

    human_image_arg = _apply_exif_orientation(cropped_img.resize((384, 512)))
    human_image_arg = convert_PIL_to_numpy(human_image_arg, format="BGR")

    args = apply_net.create_argument_parser().parse_args(
        ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
    pose_image = args.func(args, human_image_arg)
    pose_image = Image.fromarray(pose_image[:, :, ::-1]).resize((768, 1024))

    with torch.no_grad(), torch.cuda.amp.autocast():
        prompt = "model is wearing " + garment_description
        negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
        prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
            prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt
        )

        prompt_c = "a photo of " + garment_description
        negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
        prompt_embeds_c, _, _, _ = pipe.encode_prompt(
            prompt_c, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt_c
        )

        pose_image = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(pose_image).unsqueeze(0).to(device, torch.float16)
        garment_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(garment_image).unsqueeze(0).to(device, torch.float16)

        images = pipe(
            prompt_embeds=prompt_embeds.to(device, torch.float16),
            negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
            pose_image=pose_image,
            garment_image=garment_tensor,
            mask_image=mask_gray.to(device, torch.float16),
            generator=torch.Generator(device).manual_seed(seed),
            num_inference_steps=denoise_steps
        ).images

    if images:
        output_image = images[0]
        output_base64 = encode_image_to_base64(output_image)
        mask_image = mask
        mask_base64 = encode_image_to_base64(mask_image)
        return output_image, mask_image
    else:
        raise ValueError("Failed to generate image")


# Route pour récupérer l'image générée
@app.route('/api/get_image/<image_id>', methods=['GET'])
def get_image(image_id):
    # Construire le chemin complet de l'image
    image_path = image_id  # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde

    # Renvoyer l'image
    try:
        return send_file(image_path, mimetype='image/webp')
    except FileNotFoundError:
        return jsonify({'error': 'Image not found'}), 404

@app.route('/tryon', methods=['POST'])
def tryon_handler():
    try:
        data = request.json
        human_image = decode_image_from_base64(data['human_image'])
        garment_image = decode_image_from_base64(data['garment_image'])
        description = data.get('description')
        use_auto_mask = data.get('use_auto_mask', True)
        use_auto_crop = data.get('use_auto_crop', False)
        denoise_steps = int(data.get('denoise_steps', 30))
        seed = int(data.get('seed', 42))
        category = data.get('category', 'upper_body')
        
        human_dict = {
            'background': human_image,
            'layers': [human_image] if not use_auto_mask else None,
            'composite': None
        }
        clear_gpu_memory()

        output_image, mask_image = start_tryon(
            human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, category
        )

        output_base64 = encode_image_to_base64(output_image)
        mask_base64 = encode_image_to_base64(mask_image)

        return jsonify({
            'output_image': output_base64,
            'mask_image': mask_base64
        })
    except Exception as e:
        logging.error(f"Error in tryon_handler: {e}")
        return jsonify({'error': str(e)}), 500

if __name__ == "__main__":
    load_models()  # Charge les modèles au démarrage
    app.run(host='0.0.0.0', port=7860)