alexnasa commited on
Commit
f7eba29
·
verified ·
1 Parent(s): 9e7749e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -143
app.py CHANGED
@@ -14,127 +14,124 @@
14
  # limitations under the License.
15
  import spaces
16
 
17
- import torch.multiprocessing as mp
18
- mp.set_start_method('spawn', force=True)
19
-
20
  import tempfile
21
  from PIL import Image
22
  import torch
23
  import gradio as gr
24
- import string
25
- import random, time, math
26
-
27
- from src.flux.generate import generate_from_test_sample, seed_everything
28
- from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
29
- from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
30
- from eval.tools.face_id import FaceID
31
- from eval.tools.florence_sam import ObjectDetector
32
- import shutil
33
- import yaml
34
- import numpy as np
35
- from huggingface_hub import snapshot_download, hf_hub_download
36
- import os
37
-
38
- # FLUX.1-dev
39
- snapshot_download(
40
- repo_id="black-forest-labs/FLUX.1-dev",
41
- local_dir="./checkpoints/FLUX.1-dev",
42
- local_dir_use_symlinks=False
43
- )
44
-
45
- # Florence-2-large
46
- snapshot_download(
47
- repo_id="microsoft/Florence-2-large",
48
- local_dir="./checkpoints/Florence-2-large",
49
- local_dir_use_symlinks=False
50
- )
51
-
52
- # CLIP ViT Large
53
- snapshot_download(
54
- repo_id="openai/clip-vit-large-patch14",
55
- local_dir="./checkpoints/clip-vit-large-patch14",
56
- local_dir_use_symlinks=False
57
- )
58
-
59
- # DINO ViT-s16
60
- snapshot_download(
61
- repo_id="facebook/dino-vits16",
62
- local_dir="./checkpoints/dino-vits16",
63
- local_dir_use_symlinks=False
64
- )
65
-
66
- # mPLUG Visual Question Answering
67
- snapshot_download(
68
- repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
69
- local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en",
70
- local_dir_use_symlinks=False
71
- )
72
-
73
- # XVerse
74
- snapshot_download(
75
- repo_id="ByteDance/XVerse",
76
- local_dir="./checkpoints/XVerse",
77
- local_dir_use_symlinks=False
78
- )
79
-
80
- hf_hub_download(
81
- repo_id="facebook/sam2.1-hiera-large",
82
- local_dir="./checkpoints/",
83
- filename="sam2.1_hiera_large.pt",
84
- )
85
-
86
-
87
-
88
- os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
89
- os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
90
- os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
91
- os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
92
- os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
93
- os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
94
- os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
95
-
96
- dtype = torch.bfloat16
97
- device = "cuda"
98
-
99
- config_path = "train/config/XVerse_config_demo.yaml"
100
-
101
- config = config_train = get_train_config(config_path)
102
- # config["model"]["dit_quant"] = "int8-quanto"
103
- config["model"]["use_dit_lora"] = False
104
- model = CustomFluxPipeline(
105
- config, device, torch_dtype=dtype,
106
- )
107
- model.pipe.set_progress_bar_config(leave=False)
108
-
109
- face_model = FaceID(device)
110
- detector = ObjectDetector(device)
111
-
112
- config = get_train_config(config_path)
113
- model.config = config
114
-
115
- run_mode = "mod_only" # orig_only, mod_only, both
116
- store_attn_map = False
117
- run_name = time.strftime("%m%d-%H%M")
118
-
119
- num_inputs = 6
120
-
121
- ckpt_root = "./checkpoints/XVerse"
122
- model.clear_modulation_adapters()
123
- model.pipe.unload_lora_weights()
124
- if not os.path.exists(ckpt_root):
125
- print("Checkpoint root does not exist.")
126
-
127
- modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
128
- model.add_modulation_adapter(modulation_adapter)
129
- if config["model"]["use_dit_lora"]:
130
- load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False)
131
-
132
- vae_skip_iter = None
133
- attn_skip_iter = 0
134
-
135
-
136
- def clear_images():
137
- return [None, ]*num_inputs
138
 
139
  @spaces.GPU()
140
  def det_seg_img(image, label):
@@ -191,30 +188,6 @@ def resize_keep_aspect_ratio(pil_image, target_size=1024):
191
  new_W = int(round(W * scaling_factor))
192
  return pil_image.resize((new_W, new_H))
193
 
194
- # 使用循环生成六个图像输入
195
- images = []
196
- captions = []
197
- face_btns = []
198
- det_btns = []
199
- vlm_btns = []
200
- accordions = []
201
- idip_checkboxes = []
202
- accordion_states = []
203
-
204
- def open_accordion_on_example_selection(*args):
205
- print("enter open_accordion_on_example_selection")
206
- images = list(args[-18:-12])
207
- outputs = []
208
- for i, img in enumerate(images):
209
- if img is not None:
210
- print(f"open accordions {i}")
211
- outputs.append(True)
212
- else:
213
- print(f"close accordions {i}")
214
- outputs.append(False)
215
- print(outputs)
216
- return outputs
217
-
218
  @spaces.GPU()
