huangrh9 commited on
Commit
20e77aa
·
verified ·
1 Parent(s): 161ddb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +900 -284
app.py CHANGED
@@ -1,351 +1,967 @@
1
  import argparse
2
- import datetime
3
- import json
4
  import os
5
- import time
 
 
 
 
 
 
6
  import torch
7
- import gradio as gr
 
 
 
8
  from PIL import Image
9
- from tokenizer.sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
10
- from torchvision import transforms
11
- import logging
12
- from utils.registry_utils import Config
13
- from tokenizer.builder import build_vq_model
14
- from dataset.multi_ratio_dataset import get_image_size, assign_ratio
15
-
16
-
17
- def read_config(file):
18
- # solve config loading conflict when multi-processes
19
- import time
20
- while True:
21
- config = Config.fromfile(file)
22
- if len(config) == 0:
23
- time.sleep(0.1)
24
- continue
25
- break
26
- return config
27
-
28
-
29
- def build_logger(name, log_file):
30
- logger = logging.getLogger(name)
31
- logger.setLevel(logging.INFO)
32
- handler = logging.FileHandler(log_file)
33
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
34
- handler.setFormatter(formatter)
35
- logger.addHandler(handler)
36
- return logger
37
-
38
-
39
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
40
-
41
- vq_model = None
42
- is_ema_model = False
43
- diffusion_pipeline = None
44
- lazy_load = False
45
-
46
- # diffusion decoder hyperparameters.
47
- resolution_list = [
48
- (1024, 1024), (768, 1024), (1024, 768),
49
- (512, 2048), (2048, 512), (640, 1920),
50
- (1920, 640), (768, 1536),
51
- (1536, 768), (768, 1152), (1152, 768)
52
- ]
53
 
54
- cfg_range = (1, 10.0)
55
- step_range = (1, 100)
 
 
 
 
56
 
 
 
57
 
58
- def resize_to_shortest_edge(img, shortest_edge_resolution):
59
- width, height = img.size
 
 
60
 
61
- if width < height:
62
- new_width = shortest_edge_resolution
63
- new_height = int(height * (new_width / width))
64
- elif height < width:
65
- new_height = shortest_edge_resolution
66
- new_width = int(width * (new_height / height))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  else:
68
- new_width = shortest_edge_resolution
69
- new_height = shortest_edge_resolution
 
 
 
70
 
71
- resized_img = img.resize((new_width, new_height))
72
- return resized_img
73
 
 
 
 
 
74
 
75
- from PIL import Image
 
76
 
 
 
 
77
 
78
- def resize_to_square_with_long_edge(image: Image.Image, size: int = 512):
79
- """Resize image so that its *long* side equals `size`, short side scaled proportionally."""
80
- width, height = image.size
81
- if width > height:
82
- new_width = size
83
- new_height = int(size * height / width)
 
84
  else:
85
- new_height = size
86
- new_width = int(size * width / height)
87
- return image.resize((new_width, new_height), Image.LANCZOS)
 
 
 
88
 
89
 
90
- def pad_to_square(image: Image.Image, target_size: int = 512, color=(255, 255, 255)):
91
- image = resize_to_square_with_long_edge(image, target_size)
92
- new_img = Image.new("RGB", (target_size, target_size), color)
93
- offset_x = (target_size - image.width) // 2
94
- offset_y = (target_size - image.height) // 2
95
- new_img.paste(image, (offset_x, offset_y))
96
- return new_img
97
 
98
 
99
- def load_vqgan_model(args, model_dtype='fp16', use_ema=False, ):
100
- global vq_model
101
- vq_model = build_vq_model(args.vq_model)
102
 
103
- if model_dtype == 'fp16':
104
- vq_model = vq_model.to(torch.float16)
105
- logger.info("Convert the model dtype to float16")
106
- elif model_dtype == 'bf16':
107
- vq_model = vq_model.to(torch.bfloat16)
108
- logger.info("Convert the model dtype to bfloat16")
109
 
110
- vq_model.to('cuda')
111
- vq_model.eval()
112
- checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
113
 
114
- if "ema" in checkpoint:
115
- ema_state_dict = checkpoint["ema"]
116
- else:
117
- ema_state_dict = None
118
 
119
- if "model" in checkpoint:
120
- model_state_dict = checkpoint["model"]
121
- elif "state_dict" in checkpoint:
122
- model_state_dict = checkpoint["state_dict"]
123
- else:
124
- model_state_dict = checkpoint
125
 
126
- if use_ema:
127
- vq_model.load_state_dict(ema_state_dict, strict=True)
128
- else:
129
- vq_model.load_state_dict(model_state_dict, strict=True)
130
- return vq_model
131
 
 
 
 
132
 
133
- def load_diffusion_decoder(args):
134
- global diffusion_pipeline
135
- diffusion_pipeline = StableDiffusionXLDecoderPipeline.from_pretrained(
136
- args.sdxl_decoder_path,
137
- add_watermarker=False,
138
- vq_config=args,
139
- vq_model=vq_model,
140
- )
141
- diffusion_pipeline.to(vq_model.device)
142
 
 
 
143
 
144
- def vqgan_diffusion_decoder_reconstruct(input_image, diffusion_upsample, cfg_values, steps):
145
- transform = transforms.Compose([
146
- transforms.ToTensor(),
147
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
148
- ])
149
- input_tensor = transform(input_image).unsqueeze(0).to(vq_model.device)
150
 
