Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,351 +1,967 @@
|
|
1 |
import argparse
|
2 |
-
import datetime
|
3 |
-
import json
|
4 |
import os
|
5 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import torch
|
7 |
-
|
|
|
|
|
|
|
8 |
from PIL import Image
|
9 |
-
from tokenizer.sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
|
10 |
-
from torchvision import transforms
|
11 |
-
import logging
|
12 |
-
from utils.registry_utils import Config
|
13 |
-
from tokenizer.builder import build_vq_model
|
14 |
-
from dataset.multi_ratio_dataset import get_image_size, assign_ratio
|
15 |
-
|
16 |
-
|
17 |
-
def read_config(file):
|
18 |
-
# solve config loading conflict when multi-processes
|
19 |
-
import time
|
20 |
-
while True:
|
21 |
-
config = Config.fromfile(file)
|
22 |
-
if len(config) == 0:
|
23 |
-
time.sleep(0.1)
|
24 |
-
continue
|
25 |
-
break
|
26 |
-
return config
|
27 |
-
|
28 |
-
|
29 |
-
def build_logger(name, log_file):
|
30 |
-
logger = logging.getLogger(name)
|
31 |
-
logger.setLevel(logging.INFO)
|
32 |
-
handler = logging.FileHandler(log_file)
|
33 |
-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
34 |
-
handler.setFormatter(formatter)
|
35 |
-
logger.addHandler(handler)
|
36 |
-
return logger
|
37 |
-
|
38 |
-
|
39 |
-
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
40 |
-
|
41 |
-
vq_model = None
|
42 |
-
is_ema_model = False
|
43 |
-
diffusion_pipeline = None
|
44 |
-
lazy_load = False
|
45 |
-
|
46 |
-
# diffusion decoder hyperparameters.
|
47 |
-
resolution_list = [
|
48 |
-
(1024, 1024), (768, 1024), (1024, 768),
|
49 |
-
(512, 2048), (2048, 512), (640, 1920),
|
50 |
-
(1920, 640), (768, 1536),
|
51 |
-
(1536, 768), (768, 1152), (1152, 768)
|
52 |
-
]
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
else:
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
70 |
|
71 |
-
resized_img = img.resize((new_width, new_height))
|
72 |
-
return resized_img
|
73 |
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
|
|
76 |
|
|
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
84 |
else:
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
-
def
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
offset_y = (target_size - image.height) // 2
|
95 |
-
new_img.paste(image, (offset_x, offset_y))
|
96 |
-
return new_img
|
97 |
|
98 |
|
99 |
-
def
|
100 |
-
global
|
101 |
-
vq_model = build_vq_model(args.vq_model)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
logger.info("Convert the model dtype to bfloat16")
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
|
114 |
-
if
|
115 |
-
|
116 |
-
else:
|
117 |
-
ema_state_dict = None
|
118 |
|
119 |
-
|
120 |
-
model_state_dict = checkpoint["model"]
|
121 |
-
elif "state_dict" in checkpoint:
|
122 |
-
model_state_dict = checkpoint["state_dict"]
|
123 |
-
else:
|
124 |
-
model_state_dict = checkpoint
|
125 |
|
126 |
-
if
|
127 |
-
|
128 |
-
|
129 |
-
vq_model.load_state_dict(model_state_dict, strict=True)
|
130 |
-
return vq_model
|
131 |
|
|
|
|
|
|
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
vq_model=vq_model,
|
140 |
-
)
|
141 |
-
diffusion_pipeline.to(vq_model.device)
|
142 |
|
|
|
|
|
143 |
|
144 |
-
def vqgan_diffusion_decoder_reconstruct(input_image, diffusion_upsample, cfg_values, steps):
|
145 |
-
transform = transforms.Compose([
|
146 |
-
transforms.ToTensor(),
|
147 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
148 |
-
])
|
149 |
-
input_tensor = transform(input_image).unsqueeze(0).to(vq_model.device)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
)
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
178 |
-
])
|
179 |
|
180 |
-
|
181 |
|
182 |
-
|
183 |
-
|
|
|
|
|
184 |
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
(quant_detail, _, _) = vq_model.encode(**inputs)
|
192 |
-
reconstructed_image = vq_model.decode(quant_semantic, quant_detail)
|
193 |
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
return output_image, f"�� **Output Resolution**: {org_width}x{org_height}"
|
200 |
|
|
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
"""
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
-
def build_gradio_interface(args):
|
222 |
-
if not lazy_load:
|
223 |
-
load_vqgan_model(args, model_dtype=args.model_dtype)
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
with gr.Row():
|
230 |
-
with gr.Column():
|
231 |
-
gr.
|
232 |
-
|
233 |
-
|
234 |
-
gr.
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
gr.
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
with gr.Column():
|
250 |
-
gr.Markdown("## ⚙ Hyperparameters")
|
251 |
-
# with gr.Row():
|
252 |
-
short_resolution_dropdown = gr.Dropdown(
|
253 |
-
choices=[None, 256, 384, 512, 1024],
|
254 |
-
value=1024,
|
255 |
-
label="Max Shortest Side"
|
256 |
-
)
|
257 |
-
force_upscale_checkbox = gr.Checkbox(label="Force Upscale to Max Shortest Side", value=False)
|
258 |
-
use_ema_checkbox = gr.Checkbox(label="Use EMA Model", value=False)
|
259 |
-
|
260 |
-
with gr.Accordion("�� Use Diffusion Decoder", open=False):
|
261 |
-
use_diffusion_checkbox = gr.Checkbox(label="Load Diffusion Decoder", value=False)
|
262 |
-
diffusion_upsample_checkbox = gr.Checkbox(label="Enable 2x Upsample", value=False)
|
263 |
-
cfg_slider = gr.Slider(
|
264 |
-
minimum=cfg_range[0], maximum=cfg_range[1],
|
265 |
-
step=0.5, value=1.5,
|
266 |
-
label="CFG Value"
|
267 |
)
|
268 |
-
|
269 |
-
minimum=
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
)
|
273 |
-
reconstruct_btn = gr.Button("�� Reconstruct", variant="primary")
|
274 |
-
|
275 |
-
def handle_input_image(image):
|
276 |
-
if image is not None:
|
277 |
-
image = image.convert("RGB")
|
278 |
-
w, h = image.size
|
279 |
-
return image, f"�� **Input Resolution**: {w}x{h}"
|
280 |
-
return None, ""
|
281 |
-
|
282 |
-
input_image.change(
|
283 |
-
handle_input_image,
|
284 |
-
inputs=input_image,
|
285 |
-
outputs=[input_image, input_resolution_display]
|
286 |
-
)
|
287 |
|
288 |
-
|
289 |
-
|
|
|
|
|
|
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
load_diffusion_decoder(args)
|
311 |
-
recon_image, resolution_str = vqgan_diffusion_decoder_reconstruct(image, diffusion_upsample, cfg_value,
|
312 |
-
num_steps)
|
313 |
-
else:
|
314 |
-
recon_image, resolution_str = vqgan_reconstruct(image)
|
315 |
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
|
327 |
-
#
|
328 |
-
|
329 |
parser = argparse.ArgumentParser()
|
330 |
-
|
331 |
-
parser.add_argument("--
|
332 |
-
|
333 |
-
parser.add_argument("--
|
334 |
-
|
335 |
-
parser.add_argument("--sdxl-decoder-path", type=str, default=None)
|
336 |
-
parser.add_argument("--verbose", action='store_true')
|
337 |
|
338 |
-
|
|
|
339 |
|
340 |
-
|
341 |
-
|
342 |
-
config.torch_dtype = args.torch_dtype
|
343 |
-
config.model_dtype = args.model_dtype
|
344 |
-
config.verbose = args.verbose
|
345 |
-
config.sdxl_decoder_path = args.sdxl_decoder_path
|
346 |
|
347 |
-
|
|
|
|
|
|
|
348 |
|
|
|
349 |
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
|
|
|
|
2 |
import os
|
3 |
+
import traceback
|
4 |
+
import logging
|
5 |
+
from functools import partial
|
6 |
+
from threading import Thread
|
7 |
+
|
8 |
+
import re # Added for parsing image tokens
|
9 |
+
|
10 |
import torch
|
11 |
+
|
12 |
+
from transformers import TextIteratorStreamer
|
13 |
+
|
14 |
+
from transformers import AutoModel, AutoProcessor
|
15 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
18 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
19 |
+
logging.getLogger("http").setLevel(logging.WARNING)
|
20 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
21 |
+
|
22 |
+
import gradio as gr
|
23 |
|
24 |
+
from illume.conversation import default_conversation, conv_templates, SeparatorStyle
|
25 |
+
# from conversation import default_conversation, conv_templates, SeparatorStyle
|
26 |
|
27 |
+
# --- Global Variables and Model Loading ---
|
28 |
+
model = None # Global variable to hold the loaded ILLUME model
|
29 |
+
args = None # Global variable to hold command line args
|
30 |
+
streamer = None # Global variable to hold command line args
|
31 |
|
32 |
+
DEFAULT_IMAGE_TOKEN = '<image>'
|
33 |
+
|
34 |
+
# Define common resolutions
|
35 |
+
DEFAULT_RESOLUTIONS = [
|
36 |
+
(256, 256), (512, 512), (384, 640), (640, 384), (512, 384),
|
37 |
+
(384, 512), (256, 384), (384, 256), (256, 512), (512, 256)
|
38 |
+
]
|
39 |
+
|
40 |
+
DEFAULT_DIFFUSION_RESOLUTIONS = [
|
41 |
+
(512, 512), (1024, 1024), (768, 1280), (1280, 768), (1024, 768),
|
42 |
+
(768, 1024), (512, 768), (768, 512), (512, 1024), (1024, 512)
|
43 |
+
]
|
44 |
+
|
45 |
+
conv_templates_version = 'qwen2'
|
46 |
+
|
47 |
+
|
48 |
+
# inputs = processor(**inputs, return_tensors="pt")
|
49 |
+
# inputs = inputs.to(model.device)
|
50 |
+
|
51 |
+
# # prepare generation arguments
|
52 |
+
# gen_kwargs = dict(
|
53 |
+
# max_new_tokens=2048, do_sample=True
|
54 |
+
# )
|
55 |
+
|
56 |
+
# image_gen_kwargs = dict(
|
57 |
+
# negative_image_prompt_ids=uncond_inputs.input_ids,
|
58 |
+
# target_image_resolution=target_image_resolution,
|
59 |
+
# guidance_scale=2.0,
|
60 |
+
# image_semantic_temperature=1.0,
|
61 |
+
# image_semantic_top_k=2048,
|
62 |
+
# image_semantic_top_p=1.0,
|
63 |
+
# image_pixel_temperature=1.0,
|
64 |
+
# image_pixel_top_k=2048 * 3,
|
65 |
+
# image_pixel_top_p=1.0,
|
66 |
+
# )
|
67 |
+
|
68 |
+
# gen_kwargs = dict(
|
69 |
+
# max_new_tokens=2048, do_sample=False
|
70 |
+
# )
|
71 |
+
|
72 |
+
# # run generation
|
73 |
+
# with torch.no_grad():
|
74 |
+
# outputs = model.generate(**inputs, **gen_kwargs)
|
75 |
+
# outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
76 |
+
# outputs_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
77 |
+
|
78 |
+
# # It extract the image tokens of each image and replace the image tokens with the `image_placeholder` in order.
|
79 |
+
# generated_text, image_embed_inds_list, list_image_token_parts = processor.parse_text_image(outputs_text[0],
|
80 |
+
# image_placeholder='<image_out>')
|
81 |
+
|
82 |
+
# # batch decoding the image by using the DualViTok.
|
83 |
+
# vq_decoded_images = processor.decode_images(image_embed_inds_list, target_resolution=target_image_resolution)
|
84 |
+
|
85 |
+
# # batch decoding the image by using the sdxl diffusion decoder.
|
86 |
+
# # The output image resolution would be [target_image_resolution[0] * 2, target_image_resolution[1] * 2]
|
87 |
+
# diffusion_decoded_images = processor.decode_images(image_embed_inds_list, target_resolution=target_image_resolution,
|
88 |
+
# use_diffusion=True, diffusion_cfg_scale=2.0,
|
89 |
+
# diffusion_num_inference_steps=20)
|
90 |
+
|
91 |
+
# vq_decoded_images[0].save('vq_decoded_cat.png')
|
92 |
+
# diffusion_decoded_images[0].save('diffusion_decoded_cat.png')
|
93 |
+
|
94 |
+
|
95 |
+
# Adapted from your code
|
96 |
+
def check_image_token_num(image_embed_inds, token_nums=[81, 256], identifier=""):
|
97 |
+
image_embed_inds_out = []
|
98 |
+
if len(image_embed_inds) != len(token_nums):
|
99 |
+
logging.error(
|
100 |
+
f"{identifier} Mismatch between number of image token levels ({len(image_embed_inds)}) and expected token_nums ({len(token_nums)})")
|
101 |
+
# Handle error appropriately - maybe return None or raise exception
|
102 |
+
return None # Indicate error
|
103 |
+
|
104 |
+
for level, (embed_inds, token_num) in enumerate(zip(image_embed_inds, token_nums)):
|
105 |
+
if not len(embed_inds) == token_num:
|
106 |
+
logging.warning(
|
107 |
+
f"{identifier} Level {level} embed_inds length {len(embed_inds)} not equal to expected {token_num}! Padding/truncating.")
|
108 |
+
if len(embed_inds) > token_num:
|
109 |
+
embed_inds = embed_inds[:token_num]
|
110 |
+
elif len(embed_inds) == 0:
|
111 |
+
# Handle empty case - perhaps fill with a default token?
|
112 |
+
logging.warning(f"{identifier} Level {level} embed_inds is empty. Filling with zeros.")
|
113 |
+
embed_inds = [0] * token_num # Or a placeholder token ID
|
114 |
+
else:
|
115 |
+
# Pad with the last token ID
|
116 |
+
embed_inds.extend([embed_inds[-1]] * (token_num - len(embed_inds)))
|
117 |
+
image_embed_inds_out.append(embed_inds)
|
118 |
+
return image_embed_inds_out
|
119 |
+
|
120 |
+
|
121 |
+
# Adapted from your code
|
122 |
+
def pad_sequence(tokenizer, input_ids, batch_first, padding_value):
|
123 |
+
# Assuming input_ids is a list of Tensors
|
124 |
+
if tokenizer.padding_side == "left":
|
125 |
+
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
126 |
+
# Manually pad if needed, or use torch utils if input_ids are tensors
|
127 |
+
# This assumes input_ids are already tensors
|
128 |
+
input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
|
129 |
+
if tokenizer.padding_side == "left":
|
130 |
+
input_ids_padded = torch.flip(input_ids_padded, [1])
|
131 |
+
return input_ids_padded
|
132 |
+
|
133 |
+
|
134 |
+
# --- Gradio UI Functions ---
|
135 |
+
no_change_btn = gr.Button()
|
136 |
+
enable_btn = gr.Button(interactive=True)
|
137 |
+
disable_btn = gr.Button(interactive=False)
|
138 |
+
server_error_msg = "**NETWORK ERROR OR SERVER ISSUE. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
139 |
+
server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS OR IMAGE RESOLUTION AND REGENERATE.**"
|
140 |
+
|
141 |
+
|
142 |
+
def load_demo_refresh_model_list():
|
143 |
+
logging.info("load_demo.")
|
144 |
+
# Use the conversation template from the loaded model/config
|
145 |
+
# Ensure model is loaded before this runs
|
146 |
+
if conv_templates_version in conv_templates:
|
147 |
+
state = conv_templates[conv_templates_version].copy()
|
148 |
+
logging.info(f"Using conversation template: {conv_templates_version}")
|
149 |
else:
|
150 |
+
logging.warning(f"Conversation template '{conv_templates_version}' not found. Using default.")
|
151 |
+
# Find a default template name from conv_templates or define one
|
152 |
+
default_template_name = next(iter(conv_templates)) # Get the first available template
|
153 |
+
state = conv_templates[default_template_name].copy()
|
154 |
+
return state
|
155 |
|
|
|
|
|
156 |
|
157 |
+
def regenerate(state): # Added resolution_wh
|
158 |
+
logging.info("regenerate.")
|
159 |
+
if not state.messages or len(state.messages) < 2:
|
160 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2 # Use state's image
|
161 |
|
162 |
+
# Clear the last assistant message
|
163 |
+
state.messages[-1][-1] = None
|
164 |
|
165 |
+
state.skip_next = False
|
166 |
+
# Return state, updated chatbot display, refill textbox, keep image
|
167 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
|
168 |
|
169 |
+
|
170 |
+
def http_bot_conditional_then(state, temperature, top_k, top_p,
|
171 |
+
image_gen_temperature, image_gen_top_k, image_gen_top_p, max_output_tokens,
|
172 |
+
llm_cfg_scale, resolution_wh, use_diffusion, diffusion_cfg_scale,
|
173 |
+
diffusion_num_inference_steps):
|
174 |
+
if state.mode == 'chat':
|
175 |
+
result = yield from http_chat_bot(state, temperature, top_k, top_p, max_output_tokens)
|
176 |
else:
|
177 |
+
# result = yield from http_gen_edit_bot(state, temperature, top_k, top_p, max_output_tokens,
|
178 |
+
result = yield from http_gen_edit_bot(
|
179 |
+
state, temperature, top_k, top_p, image_gen_temperature, image_gen_top_k, image_gen_top_p,
|
180 |
+
max_output_tokens,
|
181 |
+
llm_cfg_scale, resolution_wh, use_diffusion, diffusion_cfg_scale, diffusion_num_inference_steps)
|
182 |
+
return result
|
183 |
|
184 |
|
185 |
+
def clear_history():
|
186 |
+
logging.info("clear_history.")
|
187 |
+
state = load_demo_refresh_model_list()
|
188 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
|
|
|
|
|
|
|
189 |
|
190 |
|
191 |
+
def add_text(state, text, image, mode):
|
192 |
+
global model # Ensure we use the loaded model
|
|
|
193 |
|
194 |
+
logging.info(f"add_text. Text len: {len(text)}, Image provided: {image is not None}")
|
195 |
+
if len(text.strip()) == 0 and image is None:
|
196 |
+
state.skip_next = True
|
197 |
+
# Keep image in the imagebox if only image was present
|
198 |
+
return (state, state.to_gradio_chatbot(), "", image) + (no_change_btn,) * 2
|
|
|
199 |
|
200 |
+
if state.messages and state.messages[-1][1] and \
|
201 |
+
isinstance(state.messages[-1][1], str) and state.messages[-1][1].startswith("**"):
|
202 |
+
state = load_demo_refresh_model_list() # Start fresh after error
|
203 |
|
204 |
+
if mode == 'image-generation':
|
205 |
+
state = load_demo_refresh_model_list()
|
|
|
|
|
206 |
|
207 |
+
image_process_mode = "Default"
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
+
if image is not None:
|
210 |
+
if state.get_images():
|
211 |
+
state = load_demo_refresh_model_list()
|
|
|
|
|
212 |
|
213 |
+
if '<image>' not in text:
|
214 |
+
text = f'<image>\n{text}'
|
215 |
+
text = (text, image, image_process_mode)
|
216 |
|
217 |
+
# Append user message
|
218 |
+
state.append_message(state.roles[0], text)
|
219 |
+
state.append_message(state.roles[1], None) # Placeholder for assistant
|
220 |
+
state.skip_next = False
|
221 |
+
state.mode = mode
|
222 |
+
logging.info(f"Updated state messages: {len(state.messages)}")
|
|
|
|
|
|
|
223 |
|
224 |
+
# Return new state, updated chatbot, clear textbox, clear imagebox
|
225 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
def stream_response(model, inputs, streamer, prompt, gen_kwargs):
|
229 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
230 |
+
streamer=streamer,
|
231 |
+
**inputs,
|
232 |
+
**gen_kwargs
|
233 |
+
))
|
234 |
+
thread.start()
|
235 |
+
|
236 |
+
generated_text = prompt
|
237 |
+
|
238 |
+
for new_text in streamer:
|
239 |
+
generated_text += new_text
|
240 |
+
yield generated_text
|
241 |
+
|
242 |
+
|
243 |
+
# @spaces.GPU
|
244 |
+
def http_chat_bot(state, temperature, top_k, top_p, max_new_tokens):
|
245 |
+
global model, args, streamer # Use global model and args
|
246 |
+
logging.info("http_chat_bot.")
|
247 |
+
|
248 |
+
if state.skip_next:
|
249 |
+
logging.warning("Skipping bot generation. skip_next or model not ready.")
|
250 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
251 |
+
return
|
252 |
+
|
253 |
+
if len(state.messages) < 2:
|
254 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
255 |
+
return
|
256 |
+
|
257 |
+
# --- Prepare Inputs for ILLUME ---
|
258 |
+
# Get the full prompt from the conversation state
|
259 |
+
prompt = state.get_prompt()
|
260 |
+
all_images = state.get_images(return_pil=True)
|
261 |
+
|
262 |
+
logging.info(f"Raw Prompt: {prompt}")
|
263 |
+
|
264 |
+
inputs = dict(
|
265 |
+
text=prompt,
|
266 |
)
|
267 |
+
# Tokenize the prompt
|
268 |
+
# run processors
|
269 |
+
inputs = processor(**inputs, return_tensors="pt")
|
270 |
+
inputs = inputs.to(model.device)
|
271 |
+
|
272 |
+
# avoid mismatch resolution. process the images alone
|
273 |
+
if len(all_images):
|
274 |
+
images = []
|
275 |
+
for image in all_images:
|
276 |
+
images.append(processor.image_processor(image, return_tensors="pt")['pixel_values'].to(model.device))
|
277 |
+
pixel_values = images
|
278 |
+
inputs['pixel_values'] = pixel_values
|
279 |
+
|
280 |
+
logging.info(f"Input IDs shape: {inputs.input_ids.shape}")
|
281 |
+
|
282 |
+
# Set initial response placeholder
|
283 |
+
state.messages[-1][-1] = "▌"
|
284 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
285 |
+
|
286 |
+
# --- MLLM Generation ---’
|
287 |
+
gen_kwargs = dict(
|
288 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
289 |
+
do_sample=True if temperature > 0 else False, # Controlled by dynamic sampler now, but keep flag
|
290 |
+
temperature=temperature,
|
291 |
+
top_k=top_k,
|
292 |
+
top_p=top_p,
|
293 |
+
max_new_tokens=max_new_tokens,
|
294 |
+
use_cache=True,
|
295 |
+
eos_token_id=processor.tokenizer.eos_token_id # Ensure EOS token is set
|
296 |
+
)
|
297 |
+
logging.info(f"==== request kwargs====\n{gen_kwargs}")
|
298 |
+
|
299 |
+
if max_new_tokens < 1:
|
300 |
+
state.messages[-1][-1] = "Exceeds max token length. Please start a new conversation, thanks."
|
301 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
302 |
+
return
|
303 |
+
|
304 |
+
state.messages[-1][-1] = "▌"
|
305 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
306 |
+
|
307 |
+
# Stream output
|
308 |
+
try:
|
309 |
+
for generated_text in stream_response(model, inputs, streamer, prompt, gen_kwargs):
|
310 |
+
output = generated_text[len(prompt):].strip()
|
311 |
+
state.messages[-1][-1] = output
|
312 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
313 |
+
except Exception as e:
|
314 |
+
os.system("nvidia-smi")
|
315 |
+
logging.info(traceback.print_exc())
|
316 |
+
state.messages[-1][-1] = server_error_msg
|
317 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
318 |
+
return (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
319 |
+
|
320 |
+
|
321 |
+
def http_gen_edit_bot(state, temperature, top_k, top_p, image_gen_temperature,
|
322 |
+
image_gen_top_k, image_gen_top_p, max_output_tokens,
|
323 |
+
llm_cfg_scale, resolution_wh, use_diffusion, diffusion_cfg_scale, diffusion_num_inference_steps):
|
324 |
+
global model, args # Use global model and args
|
325 |
+
logging.info("http_gen_edit_bot.")
|
326 |
+
|
327 |
+
if state.skip_next:
|
328 |
+
logging.warning("Skipping bot generation. skip_next or model not ready.")
|
329 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
330 |
+
return
|
331 |
+
|
332 |
+
if len(state.messages) < 2:
|
333 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
334 |
+
return
|
335 |
+
|
336 |
+
# --- Prepare Inputs for ILLUME ---
|
337 |
+
# Get the full prompt from the conversation state
|
338 |
+
all_images = state.get_images(return_pil=True)
|
339 |
+
|
340 |
+
# read resolution from user defined.
|
341 |
+
h_str, w_str = resolution_wh.split('x')
|
342 |
+
h_out, w_out = int(h_str), int(w_str)
|
343 |
+
|
344 |
+
if use_diffusion:
|
345 |
+
h_out, w_out = (h_out // 2, w_out // 2)
|
346 |
+
else:
|
347 |
+
h_out, w_out = (h_out, w_out)
|
348 |
+
ratio_tag = f"<height_{h_out}><width_{w_out}>"
|
349 |
+
|
350 |
+
input_state = state.copy()
|
351 |
+
|
352 |
+
# prepare the text.
|
353 |
+
original_image_sizes = None
|
354 |
+
if len(all_images):
|
355 |
+
# image editing.
|
356 |
+
original_image_sizes = [image.size for image in all_images]
|
357 |
+
logging.info(f"original_image_sizes: {original_image_sizes}")
|
358 |
+
|
359 |
+
all_images = [processor.transform_image_nearest_resolution_ratio(image) for image in all_images]
|
360 |
+
|
361 |
+
inputs = dict(
|
362 |
+
images=all_images
|
363 |
+
)
|
364 |
|
365 |
+
image_inputs = processor.image_processor(**inputs, return_tensors="pt")
|
366 |
+
image_inputs = image_inputs.to(model.device)
|
367 |
|
368 |
+
# overwrite the output resolution
|
369 |
+
h, w = image_inputs['pixel_values'].shape[-2:]
|
370 |
+
ratio_tag = f"<height_{h}><width_{w}>"
|
371 |
+
h_out, w_out = h, w
|
|
|
|
|
372 |
|
373 |
+
unconditional_text = f"{ratio_tag}{DEFAULT_IMAGE_TOKEN}\nReconstruct the image according to the given image\n" # of {ratio_tag}
|
374 |
|
375 |
+
instruction, img, image_process_type = input_state.messages[-2][-1]
|
376 |
+
instruction = instruction.replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
377 |
+
text = f"{ratio_tag}{DEFAULT_IMAGE_TOKEN}\nPlease edit the image according to the instruction: {instruction}\n"
|
378 |
+
input_state.messages[-2][-1] = text, img, image_process_type
|
379 |
|
380 |
+
else:
|
381 |
+
# image generation
|
382 |
+
unconditional_text = f"Generate a random image of {ratio_tag}"
|
383 |
+
|
384 |
+
text = input_state.messages[-2][-1]
|
385 |
+
logging.info(f"Current text is {text}")
|
386 |
+
text = f"Generate an image of {ratio_tag}, the content of image is {text}"
|
387 |
+
input_state.messages[-2][-1] = text
|
388 |
+
logging.info(f"After formating. current text is {text}")
|
389 |
+
image_inputs = {}
|
390 |
+
|
391 |
+
# Calculate ratio tag based on base resolution from config
|
392 |
+
logging.info(f"Target Resolution: {h_out}x{w_out}, Ratio Tag: {ratio_tag}")
|
393 |
+
target_image_resolution = (h_out, w_out)
|
394 |
+
prompt = input_state.get_prompt()
|
395 |
+
logging.info(f"Raw Prompt: {prompt}")
|
396 |
+
|
397 |
+
# Tokenize the prompt
|
398 |
+
inputs = dict(
|
399 |
+
text=prompt + ratio_tag,
|
400 |
+
)
|
401 |
|
402 |
+
inputs = processor(**inputs, return_tensors="pt")
|
403 |
+
inputs = inputs.to(model.device)
|
404 |
+
inputs.update(image_inputs)
|
|
|
|
|
405 |
|
406 |
+
conv_uncond = conv_templates[conv_templates_version].copy()
|
407 |
+
conv_uncond.append_message(conv_uncond.roles[0], unconditional_text)
|
408 |
+
conv_uncond.append_message(conv_uncond.roles[1], None)
|
409 |
+
unconditional_prompt_str = conv_uncond.get_prompt() # Add ratio tag
|
410 |
+
|
411 |
+
uncond_inputs = dict(
|
412 |
+
text=unconditional_prompt_str + ratio_tag,
|
413 |
+
images=all_images
|
414 |
+
)
|
415 |
|
416 |
+
uncond_inputs = processor(**uncond_inputs, return_tensors="pt")
|
417 |
+
uncond_inputs = uncond_inputs.to(model.device)
|
|
|
418 |
|
419 |
+
logging.info(f"Input IDs shape: {inputs.input_ids.shape}")
|
420 |
|
421 |
+
# Set initial response placeholder
|
422 |
+
state.messages[-1][-1] = "image generating..."
|
423 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
424 |
|
425 |
+
gen_kwargs = dict(
|
426 |
+
max_new_tokens=2048,
|
427 |
+
do_sample=True if temperature > 0 else False,
|
428 |
+
temperature=temperature,
|
429 |
+
top_k=top_k,
|
430 |
+
top_p=top_p,
|
431 |
+
)
|
432 |
|
433 |
+
image_gen_kwargs = dict(
|
434 |
+
negative_image_prompt_ids=uncond_inputs.input_ids,
|
435 |
+
negative_image_prompt_attention_mask=uncond_inputs.attention_mask,
|
436 |
+
target_image_resolution=target_image_resolution,
|
437 |
+
guidance_scale=llm_cfg_scale,
|
438 |
+
image_semantic_temperature=image_gen_temperature,
|
439 |
+
image_semantic_top_k=image_gen_top_k,
|
440 |
+
image_semantic_top_p=image_gen_top_p,
|
441 |
+
image_pixel_temperature=image_gen_temperature,
|
442 |
+
image_pixel_top_k=image_gen_top_k * 3,
|
443 |
+
image_pixel_top_p=image_gen_top_p,
|
444 |
+
)
|
445 |
|
446 |
+
# --- MLLM Generation ---
|
447 |
+
generated_image = None
|
448 |
+
generated_text = ""
|
449 |
+
try:
|
450 |
+
with torch.inference_mode(): # Ensure no gradients are calculated
|
451 |
+
output_ids = model.generate(
|
452 |
+
**inputs,
|
453 |
+
use_cache=True,
|
454 |
+
**gen_kwargs,
|
455 |
+
**image_gen_kwargs
|
456 |
+
)
|
457 |
+
|
458 |
+
output_ids = output_ids[:, inputs['input_ids'].shape[1]:]
|
459 |
+
|
460 |
+
logging.info(f"Generated output IDs shape: {output_ids.shape}")
|
461 |
+
|
462 |
+
# Decode the generated IDs, skipping prompt and special tokens
|
463 |
+
# We need to decode the full output first to parse image tokens
|
464 |
+
# output_ids shape is likely (batch_size, seq_len), batch_size=1 here
|
465 |
+
generated_ids = output_ids[0] # Get only generated tokens
|
466 |
+
full_output_text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
467 |
+
logging.info(f"Full decoded output: {full_output_text}")
|
468 |
+
|
469 |
+
# --- Parse Output for Image Tokens and Text ---
|
470 |
+
# Ensure levels are sorted and create the final list
|
471 |
+
generated_text, image_embed_inds_list, list_image_token_parts = processor.parse_text_image(full_output_text,
|
472 |
+
DEFAULT_IMAGE_TOKEN)
|
473 |
+
|
474 |
+
assert len(image_embed_inds_list) == 1, 'The number of generated image should be 1.'
|
475 |
+
image_embed_inds = image_embed_inds_list[0]
|
476 |
+
logging.info(f"The generated text: {full_output_text}")
|
477 |
+
logging.info(f"Parsed generated text (image presents as {DEFAULT_IMAGE_TOKEN}): {generated_text}")
|
478 |
+
|
479 |
+
# Update chat with generated text first
|
480 |
+
state.messages[-1][-1] = "vision tokenizer decoding..." # Remove cursor
|
481 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 # Yield text update
|
482 |
+
|
483 |
+
# --- Image Detokenization ---
|
484 |
+
if any(image_embed_inds):
|
485 |
+
logging.info("Image tokens found. Attempting detokenization...")
|
486 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
487 |
+
|
488 |
+
samples = processor.decode_images(image_embed_inds_list, target_resolution=target_image_resolution,
|
489 |
+
use_diffusion=use_diffusion, diffusion_cfg_scale=diffusion_cfg_scale,
|
490 |
+
diffusion_num_inference_steps=diffusion_num_inference_steps)
|
491 |
+
generated_image = samples[0]
|
492 |
+
if use_diffusion:
|
493 |
+
logging.info(
|
494 |
+
f"Using Diffusion Decoder (cfg: {diffusion_cfg_scale}, steps: {diffusion_num_inference_steps}) Image size: {generated_image.size}")
|
495 |
+
else:
|
496 |
+
logging.info(f"Using VQ Tokenizer Decoder. Image size: {generated_image.size}")
|
497 |
+
|
498 |
+
if generated_image:
|
499 |
+
if original_image_sizes is not None and len(
|
500 |
+
original_image_sizes) == 1: # editing task, unpad and resize image to original size
|
501 |
+
original_size = original_image_sizes[0]
|
502 |
+
logging.info(f"original size: {original_size}. Output Image size: {generated_image.size}")
|
503 |
+
generated_image = processor.unpad_and_resize_back(generated_image, original_size[0], original_size[1])
|
504 |
+
logging.info(f"final image size: {generated_image.size}")
|
505 |
+
logging.info("Image successfully generated.")
|
506 |
+
# <image> is placeholder.
|
507 |
+
|
508 |
+
logging.info("Image successfully generated.")
|
509 |
+
# <image> is placeholder.
|
510 |
+
state.messages[-1][-1] = (DEFAULT_IMAGE_TOKEN, [generated_image], list_image_token_parts)
|
511 |
+
else:
|
512 |
+
# No image tokens generated
|
513 |
+
state.messages[-1][-1] = generated_text # Final text without image
|
514 |
+
|
515 |
+
# Final yield with potentially updated message (text + image)
|
516 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
517 |
+
|
518 |
+
except torch.cuda.OutOfMemoryError as e:
|
519 |
+
logging.error(f"CUDA OutOfMemoryError during generation: {e}\n{traceback.format_exc()}")
|
520 |
+
state.messages[-1][-1] = server_oom_msg
|
521 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
522 |
+
except Exception as e:
|
523 |
+
logging.error(f"Error during model generation or detokenization: {e}\n{traceback.format_exc()}")
|
524 |
+
state.messages[-1][-1] = f"{server_error_msg}\n```\n{traceback.format_exc()}\n```" # Show traceback in error
|
525 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
526 |
+
|
527 |
+
logging.info(f"Final Assistant Message Length: {len(state.messages[-1][-1])}")
|
528 |
+
|
529 |
+
|
530 |
+
def update_resolution_dropdown(diffusion_enabled, current_resolution_str):
|
531 |
+
logging.info(f"Updating resolution dropdown. Diffusion: {diffusion_enabled}, Current: {current_resolution_str}")
|
532 |
+
current_h_str, current_w_str = current_resolution_str.split('x')
|
533 |
+
current_h, current_w = int(current_h_str), int(current_w_str)
|
534 |
+
|
535 |
+
new_value_str = None
|
536 |
+
if diffusion_enabled:
|
537 |
+
new_h, new_w = int(current_h) * 2, int(current_w) * 2
|
538 |
+
|
539 |
+
if (new_h, new_w) not in DEFAULT_DIFFUSION_RESOLUTIONS:
|
540 |
+
new_h, new_w = DEFAULT_DIFFUSION_RESOLUTIONS[0]
|
541 |
+
new_value_str = f"{new_h}x{new_w}"
|
542 |
+
return gr.Dropdown.update(choices=[f'{h}x{w}' for h, w in DEFAULT_DIFFUSION_RESOLUTIONS],
|
543 |
+
value=new_value_str)
|
544 |
+
else:
|
545 |
+
new_h, new_w = int(current_h) // 2, int(current_w) // 2
|
546 |
+
|
547 |
+
if (new_h, new_w) not in DEFAULT_RESOLUTIONS:
|
548 |
+
new_h, new_w = DEFAULT_RESOLUTIONS[0]
|
549 |
+
new_value_str = f"{new_h}x{new_w}"
|
550 |
+
|
551 |
+
return gr.Dropdown.update(choices=[f'{h}x{w}' for h, w in DEFAULT_RESOLUTIONS],
|
552 |
+
value=new_value_str)
|
553 |
+
|
554 |
+
|
555 |
+
# --- Gradio Layout ---
|
556 |
+
title_markdown = """
|
557 |
+
<div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
|
558 |
+
<div>
|
559 |
+
<h1 style="margin: 0;"> ILLUME+: Illuminating Unified MLLM with Dual Visual Tokenization and Diffusion Refinement</h1>
|
560 |
+
<h2 style="margin: 10px 0;">
|
561 |
+
Links:
|
562 |
+
<a href="https://arxiv.org/abs/2504.01934" target="_blank" rel="noopener noreferrer">Paper</a> |
|
563 |
+
<a href="https://github.com/illume-unified-mllm/ILLUME_plus" target="_blank" rel="noopener noreferrer">Code</a> |
|
564 |
+
<a href="#" target="_blank" rel="noopener noreferrer">Model</a> |
|
565 |
+
<a href="https://illume-unified-mllm.github.io/" target="_blank" rel="noopener noreferrer">Project Page</a>
|
566 |
+
</h2>
|
567 |
+
<ul style="margin: 20px 0; padding-left: 20px;">
|
568 |
+
<li><strong>1.</strong> Enter text and/or upload an image.</li>
|
569 |
+
<li><strong>2.</strong> Click the 💬 <strong>Chat</strong> button for image inputted conversations</li>
|
570 |
+
<li><strong>3.</strong> Click the 🖼️ <strong>Generate</strong> for image generation and image editing.</li>
|
571 |
+
<li><strong>5.</strong> (Optional) Enable Diffusion Decoder for image super resolution decoding.
|
572 |
+
<li><strong>4.</strong> Adjust generation parameters if needed.
|
573 |
+
<br/><strong>💡 Tip 1:</strong> For better image generation quality, we recommend setting <code>top_k = 2048</code>.
|
574 |
+
<br/><strong>💡 Tip 2:</strong> For diffusion decoder, CFG scale of 1.5 or 2.0 is enough.
|
575 |
+
</li>
|
576 |
+
</ul>
|
577 |
+
</div>
|
578 |
+
</div>
|
579 |
"""
|
580 |
|
581 |
+
tos_markdown = ("""
|
582 |
+
## Terms of use
|
583 |
+
By using this service, users are required to agree to the following terms:
|
584 |
+
The service is a research preview intended for non-commercial use only. It may generate inaccurate or offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user data for future research. Check the specific license of the ILLUME model and its components.
|
585 |
+
""")
|
586 |
+
|
587 |
+
learn_more_markdown = ("""
|
588 |
+
## Citation
|
589 |
+
|
590 |
+
|
591 |
+
@article{huang2025illume_plus,
|
592 |
+
title={ILLUME+: Illuminating Unified MLLM with Dual Visual Tokenization and Diffusion Refinement},
|
593 |
+
author={Huang, Runhui and Wang, Chunwei and Yang, Junwei and Lu, Guansong and Yuan, Yunlong and Han, Jianhua and Hou, Lu and Zhang, Wei and Hong, Lanqing and Zhao, Hengshuang and Xu, Hang}
|
594 |
+
journal={arXiv preprint arXiv:2504.01934},
|
595 |
+
year={2025}
|
596 |
+
}
|
597 |
+
""")
|
598 |
+
|
599 |
+
block_css = """
|
600 |
+
#buttons button {
|
601 |
+
min-width: min(120px,100%);
|
602 |
+
}
|
603 |
+
.message-row img {
|
604 |
+
max-width: 80%;
|
605 |
+
max-height: 400px;
|
606 |
+
height: auto;
|
607 |
+
display: block;
|
608 |
+
margin-top: 10px;
|
609 |
+
margin-bottom: 5px;
|
610 |
+
border-radius: 5px;
|
611 |
+
border: 1px solid #e0e0e0; /* Add a light border */
|
612 |
+
}
|
613 |
+
.avatar-container img {
|
614 |
+
padding: 0px !important;
|
615 |
+
}
|
616 |
+
/* Style for resolution dropdown */
|
617 |
+
#resolution_dropdown .gradio-dropdown {
|
618 |
+
min-width: 150px !important;
|
619 |
+
}
|
620 |
+
"""
|
621 |
|
|
|
|
|
|
|
622 |
|
623 |
+
def load_initial_state_and_example1():
|
624 |
+
"""
|
625 |
+
Loads the initial Conversation state and prepares the inputs
|
626 |
+
for the first example to populate the UI on startup.
|
627 |
+
"""
|
628 |
+
logging.info("Loading initial state and Example 1 inputs for UI.")
|
629 |
+
|
630 |
+
# 1. Get the base initial state object
|
631 |
+
initial_state = load_demo_refresh_model_list()
|
632 |
+
# At this point, initial_state is a Conversation object with empty messages.
|
633 |
+
initial_state = 'chat'
|
634 |
+
|
635 |
+
# 2. Define Example 1 inputs
|
636 |
+
image_path = "./examples/example_1.png" # Make sure this path is correct relative to where you run the script
|
637 |
+
text_prompt = "Describe this scene in detail."
|
638 |
+
image_pil = None
|
639 |
+
|
640 |
+
# 3. Load the example image
|
641 |
+
try:
|
642 |
+
# Ensure the image file exists and load it
|
643 |
+
if os.path.exists(image_path):
|
644 |
+
image_pil = Image.open(image_path)
|
645 |
+
logging.info(f"Successfully loaded example image: {image_path}")
|
646 |
+
else:
|
647 |
+
logging.warning(f"Example image not found at: {image_path}. Image box will be empty.")
|
648 |
+
# Optionally provide a placeholder blank image?
|
649 |
+
# image_pil = Image.new('RGB', (60, 30), color = 'red') # Example placeholder
|
650 |
+
except Exception as e:
|
651 |
+
logging.error(f"Error loading example image {image_path}: {e}")
|
652 |
+
image_pil = None # Ensure it's None on error
|
653 |
+
|
654 |
+
# 4. Return values to populate the UI components
|
655 |
+
# - state: The initial Conversation object
|
656 |
+
# - chatbot: The initial empty chatbot display ([]) derived from the initial state
|
657 |
+
# - textbox: The example text prompt
|
658 |
+
# - imagebox: The loaded PIL image (or None)
|
659 |
+
return initial_state, initial_state.to_gradio_chatbot(), text_prompt, image_pil
|
660 |
+
|
661 |
+
|
662 |
+
def load_initial_state_and_example2():
|
663 |
+
"""
|
664 |
+
Loads the initial Conversation state and prepares the inputs
|
665 |
+
for the first example to populate the UI on startup.
|
666 |
+
"""
|
667 |
+
logging.info("Loading initial state and Example 1 inputs for UI.")
|
668 |
+
|
669 |
+
# 1. Get the base initial state object
|
670 |
+
initial_state = load_demo_refresh_model_list()
|
671 |
+
# At this point, initial_state is a Conversation object with empty messages.
|
672 |
+
|
673 |
+
# 2. Define Example 1 inputs
|
674 |
+
# text_prompt = "Generate a photorealistic image of an astronaut riding a horse on the moon."
|
675 |
+
text_prompt = "Generate an image based on the description: A man with a white beard wearing a deep purple robe with gold crosses and a chain with a cross pendant is seated on a red upholstered chair with a small decorative pillow featuring gold embroidery. He is holding an ornate gold staff."
|
676 |
+
text_prompt = "What does a typical scene of a woman enjoying a sunny day by a luxury pool, complete with appropriate attire and refreshment, look like? Please generate the corresponding image."
|
677 |
+
|
678 |
+
return initial_state, initial_state.to_gradio_chatbot(), text_prompt, None
|
679 |
+
|
680 |
+
|
681 |
+
def build_demo(embed_mode):
|
682 |
+
textbox = gr.Textbox(label="Text Input / Prompt", show_label=False,
|
683 |
+
placeholder="Enter text prompt. Ask about the image or request image generation...",
|
684 |
+
container=False, scale=8)
|
685 |
+
|
686 |
+
with gr.Blocks(title="ILLUME Demo", theme=gr.themes.Default(), css=block_css) as demo:
|
687 |
+
conversation_state = gr.State() # Holds conversation state (instance of illume.conversation.Conversation)
|
688 |
+
|
689 |
+
if not embed_mode:
|
690 |
+
gr.HTML(title_markdown)
|
691 |
|
692 |
with gr.Row():
|
693 |
+
with gr.Column(scale=2):
|
694 |
+
imagebox = gr.Image(type="pil", label="Input Image", height=300)
|
695 |
+
|
696 |
+
# Text Generation Parameters
|
697 |
+
with gr.Accordion("Text Generation Parameters", open=True):
|
698 |
+
temperature = gr.Slider(
|
699 |
+
minimum=0.0, maximum=1.5, value=1.0, step=0.1,
|
700 |
+
label="Temperature",
|
701 |
+
info="Controls randomness of the output (higher = more diverse)."
|
702 |
+
)
|
703 |
+
top_k = gr.Slider(
|
704 |
+
minimum=1, maximum=4096, value=128, step=1,
|
705 |
+
label="Top-K",
|
706 |
+
)
|
707 |
+
top_p = gr.Slider(
|
708 |
+
minimum=0.1, maximum=1.0, value=1.0, step=0.05,
|
709 |
+
label="Top-P",
|
710 |
+
)
|
711 |
+
max_output_tokens = gr.Slider(
|
712 |
+
minimum=128, maximum=8192, value=1024, step=128,
|
713 |
+
label="Max Output Tokens",
|
714 |
+
)
|
715 |
|
716 |
+
# Image Generation Parameters
|
717 |
+
with gr.Accordion("Image Generation Parameters", open=True):
|
718 |
+
image_gen_temperature = gr.Slider(
|
719 |
+
minimum=0.0, maximum=1.5, value=1.0, step=0.1,
|
720 |
+
label="Temperature",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
721 |
)
|
722 |
+
image_gen_top_k = gr.Slider(
|
723 |
+
minimum=1, maximum=4096 * 2, value=2048, step=32,
|
724 |
+
label="Top-K",
|
725 |
+
info="Recommended value for better image generation: 2048."
|
726 |
+
)
|
727 |
+
image_gen_top_p = gr.Slider(
|
728 |
+
minimum=0.1, maximum=1.0, value=1.0, step=0.05,
|
729 |
+
label="Top-P",
|
730 |
+
)
|
731 |
+
|
732 |
+
resolution_wh_dropdown = gr.Dropdown(
|
733 |
+
[f'{h}x{w}' for h, w in DEFAULT_RESOLUTIONS],
|
734 |
+
value="512x512",
|
735 |
+
label="Output Resolution (HxW)",
|
736 |
+
elem_id="resolution_dropdown",
|
737 |
+
info="Select target size for generated images."
|
738 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
739 |
|
740 |
+
llm_cfg_scale = gr.Slider(
|
741 |
+
minimum=1.0, maximum=10.0, value=2.0, step=0.1,
|
742 |
+
label="LLM CFG Scale",
|
743 |
+
info="Guidance for text-to-image conditioning (higher = stricter to prompt)."
|
744 |
+
)
|
745 |
|
746 |
+
with gr.Accordion("Diffusion Decoder (Optional)", open=False):
|
747 |
+
use_diffusion_checkbox = gr.Checkbox(
|
748 |
+
value=False, interactive=True,
|
749 |
+
label="Use diffusion decoder for image generation",
|
750 |
+
info="Enable diffusion decoder."
|
751 |
+
)
|
752 |
+
diffusion_cfg_scale = gr.Slider(
|
753 |
+
minimum=1.0, maximum=15.0, value=2.0, step=0.1,
|
754 |
+
label="Diffusion CFG Scale",
|
755 |
+
info="Guidance strength for diffusion decoder."
|
756 |
+
)
|
757 |
+
diffusion_num_inference_steps = gr.Slider(
|
758 |
+
minimum=5, maximum=100, value=20, step=5,
|
759 |
+
label="Diffusion Inference Steps",
|
760 |
+
info="Number of steps during denoising."
|
761 |
+
)
|
762 |
+
|
763 |
+
with gr.Column(scale=8):
|
764 |
+
chatbot = gr.Chatbot(
|
765 |
+
elem_id="chatbot",
|
766 |
+
label="ILLUME Chat",
|
767 |
+
layout="bubble",
|
768 |
+
height=650, # Increased height
|
769 |
+
bubble_full_width=False,
|
770 |
+
render_markdown=True # Crucial for images
|
771 |
+
)
|
772 |
+
with gr.Row():
|
773 |
+
textbox.render()
|
774 |
+
with gr.Row(elem_id="buttons") as button_row:
|
775 |
+
chat_btn = gr.Button(value="💬 Chat", variant="primary")
|
776 |
+
gen_btn = gr.Button(value="🖼️ Generate", variant="secondary")
|
777 |
+
with gr.Row(elem_id="additional-buttons") as button_row_additional:
|
778 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
779 |
+
clear_btn = gr.Button(value="🗑️ Clear History", interactive=False)
|
780 |
+
|
781 |
+
# Update examples for ILLUME
|
782 |
+
with gr.Accordion("Examples (Click to Load)", open=True):
|
783 |
+
with gr.Row():
|
784 |
+
gr.Examples(examples=[
|
785 |
+
["examples/ImageUnderstandingExample/images/1.png",
|
786 |
+
"Depict the image in detail."],
|
787 |
+
["examples/ImageUnderstandingExample/images/2.png",
|
788 |
+
"What are they doing?"],
|
789 |
+
["examples/ImageUnderstandingExample/images/3.png",
|
790 |
+
"What objects are on the table?"],
|
791 |
+
], inputs=[imagebox, textbox], label='Image Understanding Examples')
|
792 |
+
|
793 |
+
gr.Examples(examples=[
|
794 |
+
[None, "a cat with a hat."],
|
795 |
+
[None, "a smiling child."],
|
796 |
+
[None, "tiger cub playing with soccer ball"],
|
797 |
+
[None, "screenshot from a 16 bit platform game in a lush green landscape"],
|
798 |
+
[None, "Old car in kandy sri lanka,lake road,flower, bright, sunny, orange sky, photorealistic"],
|
799 |
+
[None, "Create a vibrant painting of a tropical beach at sunset in the style of Van Gogh."],
|
800 |
+
], inputs=[imagebox, textbox], label='Image Generation Examples')
|
801 |
+
|
802 |
+
gr.Examples(examples=[
|
803 |
+
["examples/EditingSingleTurnExample/images/0.jpg",
|
804 |
+
"Change the color of the boots to a deep forest green"],
|
805 |
+
["examples/EditingSingleTurnExample/images/1.jpg",
|
806 |
+
"Add a hat on the dog"],
|
807 |
+
["examples/EditingSingleTurnExample/images/2.jpg",
|
808 |
+
"Remove the dried flowers"],
|
809 |
+
["examples/EditingSingleTurnExample/images/3.jpg",
|
810 |
+
"Change it into winter"],
|
811 |
+
["examples/EditingSingleTurnExample/images/4.jpg",
|
812 |
+
"Delete the tennis racket from the man’s hand"],
|
813 |
+
["examples/EditingSingleTurnExample/images/5.jpg",
|
814 |
+
"Show me this as it would appear in a comic book"],
|
815 |
+
], inputs=[imagebox, textbox], label='Image Editing Examples')
|
816 |
+
|
817 |
+
if not embed_mode:
|
818 |
+
gr.Markdown(tos_markdown)
|
819 |
+
gr.Markdown(learn_more_markdown)
|
820 |
+
|
821 |
+
# Register listeners
|
822 |
+
btn_list = [regenerate_btn, clear_btn]
|
823 |
+
parameter_chat_inputs = [temperature, top_k, top_p, max_output_tokens]
|
824 |
+
parameter_gen_edit_inputs = [temperature, top_k, top_p,
|
825 |
+
image_gen_temperature, image_gen_top_k, image_gen_top_p, max_output_tokens,
|
826 |
+
llm_cfg_scale, resolution_wh_dropdown,
|
827 |
+
use_diffusion_checkbox, diffusion_cfg_scale, diffusion_num_inference_steps]
|
828 |
+
|
829 |
+
regenerate_btn.click(
|
830 |
+
regenerate,
|
831 |
+
[conversation_state],
|
832 |
+
[conversation_state, chatbot, textbox, imagebox] + btn_list
|
833 |
+
).then(
|
834 |
+
http_bot_conditional_then,
|
835 |
+
[conversation_state] + parameter_gen_edit_inputs, # Pass state and all params
|
836 |
+
[conversation_state, chatbot] + btn_list,
|
837 |
+
)
|
838 |
|
839 |
+
clear_btn.click(
|
840 |
+
clear_history,
|
841 |
+
None,
|
842 |
+
[conversation_state, chatbot, textbox, imagebox] + btn_list,
|
843 |
+
queue=False
|
844 |
+
)
|
845 |
|
846 |
+
# Default use chat.
|
847 |
+
textbox.submit(
|
848 |
+
partial(add_text, mode="chat"),
|
849 |
+
[conversation_state, textbox, imagebox],
|
850 |
+
[conversation_state, chatbot, textbox, imagebox] + btn_list,
|
851 |
+
queue=False
|
852 |
+
).then(
|
853 |
+
http_chat_bot,
|
854 |
+
[conversation_state] + parameter_chat_inputs,
|
855 |
+
[conversation_state, chatbot] + btn_list,
|
856 |
+
)
|
|
|
|
|
|
|
|
|
|
|
857 |
|
858 |
+
# Regular Vision-language Chat
|
859 |
+
chat_btn.click(partial(add_text, mode="chat"),
|
860 |
+
[conversation_state, textbox, imagebox],
|
861 |
+
[conversation_state, chatbot, textbox, imagebox] + btn_list,
|
862 |
+
queue=False
|
863 |
+
).then(
|
864 |
+
http_chat_bot,
|
865 |
+
[conversation_state] + parameter_chat_inputs,
|
866 |
+
[conversation_state, chatbot] + btn_list,
|
867 |
+
)
|
868 |
|
869 |
+
# Image Generation
|
870 |
+
gen_btn.click(
|
871 |
+
partial(add_text, mode="image-generation"),
|
872 |
+
[conversation_state, textbox, imagebox],
|
873 |
+
[conversation_state, chatbot, textbox, imagebox] + btn_list
|
874 |
+
).then(
|
875 |
+
http_gen_edit_bot,
|
876 |
+
[conversation_state] + parameter_gen_edit_inputs,
|
877 |
+
[conversation_state, chatbot] + btn_list
|
878 |
+
)
|
879 |
|
880 |
+
use_diffusion_checkbox.change(
|
881 |
+
fn=update_resolution_dropdown,
|
882 |
+
inputs=[use_diffusion_checkbox, resolution_wh_dropdown],
|
883 |
+
outputs=[resolution_wh_dropdown],
|
884 |
+
queue=False
|
885 |
+
)
|
886 |
+
|
887 |
+
# Load initial state when demo starts
|
888 |
+
demo.load(
|
889 |
+
load_demo_refresh_model_list,
|
890 |
+
None,
|
891 |
+
conversation_state,
|
892 |
+
queue=False
|
893 |
+
)
|
894 |
+
return demo
|
895 |
|
896 |
|
897 |
+
# --- Main Execution Block ---
|
898 |
+
if __name__ == "__main__":
|
899 |
parser = argparse.ArgumentParser()
|
900 |
+
# --- Add arguments for ILLUME configs and checkpoints ---
|
901 |
+
parser.add_argument("--model_name", type=str, default="illume-unified-mllm/illume_plus-qwen-2_5-3b-hf",
|
902 |
+
help="Name for builder.")
|
903 |
+
parser.add_argument("--torch_dtype", type=str, default='fp32', choices=['fp32', 'bf16', 'fp16'],
|
904 |
+
help="Computation data type.")
|
|
|
|
|
905 |
|
906 |
+
parser.add_argument("--diffusion_decoder_path", type=str, default='illume-unified-mllm/dualvitok_sdxl_decoder.pt',
|
907 |
+
help="Path to Diffusion Decoder checkpoint (.pt). Required if using diffusion.")
|
908 |
|
909 |
+
parser.add_argument("--tokenizer_path", type=str, default='illume-unified-mllm/dualvitok',
|
910 |
+
help="Path to Tokenizer config file (e.g., tokenizer_config.py).")
|
|
|
|
|
|
|
|
|
911 |
|
912 |
+
# --- End ILLUME arguments ---
|
913 |
+
parser.add_argument("--share", action="store_true", help="Create a public Gradio share link")
|
914 |
+
parser.add_argument("--embed", action="store_true", help="Run in embed mode (minimal UI)")
|
915 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run on (cuda, cpu).")
|
916 |
|
917 |
+
args = parser.parse_args()
|
918 |
|
919 |
+
# --- Model Loading ---
|
920 |
+
# --- Model Loading ---set
|
921 |
+
# Set device
|
922 |
+
if "cuda" in args.device and torch.cuda.is_available():
|
923 |
+
device = args.device
|
924 |
+
local_rank = 0 # Assume single GPU for Gradio unless configured otherwise
|
925 |
+
torch.cuda.set_device(local_rank) # Set default CUDA device
|
926 |
+
else:
|
927 |
+
device = "cpu"
|
928 |
+
local_rank = -1 # Indicate CPU
|
929 |
+
logging.info(f"Using device: {device}")
|
930 |
+
|
931 |
+
args.torch_dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[args.torch_dtype]
|
932 |
+
|
933 |
+
# Build the ILLUME model instance
|
934 |
+
logging.info("Building ILLUME model...")
|
935 |
+
# prepare models and processors
|
936 |
+
model = AutoModel.from_pretrained(
|
937 |
+
args.model_name,
|
938 |
+
# torch_dtype=torch.bfloat16,
|
939 |
+
# attn_implementation='flash_attention_2', # OR 'sdpa' for Ascend NPUs
|
940 |
+
torch_dtype=args.torch_dtype,
|
941 |
+
attn_implementation='sdpa', # OR 'sdpa' for Ascend NPUs
|
942 |
+
low_cpu_mem_usage=True,
|
943 |
+
trust_remote_code=True).eval().cuda()
|
944 |
+
processor = AutoProcessor.from_pretrained(args.model_name, trust_remote_code=True)
|
945 |
+
|
946 |
+
# set the vision tokenizer for decoding image.
|
947 |
+
dualvitok = AutoModel.from_pretrained(args.tokenizer_path,
|
948 |
+
torch_dtype=torch.float32,
|
949 |
+
trust_remote_code=True).eval().cuda()
|
950 |
+
processor.set_vision_tokenizer(dualvitok)
|
951 |
+
|
952 |
+
# (Optional): set the sdxl diffusion decoder. It will enable upsample 2x image resolution.
|
953 |
+
processor.load_diffusion_vision_detokenizer(args.diffusion_decoder_path)
|
954 |
+
|
955 |
+
# Assign device to model for later use
|
956 |
+
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
957 |
+
|
958 |
+
logging.info("ILLUME model built successfully.")
|
959 |
+
|
960 |
+
demo = build_demo(args.embed)
|
961 |
+
demo.queue(
|
962 |
+
max_size=10,
|
963 |
+
api_open=False
|
964 |
+
).launch(
|
965 |
+
share=args.share,
|
966 |
+
server_name="0.0.0.0" # Allow network access if not using --share
|
967 |
+
)
|