aiqcamp commited on
Commit
3761983
·
verified ·
1 Parent(s): 05cb457

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -498
app.py DELETED
@@ -1,498 +0,0 @@
1
- import argparse
2
- import os
3
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
- os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
- from datetime import datetime
6
-
7
- import gradio as gr
8
- import spaces
9
- import numpy as np
10
- import torch
11
- from diffusers.image_processor import VaeImageProcessor
12
- from huggingface_hub import snapshot_download
13
- from PIL import Image
14
- torch.jit.script = lambda f: f
15
- from model.cloth_masker import AutoMasker, vis_mask
16
- from model.pipeline import CatVTONPipeline
17
- from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
-
19
- def parse_args():
20
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
21
- parser.add_argument(
22
- "--base_model_path",
23
- type=str,
24
- default="booksforcharlie/stable-diffusion-inpainting",
25
- help=(
26
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
27
- ),
28
- )
29
- parser.add_argument(
30
- "--resume_path",
31
- type=str,
32
- default="zhengchong/CatVTON",
33
- help=(
34
- "The Path to the checkpoint of trained tryon model."
35
- ),
36
- )
37
- parser.add_argument(
38
- "--output_dir",
39
- type=str,
40
- default="resource/demo/output",
41
- help="The output directory where the model predictions will be written.",
42
- )
43
- parser.add_argument(
44
- "--width",
45
- type=int,
46
- default=768,
47
- help=(
48
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
49
- " resolution"
50
- ),
51
- )
52
- parser.add_argument(
53
- "--height",
54
- type=int,
55
- default=1024,
56
- help=(
57
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
58
- " resolution"
59
- ),
60
- )
61
- parser.add_argument(
62
- "--repaint",
63
- action="store_true",
64
- help="Whether to repaint the result image with the original background."
65
- )
66
- parser.add_argument(
67
- "--allow_tf32",
68
- action="store_true",
69
- default=True,
70
- help=(
71
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
72
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
73
- ),
74
- )
75
- parser.add_argument(
76
- "--mixed_precision",
77
- type=str,
78
- default="bf16",
79
- choices=["no", "fp16", "bf16"],
80
- help=(
81
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
82
- " 1.10 and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
83
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
84
- ),
85
- )
86
-
87
- args = parser.parse_args()
88
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
89
- if env_local_rank != -1 and env_local_rank != args.local_rank:
90
- args.local_rank = env_local_rank
91
-
92
- return args
93
-
94
- def image_grid(imgs, rows, cols):
95
- assert len(imgs) == rows * cols
96
-
97
- w, h = imgs[0].size
98
- grid = Image.new("RGB", size=(cols * w, rows * h))
99
-
100
- for i, img in enumerate(imgs):
101
- grid.paste(img, box=(i % cols * w, i // cols * h))
102
- return grid
103
-
104
- args = parse_args()
105
- repo_path = snapshot_download(repo_id=args.resume_path)
106
-
107
- # Pipeline
108
- pipeline = CatVTONPipeline(
109
- base_ckpt=args.base_model_path,
110
- attn_ckpt=repo_path,
111
- attn_ckpt_version="mix",
112
- weight_dtype=init_weight_dtype(args.mixed_precision),
113
- use_tf32=args.allow_tf32,
114
- device='cuda'
115
- )
116
-
117
- # AutoMasker
118
- mask_processor = VaeImageProcessor(
119
- vae_scale_factor=8,
120
- do_normalize=False,
121
- do_binarize=True,
122
- do_convert_grayscale=True
123
- )
124
- automasker = AutoMasker(
125
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
126
- schp_ckpt=os.path.join(repo_path, "SCHP"),
127
- device='cuda',
128
- )
129
-
130
- @spaces.GPU(duration=120)
131
- def submit_function(
132
- person_image,
133
- cloth_image,
134
- cloth_type,
135
- num_inference_steps,
136
- guidance_scale,
137
- seed,
138
- show_type
139
- ):
140
- # person_image 객체에서 background와 layers[0]을 분리
141
- person_image, mask = person_image["background"], person_image["layers"][0]
142
- mask = Image.open(mask).convert("L")
143
-
144
- # 만약 마스크가 전부 0(검정)이면 None 처리
145
- if len(np.unique(np.array(mask))) == 1:
146
- mask = None
147
- else:
148
- mask = np.array(mask)
149
- mask[mask > 0] = 255
150
- mask = Image.fromarray(mask)
151
-
152
- tmp_folder = args.output_dir
153
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
154
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
155
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
156
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
157
-
158
- generator = None
159
- if seed != -1:
160
- generator = torch.Generator(device='cuda').manual_seed(seed)
161
-
162
- person_image = Image.open(person_image).convert("RGB")
163
- cloth_image = Image.open(cloth_image).convert("RGB")
164
- person_image = resize_and_crop(person_image, (args.width, args.height))
165
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
166
-
167
- # If user didn't draw a mask
168
- if mask is not None:
169
- mask = resize_and_crop(mask, (args.width, args.height))
170
- else:
171
- mask = automasker(
172
- person_image,
173
- cloth_type
174
- )['mask']
175
- mask = mask_processor.blur(mask, blur_factor=9)
176
-
177
- # Inference
178
- result_image = pipeline(
179
- image=person_image,
180
- condition_image=cloth_image,
181
- mask=mask,
182
- num_inference_steps=num_inference_steps,
183
- guidance_scale=guidance_scale,
184
- generator=generator
185
- )[0]
186
-
187
- # Post-process & Save
188
- masked_person = vis_mask(person_image, mask)
189
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
190
- save_result_image.save(result_save_path)
191
-
192
- if show_type == "result only":
193
- return result_image
194
- else:
195
- width, height = person_image.size
196
- if show_type == "input & result":
197
- condition_width = width // 2
198
- conditions = image_grid([person_image, cloth_image], 2, 1)
199
- else:
200
- condition_width = width // 3
201
- conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
202
-
203
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
204
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
205
- new_result_image.paste(conditions, (0, 0))
206
- new_result_image.paste(result_image, (condition_width + 5, 0))
207
- return new_result_image
208
-
209
- def person_example_fn(image_path):
210
- return image_path
211
-
212
- # Custom CSS
213
- css = """
214
- footer {visibility: hidden}
215
-
216
- /* Main container styling */
217
- .gradio-container {
218
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
219
- border-radius: 20px;
220
- box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15);
221
- }
222
-
223
- /* Header styling */
224
- h1, h2, h3 {
225
- color: #2c3e50;
226
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
227
- text-shadow: 1px 1px 2px rgba(0,0,0,0.1);
228
- }
229
-
230
- /* Button styling */
231
- button.primary-button {
232
- background: linear-gradient(45deg, #4CAF50, #45a049);
233
- border: none;
234
- border-radius: 10px;
235
- color: white;
236
- padding: 12px 24px;
237
- font-weight: bold;
238
- transition: all 0.3s ease;
239
- box-shadow: 0 4px 15px rgba(76, 175, 80, 0.3);
240
- }
241
-
242
- button.primary-button:hover {
243
- transform: translateY(-2px);
244
- box-shadow: 0 6px 20px rgba(76, 175, 80, 0.4);
245
- }
246
-
247
- /* Image container styling */
248
- .image-container {
249
- border-radius: 15px;
250
- overflow: hidden;
251
- box-shadow: 0 4px 15px rgba(0,0,0,0.1);
252
- transition: transform 0.3s ease;
253
- }
254
-
255
- .image-container:hover {
256
- transform: scale(1.02);
257
- }
258
-
259
- /* Radio button styling */
260
- .radio-group label {
261
- background-color: #ffffff;
262
- border-radius: 8px;
263
- padding: 10px 15px;
264
- margin: 5px;
265
- cursor: pointer;
266
- transition: all 0.3s ease;
267
- }
268
-
269
- .radio-group input:checked + label {
270
- background-color: #4CAF50;
271
- color: white;
272
- }
273
-
274
- /* Slider styling */
275
- .slider-container {
276
- background: white;
277
- padding: 15px;
278
- border-radius: 10px;
279
- box-shadow: 0 2px 10px rgba(0,0,0,0.05);
280
- }
281
-
282
- .slider {
283
- height: 8px;
284
- border-radius: 4px;
285
- background: #e0e0e0;
286
- }
287
-
288
- .slider .thumb {
289
- width: 20px;
290
- height: 20px;
291
- background: #4CAF50;
292
- border-radius: 50%;
293
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
294
- }
295
-
296
- /* Alert/warning text styling */
297
- .warning-text {
298
- color: #ff5252;
299
- font-weight: bold;
300
- text-align: center;
301
- padding: 10px;
302
- background: rgba(255,82,82,0.1);
303
- border-radius: 8px;
304
- margin: 10px 0;
305
- }
306
-
307
- /* Example gallery styling */
308
- .example-gallery {
309
- display: grid;
310
- grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
311
- gap: 15px;
312
- padding: 15px;
313
- background: white;
314
- border-radius: 10px;
315
- box-shadow: 0 2px 10px rgba(0,0,0,0.05);
316
- }
317
-
318
- .example-item {
319
- border-radius: 8px;
320
- overflow: hidden;
321
- transition: transform 0.3s ease;
322
- }
323
-
324
- .example-item:hover {
325
- transform: scale(1.05);
326
- }
327
- """
328
-
329
- def app_gradio():
330
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="green", secondary_hue="blue"), css=css) as demo:
331
- gr.Markdown(
332
- """
333
- # 👔 Fashion Fit
334
- Transform your look with AI-powered virtual clothing try-on!
335
- """
336
- )
337
-
338
- with gr.Row():
339
- with gr.Column(scale=1, min_width=350):
340
- with gr.Group():
341
- gr.Markdown("### 📸 Upload Images")
342
- with gr.Row():
343
- image_path = gr.Image(
344
- type="filepath",
345
- interactive=True,
346
- visible=False,
347
- )
348
- person_image = gr.ImageEditor(
349
- interactive=True,
350
- label="Person Image",
351
- type="filepath",
352
- elem_classes="image-container"
353
- )
354
-
355
- with gr.Row():
356
- with gr.Column(scale=1, min_width=230):
357
- cloth_image = gr.Image(
358
- interactive=True,
359
- label="Clothing Item",
360
- type="filepath",
361
- elem_classes="image-container"
362
- )
363
- with gr.Column(scale=1, min_width=120):
364
-
365
- cloth_type = gr.Radio(
366
- label="Clothing Type",
367
- choices=["upper", "lower", "overall"],
368
- value="upper",
369
- elem_classes="radio-group"
370
- )
371
-
372
- submit = gr.Button("🚀 Generate Try-On", elem_classes="primary-button")
373
-
374
-
375
- with gr.Accordion("⚙️ Advanced Settings", open=False):
376
- num_inference_steps = gr.Slider(
377
- label="Quality Level",
378
- minimum=10,
379
- maximum=100,
380
- step=5,
381
- value=50,
382
- elem_classes="slider-container"
383
- )
384
- guidance_scale = gr.Slider(
385
- label="Style Strength",
386
- minimum=0.0,
387
- maximum=7.5,
388
- step=0.5,
389
- value=2.5,
390
- elem_classes="slider-container"
391
- )
392
- seed = gr.Slider(
393
- label="Random Seed",
394
- minimum=-1,
395
- maximum=10000,
396
- step=1,
397
- value=42,
398
- elem_classes="slider-container"
399
- )
400
- show_type = gr.Radio(
401
- label="Display Mode",
402
- choices=["result only", "input & result", "input & mask & result"],
403
- value="input & mask & result",
404
- elem_classes="radio-group"
405
- )
406
-
407
- with gr.Column(scale=2, min_width=500):
408
- result_image = gr.Image(
409
- interactive=False,
410
- label="Final Result",
411
- elem_classes="image-container"
412
- )
413
- with gr.Row():
414
- root_path = "resource/demo/example"
415
- with gr.Column():
416
- gr.Markdown("#### 👤 Model Examples")
417
- # elem_classes 인자를 제거해야 오류가 사라집니다.
418
- men_exm = gr.Examples(
419
- examples=[
420
- os.path.join(root_path, "person", "men", file)
421
- for file in os.listdir(os.path.join(root_path, "person", "men"))
422
- ],
423
- examples_per_page=4,
424
- inputs=image_path,
425
- label="Men's Examples"
426
- )
427
- women_exm = gr.Examples(
428
- examples=[
429
- os.path.join(root_path, "person", "women", file)
430
- for file in os.listdir(os.path.join(root_path, "person", "women"))
431
- ],
432
- examples_per_page=4,
433
- inputs=image_path,
434
- label="Women's Examples"
435
- )
436
- gr.Markdown(
437
- '<div class="info-text">Model examples courtesy of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a></div>'
438
- )
439
-
440
- with gr.Column():
441
- gr.Markdown("#### 👕 Clothing Examples")
442
- condition_upper_exm = gr.Examples(
443
- examples=[
444
- os.path.join(root_path, "condition", "upper", file)
445
- for file in os.listdir(os.path.join(root_path, "condition", "upper"))
446
- ],
447
- examples_per_page=4,
448
- inputs=cloth_image,
449
- label="Upper Garments"
450
- )
451
- condition_overall_exm = gr.Examples(
452
- examples=[
453
- os.path.join(root_path, "condition", "overall", file)
454
- for file in os.listdir(os.path.join(root_path, "condition", "overall"))
455
- ],
456
- examples_per_page=4,
457
- inputs=cloth_image,
458
- label="Full Outfits"
459
- )
460
- condition_person_exm = gr.Examples(
461
- examples=[
462
- os.path.join(root_path, "condition", "person", file)
463
- for file in os.listdir(os.path.join(root_path, "condition", "person"))
464
- ],
465
- examples_per_page=4,
466
- inputs=cloth_image,
467
- label="Reference Styles"
468
- )
469
- gr.Markdown(
470
- '<div class="info-text">Clothing examples sourced from various online retailers</div>'
471
- )
472
-
473
- image_path.change(
474
- person_example_fn,
475
- inputs=image_path,
476
- outputs=person_image
477
- )
478
-
479
- submit.click(
480
- submit_function,
481
- [
482
- person_image,
483
- cloth_image,
484
- cloth_type,
485
- num_inference_steps,
486
- guidance_scale,
487
- seed,
488
- show_type,
489
- ],
490
- result_image,
491
- )
492
-
493
-
494
-
495
- demo.queue().launch(share=True, show_error=True)
496
-
497
- if __name__ == "__main__":
498
- app_gradio()