Update app.py
Browse files
app.py
CHANGED
@@ -14,127 +14,124 @@
|
|
14 |
# limitations under the License.
|
15 |
import spaces
|
16 |
|
17 |
-
import torch.multiprocessing as mp
|
18 |
-
mp.set_start_method('spawn', force=True)
|
19 |
-
|
20 |
import tempfile
|
21 |
from PIL import Image
|
22 |
import torch
|
23 |
import gradio as gr
|
24 |
-
import string
|
25 |
-
import random, time, math
|
26 |
-
|
27 |
-
from src.flux.generate import generate_from_test_sample, seed_everything
|
28 |
-
from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
|
29 |
-
from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
|
30 |
-
from eval.tools.face_id import FaceID
|
31 |
-
from eval.tools.florence_sam import ObjectDetector
|
32 |
-
import shutil
|
33 |
-
import yaml
|
34 |
-
import numpy as np
|
35 |
-
from huggingface_hub import snapshot_download, hf_hub_download
|
36 |
-
import os
|
37 |
-
|
38 |
-
# FLUX.1-dev
|
39 |
-
snapshot_download(
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
)
|
44 |
-
|
45 |
-
# Florence-2-large
|
46 |
-
snapshot_download(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
)
|
51 |
-
|
52 |
-
# CLIP ViT Large
|
53 |
-
snapshot_download(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
)
|
58 |
-
|
59 |
-
# DINO ViT-s16
|
60 |
-
snapshot_download(
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
)
|
65 |
-
|
66 |
-
# mPLUG Visual Question Answering
|
67 |
-
snapshot_download(
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
)
|
72 |
-
|
73 |
-
# XVerse
|
74 |
-
snapshot_download(
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
-
|
80 |
-
hf_hub_download(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
)
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
|
89 |
-
os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
|
90 |
-
os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
|
91 |
-
os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
|
92 |
-
os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
|
93 |
-
os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
|
94 |
-
os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
|
95 |
-
|
96 |
-
dtype = torch.bfloat16
|
97 |
-
device = "cuda"
|
98 |
-
|
99 |
-
config_path = "train/config/XVerse_config_demo.yaml"
|
100 |
-
|
101 |
-
config = config_train = get_train_config(config_path)
|
102 |
-
# config["model"]["dit_quant"] = "int8-quanto"
|
103 |
-
config["model"]["use_dit_lora"] = False
|
104 |
-
model = CustomFluxPipeline(
|
105 |
-
|
106 |
-
)
|
107 |
-
model.pipe.set_progress_bar_config(leave=False)
|
108 |
-
|
109 |
-
face_model = FaceID(device)
|
110 |
-
detector = ObjectDetector(device)
|
111 |
-
|
112 |
-
config = get_train_config(config_path)
|
113 |
-
model.config = config
|
114 |
-
|
115 |
-
run_mode = "mod_only" # orig_only, mod_only, both
|
116 |
-
store_attn_map = False
|
117 |
-
run_name = time.strftime("%m%d-%H%M")
|
118 |
-
|
119 |
-
num_inputs = 6
|
120 |
-
|
121 |
-
ckpt_root = "./checkpoints/XVerse"
|
122 |
-
model.clear_modulation_adapters()
|
123 |
-
model.pipe.unload_lora_weights()
|
124 |
-
if not os.path.exists(ckpt_root):
|
125 |
-
|
126 |
-
|
127 |
-
modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
|
128 |
-
model.add_modulation_adapter(modulation_adapter)
|
129 |
-
if config["model"]["use_dit_lora"]:
|
130 |
-
|
131 |
-
|
132 |
-
vae_skip_iter = None
|
133 |
-
attn_skip_iter = 0
|
134 |
-
|
135 |
-
|
136 |
-
def clear_images():
|
137 |
-
|
138 |
|
139 |
@spaces.GPU()
|
140 |
def det_seg_img(image, label):
|
@@ -191,30 +188,6 @@ def resize_keep_aspect_ratio(pil_image, target_size=1024):
|
|
191 |
new_W = int(round(W * scaling_factor))
|
192 |
return pil_image.resize((new_W, new_H))
|
193 |
|
194 |
-
# 使用循环生成六个图像输入
|
195 |
-
images = []
|
196 |
-
captions = []
|
197 |
-
face_btns = []
|
198 |
-
det_btns = []
|
199 |
-
vlm_btns = []
|
200 |
-
accordions = []
|
201 |
-
idip_checkboxes = []
|
202 |
-
accordion_states = []
|
203 |
-
|
204 |
-
def open_accordion_on_example_selection(*args):
|
205 |
-
print("enter open_accordion_on_example_selection")
|
206 |
-
images = list(args[-18:-12])
|
207 |
-
outputs = []
|
208 |
-
for i, img in enumerate(images):
|
209 |
-
if img is not None:
|
210 |
-
print(f"open accordions {i}")
|
211 |
-
outputs.append(True)
|
212 |
-
else:
|
213 |
-
print(f"close accordions {i}")
|
214 |
-
outputs.append(False)
|
215 |
-
print(outputs)
|
216 |
-
return outputs
|
217 |
-
|
218 |
@spaces.GPU()
|
219 |
def generate_image(
|
220 |
prompt,
|
@@ -547,8 +520,8 @@ with gr.Blocks() as demo:
|
|
547 |
outputs=output
|
548 |
)
|
549 |
|
550 |
-
# 修改清空函数的输出参数
|
551 |
-
clear_btn.click(clear_images, outputs=images)
|
552 |
|
553 |
face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
|
554 |
det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])
|
|
|
14 |
# limitations under the License.
|
15 |
import spaces
|
16 |
|
|
|
|
|
|
|
17 |
import tempfile
|
18 |
from PIL import Image
|
19 |
import torch
|
20 |
import gradio as gr
|
21 |
+
# import string
|
22 |
+
# import random, time, math
|
23 |
+
|
24 |
+
# from src.flux.generate import generate_from_test_sample, seed_everything
|
25 |
+
# from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
|
26 |
+
# from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
|
27 |
+
# from eval.tools.face_id import FaceID
|
28 |
+
# from eval.tools.florence_sam import ObjectDetector
|
29 |
+
# import shutil
|
30 |
+
# import yaml
|
31 |
+
# import numpy as np
|
32 |
+
# from huggingface_hub import snapshot_download, hf_hub_download
|
33 |
+
# import os
|
34 |
+
|
35 |
+
# # FLUX.1-dev
|
36 |
+
# snapshot_download(
|
37 |
+
# repo_id="black-forest-labs/FLUX.1-dev",
|
38 |
+
# local_dir="./checkpoints/FLUX.1-dev",
|
39 |
+
# local_dir_use_symlinks=False
|
40 |
+
# )
|
41 |
+
|
42 |
+
# # Florence-2-large
|
43 |
+
# snapshot_download(
|
44 |
+
# repo_id="microsoft/Florence-2-large",
|
45 |
+
# local_dir="./checkpoints/Florence-2-large",
|
46 |
+
# local_dir_use_symlinks=False
|
47 |
+
# )
|
48 |
+
|
49 |
+
# # CLIP ViT Large
|
50 |
+
# snapshot_download(
|
51 |
+
# repo_id="openai/clip-vit-large-patch14",
|
52 |
+
# local_dir="./checkpoints/clip-vit-large-patch14",
|
53 |
+
# local_dir_use_symlinks=False
|
54 |
+
# )
|
55 |
+
|
56 |
+
# # DINO ViT-s16
|
57 |
+
# snapshot_download(
|
58 |
+
# repo_id="facebook/dino-vits16",
|
59 |
+
# local_dir="./checkpoints/dino-vits16",
|
60 |
+
# local_dir_use_symlinks=False
|
61 |
+
# )
|
62 |
+
|
63 |
+
# # mPLUG Visual Question Answering
|
64 |
+
# snapshot_download(
|
65 |
+
# repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
|
66 |
+
# local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en",
|
67 |
+
# local_dir_use_symlinks=False
|
68 |
+
# )
|
69 |
+
|
70 |
+
# # XVerse
|
71 |
+
# snapshot_download(
|
72 |
+
# repo_id="ByteDance/XVerse",
|
73 |
+
# local_dir="./checkpoints/XVerse",
|
74 |
+
# local_dir_use_symlinks=False
|
75 |
+
# )
|
76 |
+
|
77 |
+
# hf_hub_download(
|
78 |
+
# repo_id="facebook/sam2.1-hiera-large",
|
79 |
+
# local_dir="./checkpoints/",
|
80 |
+
# filename="sam2.1_hiera_large.pt",
|
81 |
+
# )
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
|
86 |
+
# os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
|
87 |
+
# os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
|
88 |
+
# os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
|
89 |
+
# os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
|
90 |
+
# os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
|
91 |
+
# os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
|
92 |
+
|
93 |
+
# dtype = torch.bfloat16
|
94 |
+
# device = "cuda"
|
95 |
+
|
96 |
+
# config_path = "train/config/XVerse_config_demo.yaml"
|
97 |
+
|
98 |
+
# config = config_train = get_train_config(config_path)
|
99 |
+
# # config["model"]["dit_quant"] = "int8-quanto"
|
100 |
+
# config["model"]["use_dit_lora"] = False
|
101 |
+
# model = CustomFluxPipeline(
|
102 |
+
# config, device, torch_dtype=dtype,
|
103 |
+
# )
|
104 |
+
# model.pipe.set_progress_bar_config(leave=False)
|
105 |
+
|
106 |
+
# face_model = FaceID(device)
|
107 |
+
# detector = ObjectDetector(device)
|
108 |
+
|
109 |
+
# config = get_train_config(config_path)
|
110 |
+
# model.config = config
|
111 |
+
|
112 |
+
# run_mode = "mod_only" # orig_only, mod_only, both
|
113 |
+
# store_attn_map = False
|
114 |
+
# run_name = time.strftime("%m%d-%H%M")
|
115 |
+
|
116 |
+
# num_inputs = 6
|
117 |
+
|
118 |
+
# ckpt_root = "./checkpoints/XVerse"
|
119 |
+
# model.clear_modulation_adapters()
|
120 |
+
# model.pipe.unload_lora_weights()
|
121 |
+
# if not os.path.exists(ckpt_root):
|
122 |
+
# print("Checkpoint root does not exist.")
|
123 |
+
|
124 |
+
# modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
|
125 |
+
# model.add_modulation_adapter(modulation_adapter)
|
126 |
+
# if config["model"]["use_dit_lora"]:
|
127 |
+
# load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False)
|
128 |
+
|
129 |
+
# vae_skip_iter = None
|
130 |
+
# attn_skip_iter = 0
|
131 |
+
|
132 |
+
|
133 |
+
# def clear_images():
|
134 |
+
# return [None, ]*num_inputs
|
135 |
|
136 |
@spaces.GPU()
|
137 |
def det_seg_img(image, label):
|
|
|
188 |
new_W = int(round(W * scaling_factor))
|
189 |
return pil_image.resize((new_W, new_H))
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
@spaces.GPU()
|
192 |
def generate_image(
|
193 |
prompt,
|
|
|
520 |
outputs=output
|
521 |
)
|
522 |
|
523 |
+
# # 修改清空函数的输出参数
|
524 |
+
# clear_btn.click(clear_images, outputs=images)
|
525 |
|
526 |
face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
|
527 |
det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])
|