Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
004975c
ยท
verified ยท
1 Parent(s): b19f449

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -41
app.py CHANGED
@@ -1,51 +1,263 @@
 
1
  from flask import Flask, request, jsonify
2
  from PIL import Image
3
- import base64
4
- import io
5
- import numpy as np
6
  import torch
7
- import torch.nn.functional as F
8
- from torchvision.transforms.functional import normalize
9
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Initialize Flask app
12
  app = Flask(__name__)
13
 
14
- # Function to decode a base64 encoded image to a PIL image
15
- def decode_image_from_base64(image_data):
16
- encoded_image = image_data.split(",")[1]
17
- decoded_image = base64.b64decode(encoded_image)
18
- image = Image.open(io.BytesIO(decoded_image))
19
- return image
20
-
21
- # Function to encode a PIL image to base64
22
- def encode_image_to_base64(image):
23
- buffered = io.BytesIO()
24
- image.save(buffered, format="PNG")
25
- encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
26
- return "data:image/png;base64," + encoded_image
27
-
28
- # Function to process the image
29
- def process(image_data):
30
- image = decode_image_from_base64(image_data)
31
-
32
-
33
- return image
34
-
35
- @app.route("/")
36
- def root():
37
- return "Welcome to StyleSync Outfit Backround API!"
38
-
39
- # Route for the REST API
40
- @app.route('/api/tryon', methods=['POST'])
41
- def classify():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  data = request.json
43
- print(data)
44
- image_data = data['image']
45
- result_image = process(image_data)
46
- result_base64 = encode_image_to_base64(result_image)
47
- return jsonify({'result': result_base64})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  if __name__ == "__main__":
50
- app.run(host="0.0.0.0", port=7860)
51
 
 
1
+ import os
2
  from flask import Flask, request, jsonify
3
  from PIL import Image
4
+ from io import BytesIO
 
 
5
  import torch
6
+ import gradio as gr
7
+ import numpy as np
8
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
9
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
10
+ from src.unet_hacked_tryon import UNet2DConditionModel
11
+ from transformers import (
12
+ CLIPImageProcessor,
13
+ CLIPVisionModelWithProjection,
14
+ CLIPTextModel,
15
+ CLIPTextModelWithProjection,
16
+ AutoTokenizer,
17
+ )
18
+ from diffusers import DDPMScheduler, AutoencoderKL
19
+ from utils_mask import get_mask_location
20
+ from torchvision import transforms
21
+ import apply_net
22
+ from preprocess.humanparsing.run_parsing import Parsing
23
+ from preprocess.openpose.run_openpose import OpenPose
24
+ from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
25
+ from torchvision.transforms.functional import to_pil_image
26
 
 
27
  app = Flask(__name__)
28
 
