File size: 8,308 Bytes
09402c7
2a055fd
8f0759c
09402c7
8e09cec
55ec36e
7ce33ee
 
09402c7
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
09402c7
 
8f0759c
 
 
09402c7
8f0759c
09402c7
07e3b9e
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
09402c7
 
8f0759c
09402c7
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09402c7
8f0759c
 
 
 
09402c7
8f0759c
 
 
 
 
 
 
 
09402c7
8f0759c
 
 
 
07e3b9e
8f0759c
 
 
 
 
 
 
 
 
 
 
09402c7
8f0759c
 
09402c7
 
8f0759c
 
 
 
 
 
09402c7
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
 
55ec36e
8f0759c
 
 
07e3b9e
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
0ca2dd1
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09402c7
8f0759c
09402c7
8f0759c
 
09402c7
8f0759c
 
 
 
 
 
 
09402c7
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09402c7
8f0759c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import sys
import os

sys.path.append('./')
os.system("pip install gradio accelerate==0.25.0 torchmetrics==1.2.1 tqdm==4.66.1 fastapi==0.111.0 transformers==4.36.2 diffusers==0.25 einops==0.7.0 bitsandbytes scipy==1.11.1 opencv-python gradio==4.24.0 fvcore cloudpickle omegaconf pycocotools basicsr av onnxruntime==1.16.2 peft==0.11.1 huggingface_hub==0.24.7 --no-deps")
import spaces
from fastapi import FastAPI
app = FastAPI()

from PIL import Image
import gradio as gr
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)
from diffusers import DDPMScheduler,AutoencoderKL
from typing import List

import torch
import os
from transformers import AutoTokenizer
import numpy as np
from torchvision import transforms
import apply_net

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

def pil_to_binary_mask(pil_image, threshold=0):
    np_image = np.array(pil_image)
    grayscale_image = Image.fromarray(np_image).convert("L")
    binary_mask = np.array(grayscale_image) > threshold
    mask = np.zeros(binary_mask.shape, dtype=np.uint8)
    for i in range(binary_mask.shape[0]):
        for j in range(binary_mask.shape[1]):
            if binary_mask[i,j] == True :
                mask[i,j] = 1
    mask = (mask*255).astype(np.uint8)
    output_mask = Image.fromarray(mask)
    return output_mask


base_path = 'yisol/IDM-VTON'

unet = UNet2DConditionModel.from_pretrained(
    base_path,
    subfolder="unet",
    torch_dtype=torch.float16,
)
unet.requires_grad_(False)
tokenizer_one = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer",
    revision=None,
    use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer_2",
    revision=None,
    use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")

text_encoder_one = CLIPTextModel.from_pretrained(
    base_path,
    subfolder="text_encoder",
    torch_dtype=torch.float16,
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
    base_path,
    subfolder="text_encoder_2",
    torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    base_path,
    subfolder="image_encoder",
    torch_dtype=torch.float16,
    )
vae = AutoencoderKL.from_pretrained(base_path,
                                    subfolder="vae",
                                    torch_dtype=torch.float16,
)

# "stabilityai/stable-diffusion-xl-base-1.0",
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
    base_path,
    subfolder="unet_encoder",
    torch_dtype=torch.float16,
)

parsing_model = Parsing(0)
openpose_model = OpenPose(0)

UNet_Encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
tensor_transfrom = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
    )

pipe = TryonPipeline.from_pretrained(
        base_path,
        unet=unet,
        vae=vae,
        feature_extractor= CLIPImageProcessor(),
        text_encoder = text_encoder_one,
        text_encoder_2 = text_encoder_two,
        tokenizer = tokenizer_one,
        tokenizer_2 = tokenizer_two,
        scheduler = noise_scheduler,
        image_encoder=image_encoder,
        torch_dtype=torch.float16,
)
pipe.unet_encoder = UNet_Encoder

@spaces.GPU
def start_tryon(person_img, pose_img, mask_img, cloth_img, garment_des, denoise_steps, seed):
    # Assuming device is set up (e.g., "cuda" or "cpu")
    openpose_model.preprocessor.body_estimation.model.to(device)
    pipe.to(device)
    pipe.unet_encoder.to(device)

    # Resize and prepare images
    garm_img = cloth_img.convert("RGB").resize((768, 1024))
    human_img = person_img.convert("RGB").resize((768, 1024))
    mask = mask_img.convert("RGB").resize((768, 1024))

    # Prepare pose image (already uploaded)
    pose_img = pose_img.resize((768, 1024))

    # Generate text embeddings for garment description
    prompt = f"model is wearing {garment_des}"
    negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
    
    # Embedding generation for prompts
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
            ) = pipe.encode_prompt(
                prompt,
                num_images_per_prompt=1,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )

            prompt_embeds_cloth, _ = pipe.encode_prompt(
                f"a photo of {garment_des}",
                num_images_per_prompt=1,
                do_classifier_free_guidance=False,
                negative_prompt=negative_prompt,
            )

            # Convert images to tensors for processing
            pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
            garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
            mask_tensor = tensor_transfrom(mask).unsqueeze(0).to(device, torch.float16)

            # Prepare the generator with optional seed
            generator = torch.Generator(device).manual_seed(seed) if seed is not None else None

            # Generate the virtual try-on output image
            images = pipe(
                prompt_embeds=prompt_embeds.to(device, torch.float16),
                negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
                pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
                negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
                num_inference_steps=denoise_steps,
                generator=generator,
                strength=1.0,
                pose_img=pose_img_tensor.to(device, torch.float16),
                text_embeds_cloth=prompt_embeds_cloth.to(device, torch.float16),
                cloth=garm_tensor.to(device, torch.float16),
                mask_image=mask_tensor,
                image=human_img,
                height=1024,
                width=768,
                ip_adapter_image=garm_img.resize((768, 1024)),
                guidance_scale=2.0,
            )[0]

    return images

# Gradio interface for the virtual try-on model
image_blocks = gr.Blocks().queue()

with image_blocks as demo:
    gr.Markdown("## SmartLuga ")
    with gr.Row():
        with gr.Column():
            imgs = gr.ImageEditor(sources='upload', type="pil", label='Human Image', interactive=True)
            with gr.Row():
                is_checked_crop = gr.Checkbox(label="Use auto-crop & resizing", value=False)

        with gr.Column():
            garm_img = gr.Image(label="Garment", sources='upload', type="pil")
            with gr.Row(elem_id="prompt-container"):
                prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")

        with gr.Column():
            masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)

        with gr.Column():
            image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)

    with gr.Column():
        try_button = gr.Button(value="Try-on")
        with gr.Accordion(label="Advanced Settings", open=False):
            with gr.Row():
                denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
                seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)

    try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, denoise_steps, seed], outputs=[image_out, masked_img], api_name='tryon')

image_blocks.launch()