File size: 8,141 Bytes
09402c7
2a055fd
8f0759c
09402c7
8e09cec
55ec36e
7ce33ee
 
09402c7
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
ef0b37f
 
 
8f0759c
 
09402c7
ec0b3ac
09402c7
07e3b9e
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
09402c7
 
54dbe52
09402c7
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09402c7
8f0759c
 
 
 
09402c7
8f0759c
 
 
 
 
 
 
 
09402c7
8f0759c
 
 
 
07e3b9e
8f0759c
 
 
 
 
 
 
65e7974
8f0759c
 
09402c7
8f0759c
 
09402c7
 
8f0759c
 
 
 
 
 
09402c7
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
 
 
55ec36e
9d2e303
8f0759c
07e3b9e
8f0759c
 
 
3e40087
 
43c1b88
2ef8114
8f0759c
 
3e40087
8f0759c
a576857
0ca2dd1
8f0759c
 
 
a576857
 
 
 
5914d65
8f0759c
 
 
 
5914d65
8f0759c
 
 
 
5914d65
a576857
 
 
 
 
 
 
7581512
5914d65
 
 
 
a576857
5914d65
 
 
 
7581512
8f0759c
 
 
 
 
 
ef0b37f
 
8f0759c
 
 
 
 
 
 
 
 
 
 
 
43c1b88
d8bc56d
8f0759c
 
d8bc56d
8f0759c
 
09402c7
b73fd87
09402c7
5ed4c76
8f0759c
 
09402c7
8f0759c
b00fc32
8f0759c
 
5ed4c76
 
09402c7
8f0759c
5ed4c76
 
8f0759c
 
 
5ed4c76
8f0759c
5ed4c76
9d2e303
09402c7
8f0759c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import sys
import os

sys.path.append('./')
os.system("pip install gradio accelerate==0.25.0 torchmetrics==1.2.1 tqdm==4.66.1 fastapi==0.111.0 transformers==4.36.2 diffusers==0.25 einops==0.7.0 bitsandbytes scipy==1.11.1 opencv-python gradio==4.24.0 fvcore cloudpickle omegaconf pycocotools basicsr av onnxruntime==1.16.2 peft==0.11.1 huggingface_hub==0.24.7 --no-deps")
import spaces
from fastapi import FastAPI
app = FastAPI()

from PIL import Image
import gradio as gr
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)
from diffusers import DDPMScheduler,AutoencoderKL
from typing import List

import torch
import os
from transformers import AutoTokenizer
import numpy as np
from torchvision import transforms


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

def pil_to_binary_mask(pil_image, threshold=0):
    np_image = np.array(pil_image)
    grayscale_image = Image.fromarray(np_image).convert("L")
    binary_mask = np.array(grayscale_image) > threshold
    mask = np.zeros(binary_mask.shape, dtype=np.uint8)
    for i in range(binary_mask.shape[0]):
        for j in range(binary_mask.shape[1]):
            if binary_mask[i,j] == True :
                mask[i,j] = 1
    mask = (mask*255).astype(np.uint8)
    output_mask = Image.fromarray(mask)
    return output_mask


base_path = 'Keshabwi66/SmartLugaModel'

unet = UNet2DConditionModel.from_pretrained(
    base_path,
    subfolder="unet",
    torch_dtype=torch.float16,
)
unet.requires_grad_(False)
tokenizer_one = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer",
    revision=None,
    use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
    base_path,
    subfolder="tokenizer_2",
    revision=None,
    use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")

text_encoder_one = CLIPTextModel.from_pretrained(
    base_path,
    subfolder="text_encoder",
    torch_dtype=torch.float16,
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
    base_path,
    subfolder="text_encoder_2",
    torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    base_path,
    subfolder="image_encoder",
    torch_dtype=torch.float16,
    )
vae = AutoencoderKL.from_pretrained(base_path,
                                    subfolder="vae",
                                    torch_dtype=torch.float16,
)

