huangrh9 commited on
Commit
cff4f35
·
verified ·
1 Parent(s): 83e02a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -49
app.py CHANGED
@@ -1,64 +1,351 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
41
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import torch
7
  import gradio as gr
8
+ from PIL import Image
9
+ from tokenizer.sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
10
+ from torchvision import transforms
11
+ import logging
12
+ from utils.registry_utils import Config
13
+ from tokenizer.builder import build_vq_model
14
+ from dataset.multi_ratio_dataset import get_image_size, assign_ratio
15
 
 
 
 
 
16
 
17
+ def read_config(file):
18
+ # solve config loading conflict when multi-processes
19
+ import time
20
+ while True:
21
+ config = Config.fromfile(file)
22
+ if len(config) == 0:
23
+ time.sleep(0.1)
24
+ continue
25
+ break
26
+ return config
27
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def build_logger(name, log_file):
30
+ logger = logging.getLogger(name)
31
+ logger.setLevel(logging.INFO)
32
+ handler = logging.FileHandler(log_file)
33
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
34
+ handler.setFormatter(formatter)
35
+ logger.addHandler(handler)
36
+ return logger
37
 
 
38
 
39
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
40
 
41
+ vq_model = None
42
+ is_ema_model = False
43
+ diffusion_pipeline = None
44
+ lazy_load = False
 
 
 
 
45
 
46
+ # diffusion decoder hyperparameters.
47
+ resolution_list = [
48
+ (1024, 1024), (768, 1024), (1024, 768),
49
+ (512, 2048), (2048, 512), (640, 1920),
50
+ (1920, 640), (768, 1536),
51
+ (1536, 768), (768, 1152), (1152, 768)
52
+ ]
53
 
54
+ cfg_range = (1, 10.0)
55
+ step_range = (1, 100)
56
 
