Keshabwi66 commited on
Commit
09402c7
·
verified ·
1 Parent(s): c0d124d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ from transformers import CLIPImageProcessor
8
+
9
+ # Add necessary imports and initialize the model as in your code...
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
11
+ from ip_adapter.ip_adapter import Resampler
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from accelerate.logging import get_logger
21
+ from accelerate.utils import set_seed
22
+ from torchvision import transforms
23
+
24
+ from diffusers import AutoencoderKL, DDPMScheduler
25
+ from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel,
26
+
27
+
28
+ from src.unet_hacked_tryon import UNet2DConditionModel
29
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
30
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
31
+ # Define a class to hold configuration arguments
32
+ class Args:
33
+ def __init__(self):
34
+ self.pretrained_model_name_or_path = "yisol/IDM-VTON"
35
+ self.width = 768
36
+ self.height = 1024
37
+ self.num_inference_steps = 10
38
+ self.seed = 42
39
+ self.guidance_scale = 2.0
40
+ self.mixed_precision = None
41
+
42
+ # Determine the device to be used for computations (CUDA if available)
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+
45
+ logger = get_logger(__name__, log_level="INFO")
46
+
47
+ def pil_to_tensor(images):
48
+ images = np.array(images).astype(np.float32) / 255.0
49
+ images = torch.from_numpy(images.transpose(2, 0, 1))
50
+ return images
51
+
52
+
53
+
54
+ args = Args()
55
+
56
+ # Define the data type for model weights
57
+ weight_dtype = torch.float16
58
+
59
+ if args.seed is not None:
60
+ set_seed(args.seed)
61
+
62
+
63
+ # Load scheduler, tokenizer and models.
64
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
65
+ vae = AutoencoderKL.from_pretrained(
66
+ args.pretrained_model_name_or_path,
67
+ subfolder="vae",
68
+ torch_dtype=torch.float16,
69
+ )
70
+ unet = UNet2DConditionModel.from_pretrained(
71
+ args.pretrained_model_name_or_path,
72
+ subfolder="unet",
73
+ torch_dtype=torch.float16,
74
+ )
75
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
76
+ args.pretrained_model_name_or_path,
77
+ subfolder="image_encoder",
78
+ torch_dtype=torch.float16,
79
+ )
80
+ unet_encoder = UNet2DConditionModel_ref.from_pretrained(
81
+ args.pretrained_model_name_or_path,
82
+ subfolder="unet_encoder",
83
+ torch_dtype=torch.float16,
84
+ )
85
+ text_encoder_one = CLIPTextModel.from_pretrained(
86
+ args.pretrained_model_name_or_path,
87
+ subfolder="text_encoder",
88
+ torch_dtype=torch.float16,
89
+ )
90
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
91
+ args.pretrained_model_name_or_path,
92
+ subfolder="text_encoder_2",
93
+ torch_dtype=torch.float16,
94
+ )
95
+ tokenizer_one = AutoTokenizer.from_pretrained(
96
+ args.pretrained_model_name_or_path,
97
+ subfolder="tokenizer",
98
+ revision=None,
99
+ use_fast=False,
100
+ )
101
+ tokenizer_two = AutoTokenizer.from_pretrained(
102
+ args.pretrained_model_name_or_path,
103
+ subfolder="tokenizer_2",
104
+ revision=None,
105
+ use_fast=False,
106
+ )
107
+ # Freeze vae and text_encoder and set unet to trainable
108
+ unet.requires_grad_(False)
109
+ vae.requires_grad_(False)
110
+ image_encoder.requires_grad_(False)
111
+ unet_encoder.requires_grad_(False)
112
+ text_encoder_one.requires_grad_(False)
113
+ text_encoder_two.requires_grad_(False)
114
+ unet_encoder.to(device, weight_dtype)
115
+ unet.eval()
116
+ unet_encoder.eval()
117
+
118
+ pipe = TryonPipeline.from_pretrained(
119
+ args.pretrained_model_name_or_path,
120
+ unet=unet,
121
+ vae=vae,
122
+ feature_extractor= CLIPImageProcessor(),
123
+ text_encoder = text_encoder_one,
124
+ text_encoder_2 = text_encoder_two,
125
+ tokenizer = tokenizer_one,
126
+ tokenizer_2 = tokenizer_two,
127
+ scheduler = noise_scheduler,
128
+ image_encoder=image_encoder,
129
+ unet_encoder = unet_encoder,
130
+ torch_dtype=torch.float16,
131
+ ).to(device)
132
+ # pipe.enable_sequential_cpu_offload()
133
+ # pipe.enable_model_cpu_offload()
134
+ # pipe.enable_vae_slicing()
135
+
136
+ # Function to generate the image based on inputs
137
+ def generate_virtual_try_on(person_image, cloth_image, mask_image, pose_image,cloth_des):
138
+ # Prepare the input images as tensors
139
+ person_image = person_image.resize((args.width, args.height))
140
+ cloth_image = cloth_image.resize((args.width, args.height))
141
+ mask_image = mask_image.resize((args.width, args.height))
142
+ pose_image = pose_image.resize((args.width, args.height))
143
+ # Define transformations
144
+ transform = transforms.Compose([
145
+ transforms.ToTensor(),
146
+ transforms.Normalize([0.5], [0.5]),
147
+ ])
148
+ guidance_scale=2.0
149
+ seed=42
150
+
151
+ to_tensor = transforms.ToTensor()
152
+
153
+ person_tensor = transform(person_image).unsqueeze(0).to(device) # Add batch dimension
154
+ cloth_pure = transform(cloth_image).unsqueeze(0).to(device)
155
+ mask_tensor = to_tensor(mask_image)[:1].unsqueeze(0).to(device) # Keep only one channel
156
+ pose_tensor = transform(pose_image).unsqueeze(0).to(device)
157
+
158
+ # Prepare text prompts
159
+ prompt = ["A person wearing the cloth"+cloth_des] # Example prompt
160
+ negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"]
161
+
162
+ # Encode prompts
163
+ with torch.inference_mode():
164
+ (
165
+ prompt_embeds,
166
+ negative_prompt_embeds,
167
+ pooled_prompt_embeds,
168
+ negative_pooled_prompt_embeds,
169
+ ) = pipe.encode_prompt(
170
+ prompt,
171
+ num_images_per_prompt=1,
172
+ do_classifier_free_guidance=True,
173
+ negative_prompt=negative_prompt,
174
+ )
175
+ prompt_cloth = ["a photo of"+cloth_des]
176
+ with torch.inference_mode():
177
+ (
178
+ prompt_embeds_c,
179
+ _,
180
+ _,
181
+ _,
182
+ ) = pipe.encode_prompt(
183
+ prompt_cloth,
184
+ num_images_per_prompt=1,
185
+ do_classifier_free_guidance=False,
186
+ negative_prompt=negative_prompt,
187
+ )
188
+
189
+ # Encode garment using IP-Adapter
190
+ clip_processor = CLIPImageProcessor()
191
+ image_embeds = clip_processor(images=cloth_image, return_tensors="pt").pixel_values.to(device)
192
+
193
+ # Generate the image
194
+ generator = torch.Generator(pipe.device).manual_seed(seed) if seed is not None else None
195
+
196
+ with torch.no_grad():
197
+ images = pipe(
198
+ prompt_embeds=prompt_embeds,
199
+ negative_prompt_embeds=negative_prompt_embeds,
200
+ pooled_prompt_embeds=pooled_prompt_embeds,
201
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
202
+ num_inference_steps=args.num_inference_steps,
203
+ generator=generator,
204
+ strength=1.0,
205
+ pose_img=pose_tensor,
206
+ text_embeds_cloth=prompt_embeds_c,
207
+ cloth=cloth_pure,
208
+ mask_image=mask_tensor,
209
+ image=(person_tensor + 1.0) / 2.0,
210
+ height=args.height,
211
+ width=args.width,
212
+ guidance_scale=guidance_scale,
213
+ ip_adapter_image=image_embeds,
214
+ )[0]
215
+
216
+ # Convert output image to PIL format for display
217
+ generated_image = transforms.ToPILImage()(images[0])
218
+ return generated_image
219
+
220
+ # Create Gradio interface
221
+ iface = gr.Interface(
222
+ fn=generate_virtual_try_on,
223
+ inputs=[
224
+ gr.Image(type="pil", label="Person Image"),
225
+ gr.Image(type="pil", label="Cloth Image"),
226
+ gr.Image(type="pil", label="Mask Image"),
227
+ gr.Image(type="pil", label="Pose Image"),
228
+ gr.Textbox(label="cloth_des"), # Add text input
229
+
230
+
231
+
232
+
233
+ ],
234
+ outputs=gr.Image(type="pil", label="Generated Image"),
235
+ )
236
+
237
+ # Launch the interface
238
+ iface.launch()