# "stabilityai/stable-diffusion-xl-base-1.0",
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
    base_path,
    subfolder="unet_encoder",
    torch_dtype=torch.float16,
)



UNet_Encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
tensor_transfrom = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
    )

pipe = TryonPipeline.from_pretrained(
        base_path,
        unet=unet,
        vae=vae,
        feature_extractor= CLIPImageProcessor(),
        text_encoder = text_encoder_one,
        text_encoder_2 = text_encoder_two,
        tokenizer = tokenizer_one,
        tokenizer_2 = tokenizer_two,
        scheduler = noise_scheduler,
        image_encoder=image_encoder,
        torch_dtype=torch.float16,
)
pipe.unet_encoder = UNet_Encoder

@spaces.GPU
def start_tryon(person_img, mask_img, cloth_img, garment_des, denoise_steps=10, seed=42):
    # Assuming device is set up (e.g., "cuda" or "cpu")
    pipe.to(device)
    pipe.unet_encoder.to(device)

    # Resize and prepare images
    garm_img = cloth_img.convert("RGB").resize((768, 1024))
    human_img = person_img.convert("RGB").resize((768, 1024))
    mask = pil_to_binary_mask(mask_img.convert("RGB").resize((768, 1024)))
    pose_img=Image.open("00006_00.jpg")

    # Prepare pose image (already uploaded)
    pose_img = pose_img.resize((768, 1024))

    
    
    # Embedding generation for prompts
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            # Generate text embeddings for garment description
            prompt = f"model is wearing {garment_des}"
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            with torch.inference_mode():
             (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
             )= pipe.encode_prompt(
                prompt,
                num_images_per_prompt=1,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
             )
            prompt = "a photo of " + garment_des
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
            if not isinstance(prompt, List):
                        prompt = [prompt] * 1
            if not isinstance(negative_prompt, List):
                        negative_prompt = [negative_prompt] * 1
            with torch.inference_mode():
                (
                   prompt_embeds_cloth, 
                   _,
                   _,
                   _,
                )= pipe.encode_prompt(
                   prompt,
                   num_images_per_prompt=1,
                   do_classifier_free_guidance=False,
                   negative_prompt=negative_prompt,
                )

            # Convert images to tensors for processing
            pose_img_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
            garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)

            # Prepare the generator with optional seed
            generator = torch.Generator(device).manual_seed(seed) if seed is not None else None

            # Generate the virtual try-on output image
            images = pipe(
                prompt_embeds=prompt_embeds.to(device, torch.float16),
                negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
                pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
                negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
                num_inference_steps=denoise_steps,
                generator=generator,
                strength=1.0,
                pose_img=pose_img_tensor.to(device, torch.float16),
                text_embeds_cloth=prompt_embeds_cloth.to(device, torch.float16),
                cloth=garm_tensor.to(device, torch.float16),
                mask_image=mask,
                image=human_img,
                height=1024,
                width=768,
                ip_adapter_image=garm_img.resize((768, 1024)),
                guidance_scale=2.0,
            )[0]

    return images[0]


# Gradio interface for the virtual try-on model
image_blocks = gr.Blocks().queue()

with image_blocks as demo:
    gr.Markdown("## SmartLuga")    
    with gr.Row():
        with gr.Column():
            person_img = gr.Image(label='Person Image', sources='upload', type="pil")
            mask_img = gr.Image(label='Mask Image', sources='upload', type="pil")

        with gr.Column():
            cloth_img = gr.Image(label='Garment Image', sources='upload', type="pil")
            garment_des = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", label="Garment Description")


        with gr.Column():
            image_out = gr.Image(label="Output Image", elem_id="output-img", show_share_button=False)

    try_button = gr.Button(value="Try-on")
    try_button.click(fn=start_tryon, inputs=[person_img, mask_img, cloth_img, garment_des], outputs=[image_out], api_name='tryon')

image_blocks.launch()