57
+
58
+ def resize_to_shortest_edge(img, shortest_edge_resolution):
59
+ width, height = img.size
60
+
61
+ if width < height:
62
+ new_width = shortest_edge_resolution
63
+ new_height = int(height * (new_width / width))
64
+ elif height < width:
65
+ new_height = shortest_edge_resolution
66
+ new_width = int(width * (new_height / height))
67
+ else:
68
+ new_width = shortest_edge_resolution
69
+ new_height = shortest_edge_resolution
70
+
71
+ resized_img = img.resize((new_width, new_height))
72
+ return resized_img
73
+
74
+
75
+ from PIL import Image
76
+
77
+
78
+ def resize_to_square_with_long_edge(image: Image.Image, size: int = 512):
79
+ """Resize image so that its *long* side equals `size`, short side scaled proportionally."""
80
+ width, height = image.size
81
+ if width > height:
82
+ new_width = size
83
+ new_height = int(size * height / width)
84
+ else:
85
+ new_height = size
86
+ new_width = int(size * width / height)
87
+ return image.resize((new_width, new_height), Image.LANCZOS)
88
+
89
+
90
+ def pad_to_square(image: Image.Image, target_size: int = 512, color=(255, 255, 255)):
91
+ image = resize_to_square_with_long_edge(image, target_size)
92
+ new_img = Image.new("RGB", (target_size, target_size), color)
93
+ offset_x = (target_size - image.width) // 2
94
+ offset_y = (target_size - image.height) // 2
95
+ new_img.paste(image, (offset_x, offset_y))
96
+ return new_img
97
+
98
+
99
+ def load_vqgan_model(args, model_dtype='fp16', use_ema=False, ):
100
+ global vq_model
101
+ vq_model = build_vq_model(args.vq_model)
102
+
103
+ if model_dtype == 'fp16':
104
+ vq_model = vq_model.to(torch.float16)
105
+ logger.info("Convert the model dtype to float16")
106
+ elif model_dtype == 'bf16':
107
+ vq_model = vq_model.to(torch.bfloat16)
108
+ logger.info("Convert the model dtype to bfloat16")
109
+
110
+ vq_model.to('cuda')
111
+ vq_model.eval()
112
+ checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
113
+
114
+ if "ema" in checkpoint:
115
+ ema_state_dict = checkpoint["ema"]
116
+ else:
117
+ ema_state_dict = None
118
+
119
+ if "model" in checkpoint:
120
+ model_state_dict = checkpoint["model"]
121
+ elif "state_dict" in checkpoint:
122
+ model_state_dict = checkpoint["state_dict"]
123
+ else:
124
+ model_state_dict = checkpoint
125
+
126
+ if use_ema:
127
+ vq_model.load_state_dict(ema_state_dict, strict=True)
128
+ else:
129
+ vq_model.load_state_dict(model_state_dict, strict=True)
130
+ return vq_model
131
+
132
+
133
+ def load_diffusion_decoder(args):
134
+ global diffusion_pipeline
135
+ diffusion_pipeline = StableDiffusionXLDecoderPipeline.from_pretrained(
136
+ args.sdxl_decoder_path,
137
+ add_watermarker=False,
138
+ vq_config=args,
139
+ vq_model=vq_model,
140
+ )
141
+ diffusion_pipeline.to(vq_model.device)
142
+
143
+
144
+ def vqgan_diffusion_decoder_reconstruct(input_image, diffusion_upsample, cfg_values, steps):
145
+ transform = transforms.Compose([
146
+ transforms.ToTensor(),
147
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
148
+ ])
149
+ input_tensor = transform(input_image).unsqueeze(0).to(vq_model.device)
150
+
151
+ org_width, org_height = input_image.size
152
+ if diffusion_upsample:
153
+ width, height = org_width * 2, org_height * 2
154
+ else:
155
+ width, height = org_width, org_height
156
+
157
+ print(diffusion_upsample, org_width, org_height, width, height)
158
+ group_index = assign_ratio(height, width, resolution_list)
159
+ select_h, select_w = resolution_list[group_index]
160
+
161
+ diffusion_outputs = diffusion_pipeline(
162
+ images=input_tensor,
163
+ height=select_h,
164
+ width=select_w,
165
+ guidance_scale=cfg_values,
166
+ num_inference_steps=steps
167
+ )
168
+ sample = diffusion_outputs.images[0]
169
+ sample.resize((width, height))
170
+ return sample, f"�� **Output Resolution**: {width}x{height}"
171
+
172
+
173
+ @torch.no_grad()
174
+ def vqgan_reconstruct(input_image):
175
+ transform = transforms.Compose([
176
+ transforms.ToTensor(),
177
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
178
+ ])
179
+
180
+ org_width, org_height = input_image.size
181
+
182
+ width = org_width // 16 * 16
183
+ height = org_height // 16 * 16
184
+
185
+ input_image = input_image.resize((width, height))
186
+ input_tensor = transform(input_image).unsqueeze(0).to(vq_model.device)
187
+
188
+ with torch.no_grad():
189
+ inputs = vq_model.get_input(dict(image=input_tensor))
190
+ (quant_semantic, _, _, _), \
191
+ (quant_detail, _, _) = vq_model.encode(**inputs)
192
+ reconstructed_image = vq_model.decode(quant_semantic, quant_detail)
193
+
194
+ reconstructed_image = torch.clamp(127.5 * reconstructed_image + 128.0, 0, 255)
195
+ reconstructed_image = reconstructed_image.squeeze(0).permute(1, 2, 0).cpu().numpy().astype('uint8')
196
+
197
+ output_image = Image.fromarray(reconstructed_image)
198
+ output_image.resize((org_width, org_height))
199
+ return output_image, f"�� **Output Resolution**: {org_width}x{org_height}"
200
+
201
+
202
+ title_markdown = '''# DualViTok Demo
203
+ The DualViTok is a dual-branch vision tokenizer designed to capture both deep semantics and fine-grained textures. Implementation details can be found in ILLUME+[[ArXiv](https://arxiv.org/abs/2504.01934)].
204
+ '''
205
+
206
+ usage_markdown = """
207
+ <details>
208
+ <summary><strong>�� Usage Instructions (click to expand)</strong></summary>
209
+
210
+ 1. Upload an image and click the <strong>Reconstruct</strong> button.
211
+ 2. Set <code>Max Shortest Side</code> to limit the image resolution.
212
+ 3. Click <code>Force Upscale to Max Shortest Side to enable <strong>Force Upscale</strong> to resize the shortest side of the image to the <code>Max Shortest Side</code>.
213
+ 4. <em>(Optional)</em> Check <code>Use EMA model</code> to use the EMA checkpoint for reconstruction.
214
+ 5. <em>(Optional)</em> Click <code>Load Diffusion Decoder</code> to enable Diffusion Model decoding.
215
+ You can also enable <code>2x Upsample</code> to apply super-resolution to the uploaded image.
216
+
217
+ </details>
218
  """