29
+ base_path = 'yisol/IDM-VTON'
30
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
31
+
32
+ unet = UNet2DConditionModel.from_pretrained(
33
+ base_path,
34
+ subfolder="unet",
35
+ torch_dtype=torch.float16,
36
+ )
37
+ unet.requires_grad_(False)
38
+ tokenizer_one = AutoTokenizer.from_pretrained(
39
+ base_path,
40
+ subfolder="tokenizer",
41
+ revision=None,
42
+ use_fast=False,
43
+ )
44
+ tokenizer_two = AutoTokenizer.from_pretrained(
45
+ base_path,
46
+ subfolder="tokenizer_2",
47
+ revision=None,
48
+ use_fast=False,
49
+ )
50
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
51
+
52
+ text_encoder_one = CLIPTextModel.from_pretrained(
53
+ base_path,
54
+ subfolder="text_encoder",
55
+ torch_dtype=torch.float16,
56
+ )
57
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
58
+ base_path,
59
+ subfolder="text_encoder_2",
60
+ torch_dtype=torch.float16,
61
+ )
62
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
63
+ base_path,
64
+ subfolder="image_encoder",
65
+ torch_dtype=torch.float16,
66
+ )
67
+ vae = AutoencoderKL.from_pretrained(base_path,
68
+ subfolder="vae",
69
+ torch_dtype=torch.float16,
70
+ )
71
+
72
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
73
+ base_path,
74
+ subfolder="unet_encoder",
75
+ torch_dtype=torch.float16,
76
+ )
77
+
78
+ parsing_model = Parsing(0)
79
+ openpose_model = OpenPose(0)
80
+
81
+ UNet_Encoder.requires_grad_(False)
82
+ image_encoder.requires_grad_(False)
83
+ vae.requires_grad_(False)
84
+ unet.requires_grad_(False)
85
+ text_encoder_one.requires_grad_(False)
86
+ text_encoder_two.requires_grad_(False)
87
+ tensor_transfrom = transforms.Compose(
88
+ [
89
+ transforms.ToTensor(),
90
+ transforms.Normalize([0.5], [0.5]),
91
+ ]
92
+ )
93
+
94
+ pipe = TryonPipeline.from_pretrained(
95
+ base_path,
96
+ unet=unet,
97
+ vae=vae,
98
+ feature_extractor= CLIPImageProcessor(),
99
+ text_encoder = text_encoder_one,
100
+ text_encoder_2 = text_encoder_two,
101
+ tokenizer = tokenizer_one,
102
+ tokenizer_2 = tokenizer_two,
103
+ scheduler = noise_scheduler,
104
+ image_encoder=image_encoder,
105
+ torch_dtype=torch.float16,
106
+ )
107
+ pipe.unet_encoder = UNet_Encoder
108
+
109
+ def pil_to_binary_mask(pil_image, threshold=0):
110
+ np_image = np.array(pil_image)
111
+ grayscale_image = Image.fromarray(np_image).convert("L")
112
+ binary_mask = np.array(grayscale_image) > threshold
113
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
114
+ for i in range(binary_mask.shape[0]):
115
+ for j in range(binary_mask.shape[1]):
116
+ if binary_mask[i, j]:
117
+ mask[i, j] = 1
118
+ mask = (mask * 255).astype(np.uint8)
119
+ output_mask = Image.fromarray(mask)
120
+ return output_mask
121
+
122
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
123
+ device = "cuda"
124
+ openpose_model.preprocessor.body_estimation.model.to(device)
125
+ pipe.to(device)
126
+ pipe.unet_encoder.to(device)
127
+
128
+ garm_img = garm_img.convert("RGB").resize((768, 1024))
129
+ human_img_orig = dict["background"].convert("RGB")
130
+
131
+ if is_checked_crop:
132
+ width, height = human_img_orig.size
133
+ target_width = int(min(width, height * (3 / 4)))
134
+ target_height = int(min(height, width * (4 / 3)))
135
+ left = (width - target_width) / 2
136
+ top = (height - target_height) / 2
137
+ right = (width + target_width) / 2
138
+ bottom = (height + target_height) / 2
139
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
140
+ crop_size = cropped_img.size
141
+ human_img = cropped_img.resize((768, 1024))
142
+ else:
143
+ human_img = human_img_orig.resize((768, 1024))
144
+
145
+ if is_checked:
146
+ keypoints = openpose_model(human_img.resize((384, 512)))
147
+ model_parse, _ = parsing_model(human_img.resize((384, 512)))
148
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
149
+ mask = mask.resize((768, 1024))
150
+ else:
151
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
152
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
153
+ mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
154
+
155
+ human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
156
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
157
+
158
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
159
+ pose_img = args.func(args, human_img_arg)
160
+ pose_img = pose_img[:, :, ::-1]
161
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
162
+
163
+ with torch.no_grad():
164
+ with torch.cuda.amp.autocast():
165
+ prompt = "model is wearing " + garment_des
166
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
167
+ with torch.inference_mode():
168
+ (
169
+ prompt_embeds,
170
+ negative_prompt_embeds,
171
+ pooled_prompt_embeds,
172
+ negative_pooled_prompt_embeds,
173
+ ) = pipe.encode_prompt(
174
+ prompt,
175
+ num_images_per_prompt=1,
176
+ do_classifier_free_guidance=True,
177
+ negative_prompt=negative_prompt,
178
+ )
179
+
180
+ prompt = "a photo of " + garment_des
181
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
182
+ if not isinstance(prompt, list):
183
+ prompt = [prompt] * 1
184
+ if not isinstance(negative_prompt, list):
185
+ negative_prompt = [negative_prompt] * 1
186
+ with torch.inference_mode():
187
+ (
188
+ prompt_embeds_c,
189
+ _,
190
+ _,
191
+ _,
192
+ ) = pipe.encode_prompt(
193
+ prompt,
194
+ num_images_per_prompt=1,
195
+ do_classifier_free_guidance=False,
196
+ negative_prompt=negative_prompt,
197
+ )
198
+
199
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
200
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
201
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
202
+ images = pipe(
203
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
204
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
205
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
206
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
207
+ num_inference_steps=denoise_steps,
208
+ generator=generator,
209
+ strength=1.0,
210
+ pose_img=pose_img.to(device, torch.float16),
211
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
212
+ cloth=garm_tensor.to(device, torch.float16),
213
+ mask_image=mask,
214
+ image=human_img,
215
+ height=1024,
216
+ width=768,
217
+ ip_adapter_image=garm_img.resize((768, 1024)),
218
+ guidance_scale=2.0,
219
+ )[0]
220
+
221
+ if is_checked_crop:
222
+ out_img = images[0].resize(crop_size)
223
+ human_img_orig.paste(out_img, (int(left), int(top)))
224
+ return human_img_orig, mask_gray
225
+ else:
226
+ return images[0], mask_gray
227
+
228
+ @app.route('/tryon', methods=['POST'])
229
+ def tryon():
230
  data = request.json
231
+
232
+ human_image = Image.open(BytesIO(request.files['human_image'].read()))
233
+ garment_image = Image.open(BytesIO(request.files['garment_image'].read()))
234
+ description = data.get('description')
235
+ use_auto_mask = data.get('use_auto_mask', True)
236
+ use_auto_crop = data.get('use_auto_crop', False)
237
+ denoise_steps = int(data.get('denoise_steps', 30))
238
+ seed = int(data.get('seed', 42))
239
+
240
+ human_dict = {
241
+ 'background': human_image,
242
+ 'layers': [human_image] if not use_auto_mask else None,
243
+ 'composite': None
244
+ }
245
+
246
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed)
247
+
248
+ output_bytes = BytesIO()
249
+ output_image.save(output_bytes, format='PNG')
250
+ output_bytes = output_bytes.getvalue()
251
+
252
+ mask_bytes = BytesIO()
253
+ mask_image.save(mask_bytes, format='PNG')
254
+ mask_bytes = mask_bytes.getvalue()
255
+
256
+ return jsonify({
257
+ 'output_image': output_bytes.hex(),
258
+ 'mask_image': mask_bytes.hex()
259
+ })
260
 
261
  if __name__ == "__main__":
262
+ app.run(debug=True, host="0.0.0.0", port=7860)
263