151
- org_width, org_height = input_image.size
152
- if diffusion_upsample:
153
- width, height = org_width * 2, org_height * 2
154
- else:
155
- width, height = org_width, org_height
156
-
157
- print(diffusion_upsample, org_width, org_height, width, height)
158
- group_index = assign_ratio(height, width, resolution_list)
159
- select_h, select_w = resolution_list[group_index]
160
-
161
- diffusion_outputs = diffusion_pipeline(
162
- images=input_tensor,
163
- height=select_h,
164
- width=select_w,
165
- guidance_scale=cfg_values,
166
- num_inference_steps=steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
- sample = diffusion_outputs.images[0]
169
- sample.resize((width, height))
170
- return sample, f"�� **Output Resolution**: {width}x{height}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
 
 
172
 
173
- @torch.no_grad()
174
- def vqgan_reconstruct(input_image):
175
- transform = transforms.Compose([
176
- transforms.ToTensor(),
177
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
178
- ])
179
 
180
- org_width, org_height = input_image.size
181
 
182
- width = org_width // 16 * 16
183
- height = org_height // 16 * 16
 
 
184
 
185
- input_image = input_image.resize((width, height))
186
- input_tensor = transform(input_image).unsqueeze(0).to(vq_model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- with torch.no_grad():
189
- inputs = vq_model.get_input(dict(image=input_tensor))
190
- (quant_semantic, _, _, _), \
191
- (quant_detail, _, _) = vq_model.encode(**inputs)
192
- reconstructed_image = vq_model.decode(quant_semantic, quant_detail)
193
 
194
- reconstructed_image = torch.clamp(127.5 * reconstructed_image + 128.0, 0, 255)
195
- reconstructed_image = reconstructed_image.squeeze(0).permute(1, 2, 0).cpu().numpy().astype('uint8')
 
 
 
 
 
 
 
196
 
197
- output_image = Image.fromarray(reconstructed_image)
198
- output_image.resize((org_width, org_height))
199
- return output_image, f"�� **Output Resolution**: {org_width}x{org_height}"
200
 
 
201
 
202
- title_markdown = '''# DualViTok Demo
203
- The DualViTok is a dual-branch vision tokenizer designed to capture both deep semantics and fine-grained textures. Implementation details can be found in ILLUME+[[ArXiv](https://arxiv.org/abs/2504.01934)].
204
- '''
205
 
206
- usage_markdown = """
207
- <details>
208
- <summary><strong>�� Usage Instructions (click to expand)</strong></summary>
 
 
 
 
209
 
210
- 1. Upload an image and click the <strong>Reconstruct</strong> button.
211
- 2. Set <code>Max Shortest Side</code> to limit the image resolution.
212
- 3. Click <code>Force Upscale to Max Shortest Side to enable <strong>Force Upscale</strong> to resize the shortest side of the image to the <code>Max Shortest Side</code>.
213
- 4. <em>(Optional)</em> Check <code>Use EMA model</code> to use the EMA checkpoint for reconstruction.
214
- 5. <em>(Optional)</em> Click <code>Load Diffusion Decoder</code> to enable Diffusion Model decoding.
215
- You can also enable <code>2x Upsample</code> to apply super-resolution to the uploaded image.
 
 
 
 
 
 
216
 
217
- </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  """
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- def build_gradio_interface(args):
222
- if not lazy_load:
223
- load_vqgan_model(args, model_dtype=args.model_dtype)
224
 
225
- with gr.Blocks() as demo:
226
- gr.Markdown(title_markdown)
227
- gr.Markdown(usage_markdown)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  with gr.Row():
230
- with gr.Column():
231
- gr.Markdown("## ��️ Input Image")
232
- input_image = gr.Image(type="pil", label="Upload Image", width=384, height=384)
233
- input_resolution_display = gr.Markdown("")
234
- gr.Examples(
235
- examples=[
236
- ["../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/1.png",],
237
- ["../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/2.png",],
238
- ["../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/3.png",],
239
- ],
240
- inputs=input_image,
241
- label="Example Images",
242
- )
 
 
 
 
 
 
 
 
 
243
 
244
- with gr.Column():
245
- gr.Markdown("## �� Reconstructed Image")
246
- output_image_recon = gr.Image(type="pil", label="Reconstruction", width=384, height=384)
247
- output_resolution_display = gr.Markdown("")
248
-
249
- with gr.Column():
250
- gr.Markdown("## ⚙ Hyperparameters")
251
- # with gr.Row():
252
- short_resolution_dropdown = gr.Dropdown(
253
- choices=[None, 256, 384, 512, 1024],
254
- value=1024,
255
- label="Max Shortest Side"
256
- )
257
- force_upscale_checkbox = gr.Checkbox(label="Force Upscale to Max Shortest Side", value=False)
258
- use_ema_checkbox = gr.Checkbox(label="Use EMA Model", value=False)
259
-
260
- with gr.Accordion("�� Use Diffusion Decoder", open=False):
261
- use_diffusion_checkbox = gr.Checkbox(label="Load Diffusion Decoder", value=False)
262
- diffusion_upsample_checkbox = gr.Checkbox(label="Enable 2x Upsample", value=False)
263
- cfg_slider = gr.Slider(
264
- minimum=cfg_range[0], maximum=cfg_range[1],
265
- step=0.5, value=1.5,
266
- label="CFG Value"
267
  )
268
- step_slider = gr.Slider(
269
- minimum=step_range[0], maximum=step_range[1],
270
- step=1, value=20,
271
- label="Inference Steps"
 
 
 
 
 
 
 
 
 
 
 
 
272
  )
273
- reconstruct_btn = gr.Button("�� Reconstruct", variant="primary")
274
-
275
- def handle_input_image(image):
276
- if image is not None:
277
- image = image.convert("RGB")
278
- w, h = image.size
279
- return image, f"�� **Input Resolution**: {w}x{h}"
280
- return None, ""
281
-
282
- input_image.change(
283
- handle_input_image,
284
- inputs=input_image,
285
- outputs=[input_image, input_resolution_display]
286
- )
287
 
288
- def reconstruct_fn(image, use_ema_flag, short_edge_resolution, force_upscale,
289
- use_diffusion_flag, diffusion_upsample, cfg_value, num_steps):
 
 
 
290
 
291
- if short_edge_resolution is not None:
292
- if force_upscale or min(image.size) > short_edge_resolution:
293
- image = resize_to_shortest_edge(image, int(short_edge_resolution))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- global vq_model
296
- if lazy_load and vq_model is None:
297
- load_vqgan_model(args, model_dtype=args.model_dtype)
 
 
 
298
 
299
- if use_ema_flag:
300
- if not is_ema_model:
301
- load_vqgan_model(args, model_dtype=args.model_dtype, use_ema=True)
302
- logger.info("Switched to EMA checkpoint")
303
- else:
304
- if is_ema_model:
305
- load_vqgan_model(args, model_dtype=args.model_dtype, use_ema=False)
306
- logger.info("Switched to non-EMA checkpoint")
307
-
308
- if use_diffusion_flag:
309
- if diffusion_pipeline is None:
310
- load_diffusion_decoder(args)
311
- recon_image, resolution_str = vqgan_diffusion_decoder_reconstruct(image, diffusion_upsample, cfg_value,
312
- num_steps)
313
- else:
314
- recon_image, resolution_str = vqgan_reconstruct(image)
315
 
316
- return pad_to_square(recon_image, target_size=384), resolution_str
 
 
 
 
 
 
 
 
 
317
 
318
- reconstruct_btn.click(
319
- reconstruct_fn,
320
- inputs=[input_image, use_ema_checkbox, short_resolution_dropdown, force_upscale_checkbox,
321
- use_diffusion_checkbox, diffusion_upsample_checkbox, cfg_slider, step_slider],
322
- outputs=[output_image_recon, output_resolution_display])
 
 
 
 
 
323
 
324
- demo.launch(server_name='0.0.0.0')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
 
327
- # 主函数
328
- def main():
329
  parser = argparse.ArgumentParser()
330
- parser.add_argument("config", type=str)
331
- parser.add_argument("--local_rank", type=int, default=0)
332
- parser.add_argument("--vq-ckpt", type=str, help="ckpt path for vq model")
333
- parser.add_argument("--torch-dtype", type=str, default='fp32')
334
- parser.add_argument("--model-dtype", type=str, default='fp32')
335
- parser.add_argument("--sdxl-decoder-path", type=str, default=None)
336
- parser.add_argument("--verbose", action='store_true')
337
 
338
- args = parser.parse_args()
 
339
 
340
- config = read_config(args.config)
341
- config.vq_ckpt = args.vq_ckpt
342
- config.torch_dtype = args.torch_dtype
343
- config.model_dtype = args.model_dtype
344
- config.verbose = args.verbose
345
- config.sdxl_decoder_path = args.sdxl_decoder_path
346
 
347
- build_gradio_interface(config)
 
 
 
348
 
 
349
 
350
- if __name__ == "__main__":
351
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
 
 
2
  import os
3
+ import traceback
4
+ import logging
5
+ from functools import partial
6
+ from threading import Thread
7
+
8
+ import re # Added for parsing image tokens
9
+
10
  import torch
11
+
12
+ from transformers import TextIteratorStreamer
13
+
14
+ from transformers import AutoModel, AutoProcessor
15
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18
+ datefmt='%Y-%m-%d %H:%M:%S')
19
+ logging.getLogger("http").setLevel(logging.WARNING)
20
+ logging.getLogger("httpx").setLevel(logging.WARNING)
21
+
22
+ import gradio as gr
23
 
24
+ from illume.conversation import default_conversation, conv_templates, SeparatorStyle
25
+ # from conversation import default_conversation, conv_templates, SeparatorStyle
26
 
27
+ # --- Global Variables and Model Loading ---
28
+ model = None # Global variable to hold the loaded ILLUME model
29
+ args = None # Global variable to hold command line args
30
+ streamer = None # Global variable to hold command line args
31
 
32
+ DEFAULT_IMAGE_TOKEN = '<image>'
33
+
34
+ # Define common resolutions
35
+ DEFAULT_RESOLUTIONS = [
36
+ (256, 256), (512, 512), (384, 640), (640, 384), (512, 384),
37
+ (384, 512), (256, 384), (384, 256), (256, 512), (512, 256)
38
+ ]
39
+
40
+ DEFAULT_DIFFUSION_RESOLUTIONS = [
41
+ (512, 512), (1024, 1024), (768, 1280), (1280, 768), (1024, 768),
42
+ (768, 1024), (512, 768), (768, 512), (512, 1024), (1024, 512)
43
+ ]
44
+
45
+ conv_templates_version = 'qwen2'
46
+
47
+
48
+ # inputs = processor(**inputs, return_tensors="pt")
49
+ # inputs = inputs.to(model.device)
50
+
51
+ # # prepare generation arguments
52
+ # gen_kwargs = dict(
53
+ # max_new_tokens=2048, do_sample=True
54
+ # )
55
+
56
+ # image_gen_kwargs = dict(
57
+ # negative_image_prompt_ids=uncond_inputs.input_ids,
58
+ # target_image_resolution=target_image_resolution,
59
+ # guidance_scale=2.0,
60
+ # image_semantic_temperature=1.0,
61
+ # image_semantic_top_k=2048,
62
+ # image_semantic_top_p=1.0,
63
+ # image_pixel_temperature=1.0,
64
+ # image_pixel_top_k=2048 * 3,
65
+ # image_pixel_top_p=1.0,
66
+ # )
67
+
68
+ # gen_kwargs = dict(
69
+ # max_new_tokens=2048, do_sample=False
70
+ # )
71
+
72
+ # # run generation
73
+ # with torch.no_grad():
74
+ # outputs = model.generate(**inputs, **gen_kwargs)
75
+ # outputs = outputs[:, inputs['input_ids'].shape[1]:]
76
+ # outputs_text = processor.batch_decode(outputs, skip_special_tokens=True)
77
+
78
+ # # It extract the image tokens of each image and replace the image tokens with the `image_placeholder` in order.
79
+ # generated_text, image_embed_inds_list, list_image_token_parts = processor.parse_text_image(outputs_text[0],
80
+ # image_placeholder='<image_out>')
81
+
82
+ # # batch decoding the image by using the DualViTok.
83
+ # vq_decoded_images = processor.decode_images(image_embed_inds_list, target_resolution=target_image_resolution)
84
+
85
+ # # batch decoding the image by using the sdxl diffusion decoder.
86
+ # # The output image resolution would be [target_image_resolution[0] * 2, target_image_resolution[1] * 2]
87
+ # diffusion_decoded_images = processor.decode_images(image_embed_inds_list, target_resolution=target_image_resolution,
88
+ # use_diffusion=True, diffusion_cfg_scale=2.0,
89
+ # diffusion_num_inference_steps=20)
90
+
91
+ # vq_decoded_images[0].save('vq_decoded_cat.png')
92
+ # diffusion_decoded_images[0].save('diffusion_decoded_cat.png')
93
+
94
+
95
+ # Adapted from your code
96
+ def check_image_token_num(image_embed_inds, token_nums=[81, 256], identifier=""):
97
+ image_embed_inds_out = []
98
+ if len(image_embed_inds) != len(token_nums):
99
+ logging.error(
100
+ f"{identifier} Mismatch between number of image token levels ({len(image_embed_inds)}) and expected token_nums ({len(token_nums)})")
101
+ # Handle error appropriately - maybe return None or raise exception
102
+ return None # Indicate error
103
+
104
+ for level, (embed_inds, token_num) in enumerate(zip(image_embed_inds, token_nums)):
105
+ if not len(embed_inds) == token_num:
106
+ logging.warning(
107
+ f"{identifier} Level {level} embed_inds length {len(embed_inds)} not equal to expected {token_num}! Padding/truncating.")
108
+ if len(embed_inds) > token_num:
109
+ embed_inds = embed_inds[:token_num]
110
+ elif len(embed_inds) == 0:
111
+ # Handle empty case - perhaps fill with a default token?
112
+ logging.warning(f"{identifier} Level {level} embed_inds is empty. Filling with zeros.")
113
+ embed_inds = [0] * token_num # Or a placeholder token ID
114
+ else:
115
+ # Pad with the last token ID
116
+ embed_inds.extend([embed_inds[-1]] * (token_num - len(embed_inds)))
117
+ image_embed_inds_out.append(embed_inds)
118
+ return image_embed_inds_out
119
+
120
+
121
+ # Adapted from your code
122
+ def pad_sequence(tokenizer, input_ids, batch_first, padding_value):
123
+ # Assuming input_ids is a list of Tensors
124
+ if tokenizer.padding_side == "left":
125
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
126
+ # Manually pad if needed, or use torch utils if input_ids are tensors
127
+ # This assumes input_ids are already tensors
128
+ input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
129
+ if tokenizer.padding_side == "left":
130
+ input_ids_padded = torch.flip(input_ids_padded, [1])
131
+ return input_ids_padded
132
+
133
+
134
+ # --- Gradio UI Functions ---
135
+ no_change_btn = gr.Button()
136
+ enable_btn = gr.Button(interactive=True)
137
+ disable_btn = gr.Button(interactive=False)
138
+ server_error_msg = "**NETWORK ERROR OR SERVER ISSUE. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
139
+ server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS OR IMAGE RESOLUTION AND REGENERATE.**"
140
+
141
+
142
+ def load_demo_refresh_model_list():
143
+ logging.info("load_demo.")
144
+ # Use the conversation template from the loaded model/config
145
+ # Ensure model is loaded before this runs
146
+ if conv_templates_version in conv_templates:
147
+ state = conv_templates[conv_templates_version].copy()
148
+ logging.info(f"Using conversation template: {conv_templates_version}")
149
  else:
150
+ logging.warning(f"Conversation template '{conv_templates_version}' not found. Using default.")
151
+ # Find a default template name from conv_templates or define one
152
+ default_template_name = next(iter(conv_templates)) # Get the first available template
153
+ state = conv_templates[default_template_name].copy()
154
+ return state
155
 
 
 
156
 
157
+ def regenerate(state): # Added resolution_wh
158
+ logging.info("regenerate.")
159
+ if not state.messages or len(state.messages) < 2:
160
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2 # Use state's image
161
 
162
+ # Clear the last assistant message
163
+ state.messages[-1][-1] = None
164
 
165
+ state.skip_next = False
166
+ # Return state, updated chatbot display, refill textbox, keep image
167
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
168
 
169
+
170
+ def http_bot_conditional_then(state, temperature, top_k, top_p,
171
+ image_gen_temperature, image_gen_top_k, image_gen_top_p, max_output_tokens,
172
+ llm_cfg_scale, resolution_wh, use_diffusion, diffusion_cfg_scale,
173
+ diffusion_num_inference_steps):
174
+ if state.mode == 'chat':
175
+ result = yield from http_chat_bot(state, temperature, top_k, top_p, max_output_tokens)
176
  else:
177
+ # result = yield from http_gen_edit_bot(state, temperature, top_k, top_p, max_output_tokens,
178
+ result = yield from http_gen_edit_bot(
179
+ state, temperature, top_k, top_p, image_gen_temperature, image_gen_top_k, image_gen_top_p,
180
+ max_output_tokens,
181
+ llm_cfg_scale, resolution_wh, use_diffusion, diffusion_cfg_scale, diffusion_num_inference_steps)
182
+ return result
183
 
184
 
185
+ def clear_history():
186
+ logging.info("clear_history.")
187
+ state = load_demo_refresh_model_list()
188
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
 
 
 
189
 
190
 
191
+ def add_text(state, text, image, mode):
192
+ global model # Ensure we use the loaded model
 
193
 
194
+ logging.info(f"add_text. Text len: {len(text)}, Image provided: {image is not None}")
195
+ if len(text.strip()) == 0 and image is None:
196
+ state.skip_next = True
197
+ # Keep image in the imagebox if only image was present
198
+ return (state, state.to_gradio_chatbot(), "", image) + (no_change_btn,) * 2
 
199
 
200
+ if state.messages and state.messages[-1][1] and \
201
+ isinstance(state.messages[-1][1], str) and state.messages[-1][1].startswith("**"):
202
+ state = load_demo_refresh_model_list() # Start fresh after error
203
 
204
+ if mode == 'image-generation':
205
+ state = load_demo_refresh_model_list()
 
 
206
 
207
+ image_process_mode = "Default"
 
 
 
 
 
208
 
209
+ if image is not None:
210
+ if state.get_images():
211
+ state = load_demo_refresh_model_list()
 
 
212
 
213
+ if '<image>' not in text:
214
+ text = f'<image>\n{text}'
215
+ text = (text, image, image_process_mode)
216
 
217
+ # Append user message
218
+ state.append_message(state.roles[0], text)
219
+ state.append_message(state.roles[1], None) # Placeholder for assistant
220
+ state.skip_next = False
221
+ state.mode = mode
222
+ logging.info(f"Updated state messages: {len(state.messages)}")
 
 
 
223
 
224
+ # Return new state, updated chatbot, clear textbox, clear imagebox
225
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
226
 
 
 
 
 
 
 
227
 
228
+ def stream_response(model, inputs, streamer, prompt, gen_kwargs):
229
+ thread = Thread(target=model.generate, kwargs=dict(
230
+ streamer=streamer,
231
+ **inputs,
232
+ **gen_kwargs
233
+ ))
234
+ thread.start()
235
+
236
+ generated_text = prompt
237
+
238
+ for new_text in streamer:
239
+ generated_text += new_text
240
+ yield generated_text
241
+
242
+
243
+ # @spaces.GPU
244
+ def http_chat_bot(state, temperature, top_k, top_p, max_new_tokens):
245
+ global model, args, streamer # Use global model and args
246
+ logging.info("http_chat_bot.")
247
+
248
+ if state.skip_next:
249
+ logging.warning("Skipping bot generation. skip_next or model not ready.")
250
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
251
+ return
252
+
253
+ if len(state.messages) < 2:
254
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
255
+ return
256
+
257
+ # --- Prepare Inputs for ILLUME ---
258
+ # Get the full prompt from the conversation state
259
+ prompt = state.get_prompt()
260
+ all_images = state.get_images(return_pil=True)
261
+
262
+ logging.info(f"Raw Prompt: {prompt}")
263
+
264
+ inputs = dict(
265
+ text=prompt,
266
  )
267
+ # Tokenize the prompt
268
+ # run processors
269
+ inputs = processor(**inputs, return_tensors="pt")
270
+ inputs = inputs.to(model.device)
271
+
272
+ # avoid mismatch resolution. process the images alone
273
+ if len(all_images):
274
+ images = []
275
+ for image in all_images:
276
+ images.append(processor.image_processor(image, return_tensors="pt")['pixel_values'].to(model.device))
277
+ pixel_values = images
278
+ inputs['pixel_values'] = pixel_values
279
+
280
+ logging.info(f"Input IDs shape: {inputs.input_ids.shape}")
281
+
282
+ # Set initial response placeholder
283
+ state.messages[-1][-1] = "▌"
284
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
285
+
286
+ # --- MLLM Generation ---’
287
+ gen_kwargs = dict(
288
+ pad_token_id=processor.tokenizer.pad_token_id,
289
+ do_sample=True if temperature > 0 else False, # Controlled by dynamic sampler now, but keep flag
290
+ temperature=temperature,
291
+ top_k=top_k,
292
+ top_p=top_p,
293
+ max_new_tokens=max_new_tokens,
294
+ use_cache=True,
295
+ eos_token_id=processor.tokenizer.eos_token_id # Ensure EOS token is set
296
+ )
297
+ logging.info(f"==== request kwargs====\n{gen_kwargs}")
298
+
299
+ if max_new_tokens < 1:
300
+ state.messages[-1][-1] = "Exceeds max token length. Please start a new conversation, thanks."
301
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
302
+ return
303
+
304
+ state.messages[-1][-1] = "▌"
305
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
306
+
307
+ # Stream output
308
+ try:
309
+ for generated_text in stream_response(model, inputs, streamer, prompt, gen_kwargs):
310
+ output = generated_text[len(prompt):].strip()
311
+ state.messages[-1][-1] = output
312
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
313
+ except Exception as e:
314
+ os.system("nvidia-smi")
315
+ logging.info(traceback.print_exc())
316
+ state.messages[-1][-1] = server_error_msg
317
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
318
+ return (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
319
+
320
+
321
+ def http_gen_edit_bot(state, temperature, top_k, top_p, image_gen_temperature,
322
+ image_gen_top_k, image_gen_top_p, max_output_tokens,
323
+ llm_cfg_scale, resolution_wh, use_diffusion, diffusion_cfg_scale, diffusion_num_inference_steps):
324
+ global model, args # Use global model and args
325
+ logging.info("http_gen_edit_bot.")
326
+
327
+ if state.skip_next:
328
+ logging.warning("Skipping bot generation. skip_next or model not ready.")
329
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
330
+ return
331
+
332
+ if len(state.messages) < 2:
333
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
334
+ return
335
+
336
+ # --- Prepare Inputs for ILLUME ---
337
+ # Get the full prompt from the conversation state
338
+ all_images = state.get_images(return_pil=True)
339
+
340
+ # read resolution from user defined.
341
+ h_str, w_str = resolution_wh.split('x')
342
+ h_out, w_out = int(h_str), int(w_str)
343
+
344
+ if use_diffusion:
345
+ h_out, w_out = (h_out // 2, w_out // 2)
346
+ else:
347
+ h_out, w_out = (h_out, w_out)
348
+ ratio_tag = f"<height_{h_out}><width_{w_out}>"
349
+
350
+ input_state = state.copy()
351
+
352
+ # prepare the text.
353
+ original_image_sizes = None
354
+ if len(all_images):
355
+ # image editing.
356
+ original_image_sizes = [image.size for image in all_images]
357
+ logging.info(f"original_image_sizes: {original_image_sizes}")
358
+
359
+ all_images = [processor.transform_image_nearest_resolution_ratio(image) for image in all_images]
360
+
361
+ inputs = dict(
362
+ images=all_images
363
+ )
364
 
365
+ image_inputs = processor.image_processor(**inputs, return_tensors="pt")
366
+ image_inputs = image_inputs.to(model.device)
367
 
368
+ # overwrite the output resolution
369
+ h, w = image_inputs['pixel_values'].shape[-2:]
370
+ ratio_tag = f"<height_{h}><width_{w}>"
371
+ h_out, w_out = h, w
 
 
372
 
373
+ unconditional_text = f"{ratio_tag}{DEFAULT_IMAGE_TOKEN}\nReconstruct the image according to the given image\n" # of {ratio_tag}
374
 
375
+ instruction, img, image_process_type = input_state.messages[-2][-1]
376
+ instruction = instruction.replace(DEFAULT_IMAGE_TOKEN, '').strip()
377
+ text = f"{ratio_tag}{DEFAULT_IMAGE_TOKEN}\nPlease edit the image according to the instruction: {instruction}\n"
378
+ input_state.messages[-2][-1] = text, img, image_process_type
379
 
380
+ else:
381
+ # image generation
382
+ unconditional_text = f"Generate a random image of {ratio_tag}"
383
+
384
+ text = input_state.messages[-2][-1]
385
+ logging.info(f"Current text is {text}")
386
+ text = f"Generate an image of {ratio_tag}, the content of image is {text}"
387
+ input_state.messages[-2][-1] = text
388
+ logging.info(f"After formating. current text is {text}")
389
+ image_inputs = {}
390
+
391
+ # Calculate ratio tag based on base resolution from config
392
+ logging.info(f"Target Resolution: {h_out}x{w_out}, Ratio Tag: {ratio_tag}")
393
+ target_image_resolution = (h_out, w_out)
394
+ prompt = input_state.get_prompt()
395
+ logging.info(f"Raw Prompt: {prompt}")
396
+
397
+ # Tokenize the prompt
398
+ inputs = dict(
399
+ text=prompt + ratio_tag,
400
+ )
401
 
402
+ inputs = processor(**inputs, return_tensors="pt")
403
+ inputs = inputs.to(model.device)
404
+ inputs.update(image_inputs)
 
 
405
 
406
+ conv_uncond = conv_templates[conv_templates_version].copy()
407
+ conv_uncond.append_message(conv_uncond.roles[0], unconditional_text)
408
+ conv_uncond.append_message(conv_uncond.roles[1], None)
409
+ unconditional_prompt_str = conv_uncond.get_prompt() # Add ratio tag
410
+
411
+ uncond_inputs = dict(
412
+ text=unconditional_prompt_str + ratio_tag,
413
+ images=all_images
414
+ )
415
 
416
+ uncond_inputs = processor(**uncond_inputs, return_tensors="pt")
417
+ uncond_inputs = uncond_inputs.to(model.device)
 
418
 
419
+ logging.info(f"Input IDs shape: {inputs.input_ids.shape}")
420
 
421
+ # Set initial response placeholder
422
+ state.messages[-1][-1] = "image generating..."
423
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
424
 
425
+ gen_kwargs = dict(
426
+ max_new_tokens=2048,
427
+ do_sample=True if temperature > 0 else False,
428
+ temperature=temperature,
429
+ top_k=top_k,
430
+ top_p=top_p,
431
+ )
432
 
433
+ image_gen_kwargs = dict(
434
+ negative_image_prompt_ids=uncond_inputs.input_ids,
435
+ negative_image_prompt_attention_mask=uncond_inputs.attention_mask,
436
+ target_image_resolution=target_image_resolution,
437
+ guidance_scale=llm_cfg_scale,
438
+ image_semantic_temperature=image_gen_temperature,
439
+ image_semantic_top_k=image_gen_top_k,
440
+ image_semantic_top_p=image_gen_top_p,
441
+ image_pixel_temperature=image_gen_temperature,
442
+ image_pixel_top_k=image_gen_top_k * 3,
443
+ image_pixel_top_p=image_gen_top_p,
444
+ )
445
 
446
+ # --- MLLM Generation ---
447
+ generated_image = None
448
+ generated_text = ""
449
+ try:
450
+ with torch.inference_mode(): # Ensure no gradients are calculated
451
+ output_ids = model.generate(
452
+ **inputs,
453
+ use_cache=True,
454
+ **gen_kwargs,
455
+ **image_gen_kwargs
456
+ )
457
+
458
+ output_ids = output_ids[:, inputs['input_ids'].shape[1]:]
459
+
460
+ logging.info(f"Generated output IDs shape: {output_ids.shape}")
461
+
462
+ # Decode the generated IDs, skipping prompt and special tokens
463
+ # We need to decode the full output first to parse image tokens
464
+ # output_ids shape is likely (batch_size, seq_len), batch_size=1 here
465
+ generated_ids = output_ids[0] # Get only generated tokens
466
+ full_output_text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
467
+ logging.info(f"Full decoded output: {full_output_text}")
468
+
469
+ # --- Parse Output for Image Tokens and Text ---
470
+ # Ensure levels are sorted and create the final list
471
+ generated_text, image_embed_inds_list, list_image_token_parts = processor.parse_text_image(full_output_text,
472
+ DEFAULT_IMAGE_TOKEN)
473
+
474
+ assert len(image_embed_inds_list) == 1, 'The number of generated image should be 1.'
475
+ image_embed_inds = image_embed_inds_list[0]
476
+ logging.info(f"The generated text: {full_output_text}")
477
+ logging.info(f"Parsed generated text (image presents as {DEFAULT_IMAGE_TOKEN}): {generated_text}")
478
+
479
+ # Update chat with generated text first
480
+ state.messages[-1][-1] = "vision tokenizer decoding..." # Remove cursor
481
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 # Yield text update
482
+
483
+ # --- Image Detokenization ---
484
+ if any(image_embed_inds):
485
+ logging.info("Image tokens found. Attempting detokenization...")
486
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
487
+
488
+ samples = processor.decode_images(image_embed_inds_list, target_resolution=target_image_resolution,
489
+ use_diffusion=use_diffusion, diffusion_cfg_scale=diffusion_cfg_scale,
490
+ diffusion_num_inference_steps=diffusion_num_inference_steps)
491
+ generated_image = samples[0]
492
+ if use_diffusion:
493
+ logging.info(
494
+ f"Using Diffusion Decoder (cfg: {diffusion_cfg_scale}, steps: {diffusion_num_inference_steps}) Image size: {generated_image.size}")
495
+ else:
496
+ logging.info(f"Using VQ Tokenizer Decoder. Image size: {generated_image.size}")
497
+
498
+ if generated_image:
499
+ if original_image_sizes is not None and len(
500
+ original_image_sizes) == 1: # editing task, unpad and resize image to original size
501
+ original_size = original_image_sizes[0]
502
+ logging.info(f"original size: {original_size}. Output Image size: {generated_image.size}")
503
+ generated_image = processor.unpad_and_resize_back(generated_image, original_size[0], original_size[1])
504
+ logging.info(f"final image size: {generated_image.size}")
505
+ logging.info("Image successfully generated.")
506
+ # <image> is placeholder.
507
+
508
+ logging.info("Image successfully generated.")
509
+ # <image> is placeholder.
510
+ state.messages[-1][-1] = (DEFAULT_IMAGE_TOKEN, [generated_image], list_image_token_parts)
511
+ else:
512
+ # No image tokens generated
513
+ state.messages[-1][-1] = generated_text # Final text without image
514
+
515
+ # Final yield with potentially updated message (text + image)
516
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
517
+
518
+ except torch.cuda.OutOfMemoryError as e:
519
+ logging.error(f"CUDA OutOfMemoryError during generation: {e}\n{traceback.format_exc()}")
520
+ state.messages[-1][-1] = server_oom_msg
521
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
522
+ except Exception as e:
523
+ logging.error(f"Error during model generation or detokenization: {e}\n{traceback.format_exc()}")
524
+ state.messages[-1][-1] = f"{server_error_msg}\n```\n{traceback.format_exc()}\n```" # Show traceback in error
525
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
526
+
527
+ logging.info(f"Final Assistant Message Length: {len(state.messages[-1][-1])}")
528
+
529
+
530
+ def update_resolution_dropdown(diffusion_enabled, current_resolution_str):
531
+ logging.info(f"Updating resolution dropdown. Diffusion: {diffusion_enabled}, Current: {current_resolution_str}")
532
+ current_h_str, current_w_str = current_resolution_str.split('x')
533
+ current_h, current_w = int(current_h_str), int(current_w_str)
534
+
535
+ new_value_str = None
536
+ if diffusion_enabled:
537
+ new_h, new_w = int(current_h) * 2, int(current_w) * 2
538
+
539
+ if (new_h, new_w) not in DEFAULT_DIFFUSION_RESOLUTIONS:
540
+ new_h, new_w = DEFAULT_DIFFUSION_RESOLUTIONS[0]
541
+ new_value_str = f"{new_h}x{new_w}"
542
+ return gr.Dropdown.update(choices=[f'{h}x{w}' for h, w in DEFAULT_DIFFUSION_RESOLUTIONS],
543
+ value=new_value_str)
544
+ else:
545
+ new_h, new_w = int(current_h) // 2, int(current_w) // 2
546
+
547
+ if (new_h, new_w) not in DEFAULT_RESOLUTIONS:
548
+ new_h, new_w = DEFAULT_RESOLUTIONS[0]
549
+ new_value_str = f"{new_h}x{new_w}"
550
+
551
+ return gr.Dropdown.update(choices=[f'{h}x{w}' for h, w in DEFAULT_RESOLUTIONS],
552
+ value=new_value_str)
553
+
554
+
555
+ # --- Gradio Layout ---
556
+ title_markdown = """
557
+ <div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
558
+ <div>
559
+ <h1 style="margin: 0;"> ILLUME+: Illuminating Unified MLLM with Dual Visual Tokenization and Diffusion Refinement</h1>
560
+ <h2 style="margin: 10px 0;">
561
+ Links:
562
+ <a href="https://arxiv.org/abs/2504.01934" target="_blank" rel="noopener noreferrer">Paper</a> |
563
+ <a href="https://github.com/illume-unified-mllm/ILLUME_plus" target="_blank" rel="noopener noreferrer">Code</a> |
564
+ <a href="#" target="_blank" rel="noopener noreferrer">Model</a> |
565
+ <a href="https://illume-unified-mllm.github.io/" target="_blank" rel="noopener noreferrer">Project Page</a>
566
+ </h2>
567
+ <ul style="margin: 20px 0; padding-left: 20px;">
568
+ <li><strong>1.</strong> Enter text and/or upload an image.</li>
569
+ <li><strong>2.</strong> Click the 💬 <strong>Chat</strong> button for image inputted conversations</li>
570
+ <li><strong>3.</strong> Click the 🖼️ <strong>Generate</strong> for image generation and image editing.</li>
571
+ <li><strong>5.</strong> (Optional) Enable Diffusion Decoder for image super resolution decoding.
572
+ <li><strong>4.</strong> Adjust generation parameters if needed.
573
+ <br/><strong>💡 Tip 1:</strong> For better image generation quality, we recommend setting <code>top_k = 2048</code>.
574
+ <br/><strong>💡 Tip 2:</strong> For diffusion decoder, CFG scale of 1.5 or 2.0 is enough.
575
+ </li>
576
+ </ul>
577
+ </div>
578
+ </div>
579
  """
580
 
581
+ tos_markdown = ("""
582
+ ## Terms of use
583
+ By using this service, users are required to agree to the following terms:
584
+ The service is a research preview intended for non-commercial use only. It may generate inaccurate or offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user data for future research. Check the specific license of the ILLUME model and its components.
585
+ """)
586
+
587
+ learn_more_markdown = ("""
588
+ ## Citation
589
+
590
+
591
+ @article{huang2025illume_plus,
592
+ title={ILLUME+: Illuminating Unified MLLM with Dual Visual Tokenization and Diffusion Refinement},
593
+ author={Huang, Runhui and Wang, Chunwei and Yang, Junwei and Lu, Guansong and Yuan, Yunlong and Han, Jianhua and Hou, Lu and Zhang, Wei and Hong, Lanqing and Zhao, Hengshuang and Xu, Hang}
594
+ journal={arXiv preprint arXiv:2504.01934},
595
+ year={2025}
596
+ }
597
+ """)
598
+
599
+ block_css = """
600
+ #buttons button {
601
+ min-width: min(120px,100%);
602
+ }
603
+ .message-row img {
604
+ max-width: 80%;
605
+ max-height: 400px;
606
+ height: auto;
607
+ display: block;
608
+ margin-top: 10px;
609
+ margin-bottom: 5px;
610
+ border-radius: 5px;
611
+ border: 1px solid #e0e0e0; /* Add a light border */
612
+ }
613
+ .avatar-container img {
614
+ padding: 0px !important;
615
+ }
616
+ /* Style for resolution dropdown */
617
+ #resolution_dropdown .gradio-dropdown {
618
+ min-width: 150px !important;
619
+ }
620
+ """
621
 
 
 
 
622
 
623
+ def load_initial_state_and_example1():
624
+ """
625
+ Loads the initial Conversation state and prepares the inputs
626
+ for the first example to populate the UI on startup.
627
+ """
628
+ logging.info("Loading initial state and Example 1 inputs for UI.")
629
+
630
+ # 1. Get the base initial state object
631
+ initial_state = load_demo_refresh_model_list()
632
+ # At this point, initial_state is a Conversation object with empty messages.
633
+ initial_state = 'chat'
634
+
635
+ # 2. Define Example 1 inputs
636
+ image_path = "./examples/example_1.png" # Make sure this path is correct relative to where you run the script
637
+ text_prompt = "Describe this scene in detail."
638
+ image_pil = None
639
+
640
+ # 3. Load the example image
641
+ try:
642
+ # Ensure the image file exists and load it
643
+ if os.path.exists(image_path):
644
+ image_pil = Image.open(image_path)
645
+ logging.info(f"Successfully loaded example image: {image_path}")
646
+ else:
647
+ logging.warning(f"Example image not found at: {image_path}. Image box will be empty.")
648
+ # Optionally provide a placeholder blank image?
649
+ # image_pil = Image.new('RGB', (60, 30), color = 'red') # Example placeholder
650
+ except Exception as e:
651
+ logging.error(f"Error loading example image {image_path}: {e}")
652
+ image_pil = None # Ensure it's None on error
653
+
654
+ # 4. Return values to populate the UI components
655
+ # - state: The initial Conversation object
656
+ # - chatbot: The initial empty chatbot display ([]) derived from the initial state
657
+ # - textbox: The example text prompt
658
+ # - imagebox: The loaded PIL image (or None)
659
+ return initial_state, initial_state.to_gradio_chatbot(), text_prompt, image_pil
660
+
661
+
662
+ def load_initial_state_and_example2():
663
+ """
664
+ Loads the initial Conversation state and prepares the inputs
665
+ for the first example to populate the UI on startup.
666
+ """
667
+ logging.info("Loading initial state and Example 1 inputs for UI.")
668
+
669
+ # 1. Get the base initial state object
670
+ initial_state = load_demo_refresh_model_list()
671
+ # At this point, initial_state is a Conversation object with empty messages.
672
+
673
+ # 2. Define Example 1 inputs
674
+ # text_prompt = "Generate a photorealistic image of an astronaut riding a horse on the moon."
675
+ text_prompt = "Generate an image based on the description: A man with a white beard wearing a deep purple robe with gold crosses and a chain with a cross pendant is seated on a red upholstered chair with a small decorative pillow featuring gold embroidery. He is holding an ornate gold staff."
676
+ text_prompt = "What does a typical scene of a woman enjoying a sunny day by a luxury pool, complete with appropriate attire and refreshment, look like? Please generate the corresponding image."
677
+
678
+ return initial_state, initial_state.to_gradio_chatbot(), text_prompt, None
679
+
680
+
681
+ def build_demo(embed_mode):
682
+ textbox = gr.Textbox(label="Text Input / Prompt", show_label=False,
683
+ placeholder="Enter text prompt. Ask about the image or request image generation...",
684
+ container=False, scale=8)
685
+
686
+ with gr.Blocks(title="ILLUME Demo", theme=gr.themes.Default(), css=block_css) as demo:
687
+ conversation_state = gr.State() # Holds conversation state (instance of illume.conversation.Conversation)
688
+
689
+ if not embed_mode:
690
+ gr.HTML(title_markdown)
691
 
692
  with gr.Row():
693
+ with gr.Column(scale=2):
694
+ imagebox = gr.Image(type="pil", label="Input Image", height=300)
695
+
696
+ # Text Generation Parameters
697
+ with gr.Accordion("Text Generation Parameters", open=True):
698
+ temperature = gr.Slider(
699
+ minimum=0.0, maximum=1.5, value=1.0, step=0.1,
700
+ label="Temperature",
701
+ info="Controls randomness of the output (higher = more diverse)."
702
+ )
703
+ top_k = gr.Slider(
704
+ minimum=1, maximum=4096, value=128, step=1,
705
+ label="Top-K",
706
+ )
707
+ top_p = gr.Slider(
708
+ minimum=0.1, maximum=1.0, value=1.0, step=0.05,
709
+ label="Top-P",
710
+ )
711
+ max_output_tokens = gr.Slider(
712
+ minimum=128, maximum=8192, value=1024, step=128,
713
+ label="Max Output Tokens",
714
+ )
715
 
716
+ # Image Generation Parameters
717
+ with gr.Accordion("Image Generation Parameters", open=True):
718
+ image_gen_temperature = gr.Slider(
719
+ minimum=0.0, maximum=1.5, value=1.0, step=0.1,
720
+ label="Temperature",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  )
722
+ image_gen_top_k = gr.Slider(
723
+ minimum=1, maximum=4096 * 2, value=2048, step=32,
724
+ label="Top-K",
725
+ info="Recommended value for better image generation: 2048."
726
+ )
727
+ image_gen_top_p = gr.Slider(
728
+ minimum=0.1, maximum=1.0, value=1.0, step=0.05,
729
+ label="Top-P",
730
+ )
731
+
732
+ resolution_wh_dropdown = gr.Dropdown(
733
+ [f'{h}x{w}' for h, w in DEFAULT_RESOLUTIONS],
734
+ value="512x512",
735
+ label="Output Resolution (HxW)",
736
+ elem_id="resolution_dropdown",
737
+ info="Select target size for generated images."
738
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
 
740
+ llm_cfg_scale = gr.Slider(
741
+ minimum=1.0, maximum=10.0, value=2.0, step=0.1,
742
+ label="LLM CFG Scale",
743
+ info="Guidance for text-to-image conditioning (higher = stricter to prompt)."
744
+ )
745
 
746
+ with gr.Accordion("Diffusion Decoder (Optional)", open=False):
747
+ use_diffusion_checkbox = gr.Checkbox(
748
+ value=False, interactive=True,
749
+ label="Use diffusion decoder for image generation",
750
+ info="Enable diffusion decoder."
751
+ )
752
+ diffusion_cfg_scale = gr.Slider(
753
+ minimum=1.0, maximum=15.0, value=2.0, step=0.1,
754
+ label="Diffusion CFG Scale",
755
+ info="Guidance strength for diffusion decoder."
756
+ )
757
+ diffusion_num_inference_steps = gr.Slider(
758
+ minimum=5, maximum=100, value=20, step=5,
759
+ label="Diffusion Inference Steps",
760
+ info="Number of steps during denoising."
761
+ )
762
+
763
+ with gr.Column(scale=8):
764
+ chatbot = gr.Chatbot(
765
+ elem_id="chatbot",
766
+ label="ILLUME Chat",
767
+ layout="bubble",
768
+ height=650, # Increased height
769
+ bubble_full_width=False,
770
+ render_markdown=True # Crucial for images
771
+ )
772
+ with gr.Row():
773
+ textbox.render()
774
+ with gr.Row(elem_id="buttons") as button_row:
775
+ chat_btn = gr.Button(value="💬 Chat", variant="primary")
776
+ gen_btn = gr.Button(value="🖼️ Generate", variant="secondary")
777
+ with gr.Row(elem_id="additional-buttons") as button_row_additional:
778
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
779
+ clear_btn = gr.Button(value="🗑️ Clear History", interactive=False)
780
+
781
+ # Update examples for ILLUME
782
+ with gr.Accordion("Examples (Click to Load)", open=True):
783
+ with gr.Row():
784
+ gr.Examples(examples=[
785
+ ["examples/ImageUnderstandingExample/images/1.png",
786
+ "Depict the image in detail."],
787
+ ["examples/ImageUnderstandingExample/images/2.png",
788
+ "What are they doing?"],
789
+ ["examples/ImageUnderstandingExample/images/3.png",
790
+ "What objects are on the table?"],
791
+ ], inputs=[imagebox, textbox], label='Image Understanding Examples')
792
+
793
+ gr.Examples(examples=[
794
+ [None, "a cat with a hat."],
795
+ [None, "a smiling child."],
796
+ [None, "tiger cub playing with soccer ball"],
797
+ [None, "screenshot from a 16 bit platform game in a lush green landscape"],
798
+ [None, "Old car in kandy sri lanka,lake road,flower, bright, sunny, orange sky, photorealistic"],
799
+ [None, "Create a vibrant painting of a tropical beach at sunset in the style of Van Gogh."],
800
+ ], inputs=[imagebox, textbox], label='Image Generation Examples')
801
+
802
+ gr.Examples(examples=[
803
+ ["examples/EditingSingleTurnExample/images/0.jpg",
804
+ "Change the color of the boots to a deep forest green"],
805
+ ["examples/EditingSingleTurnExample/images/1.jpg",
806
+ "Add a hat on the dog"],
807
+ ["examples/EditingSingleTurnExample/images/2.jpg",
808
+ "Remove the dried flowers"],
809
+ ["examples/EditingSingleTurnExample/images/3.jpg",
810
+ "Change it into winter"],
811
+ ["examples/EditingSingleTurnExample/images/4.jpg",
812
+ "Delete the tennis racket from the man’s hand"],
813
+ ["examples/EditingSingleTurnExample/images/5.jpg",
814
+ "Show me this as it would appear in a comic book"],
815
+ ], inputs=[imagebox, textbox], label='Image Editing Examples')
816
+
817
+ if not embed_mode:
818
+ gr.Markdown(tos_markdown)
819
+ gr.Markdown(learn_more_markdown)
820
+
821
+ # Register listeners
822
+ btn_list = [regenerate_btn, clear_btn]
823
+ parameter_chat_inputs = [temperature, top_k, top_p, max_output_tokens]
824
+ parameter_gen_edit_inputs = [temperature, top_k, top_p,
825
+ image_gen_temperature, image_gen_top_k, image_gen_top_p, max_output_tokens,
826
+ llm_cfg_scale, resolution_wh_dropdown,
827
+ use_diffusion_checkbox, diffusion_cfg_scale, diffusion_num_inference_steps]
828
+
829
+ regenerate_btn.click(
830
+ regenerate,
831
+ [conversation_state],
832
+ [conversation_state, chatbot, textbox, imagebox] + btn_list
833
+ ).then(
834
+ http_bot_conditional_then,
835
+ [conversation_state] + parameter_gen_edit_inputs, # Pass state and all params
836
+ [conversation_state, chatbot] + btn_list,
837
+ )
838
 
839
+ clear_btn.click(
840
+ clear_history,
841
+ None,
842
+ [conversation_state, chatbot, textbox, imagebox] + btn_list,
843
+ queue=False
844
+ )
845
 
846
+ # Default use chat.
847
+ textbox.submit(
848
+ partial(add_text, mode="chat"),
849
+ [conversation_state, textbox, imagebox],
850
+ [conversation_state, chatbot, textbox, imagebox] + btn_list,
851
+ queue=False
852
+ ).then(
853
+ http_chat_bot,
854
+ [conversation_state] + parameter_chat_inputs,
855
+ [conversation_state, chatbot] + btn_list,
856
+ )
 
 
 
 
 
857
 
858
+ # Regular Vision-language Chat
859
+ chat_btn.click(partial(add_text, mode="chat"),
860
+ [conversation_state, textbox, imagebox],
861
+ [conversation_state, chatbot, textbox, imagebox] + btn_list,
862
+ queue=False
863
+ ).then(
864
+ http_chat_bot,
865
+ [conversation_state] + parameter_chat_inputs,
866
+ [conversation_state, chatbot] + btn_list,
867
+ )
868
 
869
+ # Image Generation
870
+ gen_btn.click(
871
+ partial(add_text, mode="image-generation"),
872
+ [conversation_state, textbox, imagebox],
873
+ [conversation_state, chatbot, textbox, imagebox] + btn_list
874
+ ).then(
875
+ http_gen_edit_bot,
876
+ [conversation_state] + parameter_gen_edit_inputs,
877
+ [conversation_state, chatbot] + btn_list
878
+ )
879
 
880
+ use_diffusion_checkbox.change(
881
+ fn=update_resolution_dropdown,
882
+ inputs=[use_diffusion_checkbox, resolution_wh_dropdown],
883
+ outputs=[resolution_wh_dropdown],
884
+ queue=False
885
+ )
886
+
887
+ # Load initial state when demo starts
888
+ demo.load(
889
+ load_demo_refresh_model_list,
890
+ None,
891
+ conversation_state,
892
+ queue=False
893
+ )
894
+ return demo
895
 
896
 
897
+ # --- Main Execution Block ---
898
+ if __name__ == "__main__":
899
  parser = argparse.ArgumentParser()
900
+ # --- Add arguments for ILLUME configs and checkpoints ---
901
+ parser.add_argument("--model_name", type=str, default="illume-unified-mllm/illume_plus-qwen-2_5-3b-hf",
902
+ help="Name for builder.")
903
+ parser.add_argument("--torch_dtype", type=str, default='fp32', choices=['fp32', 'bf16', 'fp16'],
904
+ help="Computation data type.")
 
 
905
 
906
+ parser.add_argument("--diffusion_decoder_path", type=str, default='illume-unified-mllm/dualvitok_sdxl_decoder.pt',
907
+ help="Path to Diffusion Decoder checkpoint (.pt). Required if using diffusion.")
908
 
909
+ parser.add_argument("--tokenizer_path", type=str, default='illume-unified-mllm/dualvitok',
910
+ help="Path to Tokenizer config file (e.g., tokenizer_config.py).")
 
 
 
 
911
 
912
+ # --- End ILLUME arguments ---
913
+ parser.add_argument("--share", action="store_true", help="Create a public Gradio share link")
914
+ parser.add_argument("--embed", action="store_true", help="Run in embed mode (minimal UI)")
915
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run on (cuda, cpu).")
916
 
917
+ args = parser.parse_args()
918
 
919
+ # --- Model Loading ---
920
+ # --- Model Loading ---set
921
+ # Set device
922
+ if "cuda" in args.device and torch.cuda.is_available():
923
+ device = args.device
924
+ local_rank = 0 # Assume single GPU for Gradio unless configured otherwise
925
+ torch.cuda.set_device(local_rank) # Set default CUDA device
926
+ else:
927
+ device = "cpu"
928
+ local_rank = -1 # Indicate CPU
929
+ logging.info(f"Using device: {device}")
930
+
931
+ args.torch_dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[args.torch_dtype]
932
+
933
+ # Build the ILLUME model instance
934
+ logging.info("Building ILLUME model...")
935
+ # prepare models and processors
936
+ model = AutoModel.from_pretrained(
937
+ args.model_name,
938
+ # torch_dtype=torch.bfloat16,
939
+ # attn_implementation='flash_attention_2', # OR 'sdpa' for Ascend NPUs
940
+ torch_dtype=args.torch_dtype,
941
+ attn_implementation='sdpa', # OR 'sdpa' for Ascend NPUs
942
+ low_cpu_mem_usage=True,
943
+ trust_remote_code=True).eval().cuda()
944
+ processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True)
945
+
946
+ # set the vision tokenizer for decoding image.
947
+ dualvitok = AutoModel.from_pretrained(args.tokenizer_path,
948
+ torch_dtype=torch.float32,
949
+ trust_remote_code=True).eval().cuda()
950
+ processor.set_vision_tokenizer(dualvitok)
951
+
952
+ # (Optional): set the sdxl diffusion decoder. It will enable upsample 2x image resolution.
953
+ processor.load_diffusion_vision_detokenizer(args.diffusion_decoder_path)
954
+
955
+ # Assign device to model for later use
956
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
957
+
958
+ logging.info("ILLUME model built successfully.")
959
+
960
+ demo = build_demo(args.embed)
961
+ demo.queue(
962
+ max_size=10,
963
+ api_open=False
964
+ ).launch(
965
+ share=args.share,
966
+ server_name="0.0.0.0" # Allow network access if not using --share
967
+ )