219
+
220
+
221
+ def build_gradio_interface(args):
222
+ if not lazy_load:
223
+ load_vqgan_model(args, model_dtype=args.model_dtype)
224
+
225
+ with gr.Blocks() as demo:
226
+ gr.Markdown(title_markdown)
227
+ gr.Markdown(usage_markdown)
228
+
229
+ with gr.Row():
230
+ with gr.Column():
231
+ gr.Markdown("## ��️ Input Image")
232
+ input_image = gr.Image(type="pil", label="Upload Image", width=384, height=384)
233
+ input_resolution_display = gr.Markdown("")
234
+ gr.Examples(
235
+ examples=[
236
+ ["../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/1.png",],
237
+ ["../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/2.png",],
238
+ ["../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/3.png",],
239
+ ],
240
+ inputs=input_image,
241
+ label="Example Images",
242
+ )
243
+
244
+ with gr.Column():
245
+ gr.Markdown("## �� Reconstructed Image")
246
+ output_image_recon = gr.Image(type="pil", label="Reconstruction", width=384, height=384)
247
+ output_resolution_display = gr.Markdown("")
248
+
249
+ with gr.Column():
250
+ gr.Markdown("## ⚙ Hyperparameters")
251
+ # with gr.Row():
252
+ short_resolution_dropdown = gr.Dropdown(
253
+ choices=[None, 256, 384, 512, 1024],
254
+ value=1024,
255
+ label="Max Shortest Side"
256
+ )
257
+ force_upscale_checkbox = gr.Checkbox(label="Force Upscale to Max Shortest Side", value=False)
258
+ use_ema_checkbox = gr.Checkbox(label="Use EMA Model", value=False)
259
+
260
+ with gr.Accordion("�� Use Diffusion Decoder", open=False):
261
+ use_diffusion_checkbox = gr.Checkbox(label="Load Diffusion Decoder", value=False)
262
+ diffusion_upsample_checkbox = gr.Checkbox(label="Enable 2x Upsample", value=False)
263
+ cfg_slider = gr.Slider(
264
+ minimum=cfg_range[0], maximum=cfg_range[1],
265
+ step=0.5, value=1.5,
266
+ label="CFG Value"
267
+ )
268
+ step_slider = gr.Slider(
269
+ minimum=step_range[0], maximum=step_range[1],
270
+ step=1, value=20,
271
+ label="Inference Steps"
272
+ )
273
+ reconstruct_btn = gr.Button("�� Reconstruct", variant="primary")
274
+
275
+ def handle_input_image(image):
276
+ if image is not None:
277
+ image = image.convert("RGB")
278
+ w, h = image.size
279
+ return image, f"�� **Input Resolution**: {w}x{h}"
280
+ return None, ""
281
+
282
+ input_image.change(
283
+ handle_input_image,
284
+ inputs=input_image,
285
+ outputs=[input_image, input_resolution_display]
286
+ )
287
+
288
+ def reconstruct_fn(image, use_ema_flag, short_edge_resolution, force_upscale,
289
+ use_diffusion_flag, diffusion_upsample, cfg_value, num_steps):
290
+
291
+ if short_edge_resolution is not None:
292
+ if force_upscale or min(image.size) > short_edge_resolution:
293
+ image = resize_to_shortest_edge(image, int(short_edge_resolution))
294
+
295
+ global vq_model
296
+ if lazy_load and vq_model is None:
297
+ load_vqgan_model(args, model_dtype=args.model_dtype)
298
+
299
+ if use_ema_flag:
300
+ if not is_ema_model:
301
+ load_vqgan_model(args, model_dtype=args.model_dtype, use_ema=True)
302
+ logger.info("Switched to EMA checkpoint")
303
+ else:
304
+ if is_ema_model:
305
+ load_vqgan_model(args, model_dtype=args.model_dtype, use_ema=False)
306
+ logger.info("Switched to non-EMA checkpoint")
307
+
308
+ if use_diffusion_flag:
309
+ if diffusion_pipeline is None:
310
+ load_diffusion_decoder(args)
311
+ recon_image, resolution_str = vqgan_diffusion_decoder_reconstruct(image, diffusion_upsample, cfg_value,
312
+ num_steps)
313
+ else:
314
+ recon_image, resolution_str = vqgan_reconstruct(image)
315
+
316
+ return pad_to_square(recon_image, target_size=384), resolution_str
317
+
318
+ reconstruct_btn.click(
319
+ reconstruct_fn,
320
+ inputs=[input_image, use_ema_checkbox, short_resolution_dropdown, force_upscale_checkbox,
321
+ use_diffusion_checkbox, diffusion_upsample_checkbox, cfg_slider, step_slider],
322
+ outputs=[output_image_recon, output_resolution_display])
323
+
324
+ demo.launch(server_name='0.0.0.0')
325
+
326
+
327
+ # 主函数
328
+ def main():
329
+ parser = argparse.ArgumentParser()
330
+ parser.add_argument("config", type=str)
331
+ parser.add_argument("--local_rank", type=int, default=0)
332
+ parser.add_argument("--vq-ckpt", type=str, help="ckpt path for vq model")
333
+ parser.add_argument("--torch-dtype", type=str, default='fp32')
334
+ parser.add_argument("--model-dtype", type=str, default='fp32')
335
+ parser.add_argument("--sdxl-decoder-path", type=str, default=None)
336
+ parser.add_argument("--verbose", action='store_true')
337
+
338
+ args = parser.parse_args()
339
+
340
+ config = read_config(args.config)
341
+ config.vq_ckpt = args.vq_ckpt
342
+ config.torch_dtype = args.torch_dtype
343
+ config.model_dtype = args.model_dtype
344
+ config.verbose = args.verbose
345
+ config.sdxl_decoder_path = args.sdxl_decoder_path
346
+
347
+ build_gradio_interface(config)
348
 
349
 
350
  if __name__ == "__main__":
351
+ main()