219
  def generate_image(
220
  prompt,
@@ -547,8 +520,8 @@ with gr.Blocks() as demo:
547
  outputs=output
548
  )
549
 
550
- # 修改清空函数的输出参数
551
- clear_btn.click(clear_images, outputs=images)
552
 
553
  face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
554
  det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])
 
14
  # limitations under the License.
15
  import spaces
16
 
 
 
 
17
  import tempfile
18
  from PIL import Image
19
  import torch
20
  import gradio as gr
21
+ # import string
22
+ # import random, time, math
23
+
24
+ # from src.flux.generate import generate_from_test_sample, seed_everything
25
+ # from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
26
+ # from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
27
+ # from eval.tools.face_id import FaceID
28
+ # from eval.tools.florence_sam import ObjectDetector
29
+ # import shutil
30
+ # import yaml
31
+ # import numpy as np
32
+ # from huggingface_hub import snapshot_download, hf_hub_download
33
+ # import os
34
+
35
+ # # FLUX.1-dev
36
+ # snapshot_download(
37
+ # repo_id="black-forest-labs/FLUX.1-dev",
38
+ # local_dir="./checkpoints/FLUX.1-dev",
39
+ # local_dir_use_symlinks=False
40
+ # )
41
+
42
+ # # Florence-2-large
43
+ # snapshot_download(
44
+ # repo_id="microsoft/Florence-2-large",
45
+ # local_dir="./checkpoints/Florence-2-large",
46
+ # local_dir_use_symlinks=False
47
+ # )
48
+
49
+ # # CLIP ViT Large
50
+ # snapshot_download(
51
+ # repo_id="openai/clip-vit-large-patch14",
52
+ # local_dir="./checkpoints/clip-vit-large-patch14",
53
+ # local_dir_use_symlinks=False
54
+ # )
55
+
56
+ # # DINO ViT-s16
57
+ # snapshot_download(
58
+ # repo_id="facebook/dino-vits16",
59
+ # local_dir="./checkpoints/dino-vits16",
60
+ # local_dir_use_symlinks=False
61
+ # )
62
+
63
+ # # mPLUG Visual Question Answering
64
+ # snapshot_download(
65
+ # repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
66
+ # local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en",
67
+ # local_dir_use_symlinks=False
68
+ # )
69
+
70
+ # # XVerse
71
+ # snapshot_download(
72
+ # repo_id="ByteDance/XVerse",
73
+ # local_dir="./checkpoints/XVerse",
74
+ # local_dir_use_symlinks=False
75
+ # )
76
+
77
+ # hf_hub_download(
78
+ # repo_id="facebook/sam2.1-hiera-large",
79
+ # local_dir="./checkpoints/",
80
+ # filename="sam2.1_hiera_large.pt",
81
+ # )
82
+
83
+
84
+
85
+ # os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
86
+ # os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
87
+ # os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
88
+ # os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
89
+ # os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
90
+ # os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
91
+ # os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
92
+
93
+ # dtype = torch.bfloat16
94
+ # device = "cuda"
95
+
96
+ # config_path = "train/config/XVerse_config_demo.yaml"
97
+
98
+ # config = config_train = get_train_config(config_path)
99
+ # # config["model"]["dit_quant"] = "int8-quanto"
100
+ # config["model"]["use_dit_lora"] = False
101
+ # model = CustomFluxPipeline(
102
+ # config, device, torch_dtype=dtype,
103
+ # )
104
+ # model.pipe.set_progress_bar_config(leave=False)
105
+
106
+ # face_model = FaceID(device)
107
+ # detector = ObjectDetector(device)
108
+
109
+ # config = get_train_config(config_path)
110
+ # model.config = config
111
+
112
+ # run_mode = "mod_only" # orig_only, mod_only, both
113
+ # store_attn_map = False
114
+ # run_name = time.strftime("%m%d-%H%M")
115
+
116
+ # num_inputs = 6
117
+
118
+ # ckpt_root = "./checkpoints/XVerse"
119
+ # model.clear_modulation_adapters()
120
+ # model.pipe.unload_lora_weights()
121
+ # if not os.path.exists(ckpt_root):
122
+ # print("Checkpoint root does not exist.")
123
+
124
+ # modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
125
+ # model.add_modulation_adapter(modulation_adapter)
126
+ # if config["model"]["use_dit_lora"]:
127
+ # load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False)
128
+
129
+ # vae_skip_iter = None
130
+ # attn_skip_iter = 0
131
+
132
+
133
+ # def clear_images():
134
+ # return [None, ]*num_inputs
135
 
136
  @spaces.GPU()
137
  def det_seg_img(image, label):
 
188
  new_W = int(round(W * scaling_factor))
189
  return pil_image.resize((new_W, new_H))
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  @spaces.GPU()
192
  def generate_image(
193
  prompt,
 
520
  outputs=output
521
  )
522
 
523
+ # # 修改清空函数的输出参数
524
+ # clear_btn.click(clear_images, outputs=images)
525
 
526
  face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
527
  det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])