rp-yu commited on
Commit
cb05be2
·
1 Parent(s): da93475
Files changed (10) hide show
  1. app.py +580 -0
  2. assets/assistant.png +3 -0
  3. assets/human.png +3 -0
  4. constants.py +7 -0
  5. conversation.py +295 -0
  6. gallery/14.jfif +3 -0
  7. gallery/15.PNG +3 -0
  8. gallery/prod_9.jpg +3 -0
  9. model.py +35 -0
  10. utils.py +163 -0
app.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ except ImportError:
4
+ # Local run: define dummy decorator
5
+ class spaces:
6
+ @staticmethod
7
+ def GPU(duration=10):
8
+ def dummy(func):
9
+ return func
10
+ return dummy
11
+
12
+ import argparse
13
+ import json
14
+ import time
15
+
16
+ import gradio as gr
17
+ from filelock import FileLock
18
+ from PIL import Image
19
+ import threading
20
+
21
+ from utils import (
22
+ build_logger,
23
+ server_error_msg,
24
+ violates_moderation,
25
+ moderation_msg,
26
+ get_log_filename,
27
+ )
28
+ from conversation import Conversation
29
+ from model import (
30
+ FullSequenceStreamer,
31
+ get_model,
32
+ )
33
+
34
+ logger = build_logger("dimple", "dimple.log")
35
+
36
+ no_change_btn = gr.Button()
37
+ enable_btn = gr.Button(interactive=True)
38
+ disable_btn = gr.Button(interactive=False)
39
+
40
+
41
+ @spaces.GPU(duration=10)
42
+ def make_zerogpu_happy():
43
+ pass
44
+
45
+
46
+ def write2file(path, content):
47
+ lock = FileLock(f"{path}.lock")
48
+ with lock:
49
+ with open(path, "a") as fout:
50
+ fout.write(content)
51
+
52
+ model, processor = get_model("cuda:0")
53
+
54
+ get_window_url_params = """
55
+ function() {
56
+ const params = new URLSearchParams(window.location.search);
57
+ url_params = Object.fromEntries(params);
58
+ console.log(url_params);
59
+ return url_params;
60
+ }
61
+ """
62
+
63
+
64
+ def init_state(state=None):
65
+ if state is not None:
66
+ del state
67
+ return Conversation()
68
+
69
+ def vote_last_response(state, liked, request: gr.Request):
70
+ conv_data = {
71
+ "tstamp": round(time.time(), 4),
72
+ "like": liked,
73
+ "model": '"rp-yu/Dimple-7B"',
74
+ "state": state.dict(),
75
+ "ip": request.client.host,
76
+ }
77
+ write2file(get_log_filename(), json.dumps(conv_data) + "\n")
78
+
79
+
80
+ def upvote_last_response(state, request: gr.Request):
81
+ logger.info(f"upvote. ip: {request.client.host}")
82
+ vote_last_response(state, True, request)
83
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
84
+ return (textbox,) + (disable_btn,) * 3
85
+
86
+
87
+ def downvote_last_response(state, request: gr.Request):
88
+ logger.info(f"downvote. ip: {request.client.host}")
89
+ vote_last_response(state, False, request)
90
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
91
+ return (textbox,) + (disable_btn,) * 3
92
+
93
+
94
+ def vote_selected_response(
95
+ state, request: gr.Request, data: gr.LikeData
96
+ ):
97
+ logger.info(
98
+ f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
99
+ )
100
+ conv_data = {
101
+ "tstamp": round(time.time(), 4),
102
+ "like": data.liked,
103
+ "index": data.index,
104
+ "model": 'rp-yu/Dimple-7B',
105
+ "state": state.dict(),
106
+ "ip": request.client.host,
107
+ }
108
+ write2file(get_log_filename(), json.dumps(conv_data) + "\n")
109
+ return
110
+
111
+
112
+ def flag_last_response(state, request: gr.Request):
113
+ logger.info(f"flag. ip: {request.client.host}")
114
+ vote_last_response(state, "flag", request)
115
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
116
+ return (textbox,) + (disable_btn,) * 3
117
+
118
+
119
+ def regenerate(state, image_process_mode, request: gr.Request):
120
+ logger.info(f"regenerate. ip: {request.client.host}")
121
+ # state.messages[-1][-1] = None
122
+ state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1)
123
+ prev_human_msg = state.messages[-2]
124
+ if type(prev_human_msg[1]) in (tuple, list):
125
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
126
+ state.skip_next = False
127
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
128
+ return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
129
+
130
+
131
+ def clear_history(request: gr.Request):
132
+ logger.info(f"clear_history. ip: {request.client.host}")
133
+ state = init_state()
134
+ textbox = gr.MultimodalTextbox(value=None, interactive=True)
135
+ return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
136
+
137
+
138
+ def add_text(state, message, system_prompt, request: gr.Request):
139
+ print(f"state: {state}")
140
+ if not state:
141
+ state = init_state()
142
+ images = message.get("files", [])
143
+ text = message.get("text", "").strip()
144
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
145
+ # import pdb; pdb.set_trace()
146
+ textbox = gr.MultimodalTextbox(value=None, interactive=False)
147
+ if len(text) <= 0 and len(images) == 0:
148
+ state.skip_next = True
149
+ return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
150
+ if args.moderate:
151
+ flagged = violates_moderation(text)
152
+ if flagged:
153
+ state.skip_next = True
154
+ textbox = gr.MultimodalTextbox(
155
+ value={"text": moderation_msg}, interactive=True
156
+ )
157
+ return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
158
+ images = [Image.open(path).convert("RGB") for path in images]
159
+
160
+ if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
161
+ state = init_state(state)
162
+ state.set_system_message(system_prompt)
163
+ state.append_message(Conversation.USER, text, images)
164
+ state.skip_next = False
165
+ return (state, state.to_gradio_chatbot(), textbox) + (
166
+ disable_btn,
167
+ ) * 5
168
+
169
+
170
+ def http_bot(
171
+ state,
172
+ temperature,
173
+ top_p,
174
+ p_threshold,
175
+ alg_temp,
176
+ max_new_tokens,
177
+ steps,
178
+ alg,
179
+ ):
180
+ start_tstamp = time.time()
181
+ if hasattr(state, "skip_next") and state.skip_next:
182
+ # This generate call is skipped due to invalid inputs
183
+ yield (
184
+ state,
185
+ state.to_gradio_chatbot(),
186
+ gr.MultimodalTextbox(interactive=False),
187
+ ) + (no_change_btn,) * 5
188
+ return
189
+
190
+ all_images = state.get_images(source=state.USER)
191
+ all_image_paths = [state.save_image(image) for image in all_images]
192
+
193
+ if len(all_images) == 0:
194
+ all_images = None
195
+
196
+ messages = state.get_prompt()
197
+ text = processor.apply_chat_template(
198
+ messages, tokenize=False, add_generation_prompt=True, add_vision_id=False
199
+ )
200
+
201
+ inputs = processor(
202
+ text=text,
203
+ images=all_images,
204
+ videos=None,
205
+ padding="longest",
206
+ return_tensors="pt",
207
+ ).to(model.device)
208
+ input_ids = inputs.pop("input_ids")
209
+
210
+ streamer = FullSequenceStreamer(
211
+ processor.tokenizer,
212
+ timeout=10,
213
+ skip_special_tokens=True,
214
+ )
215
+
216
+ def run_generate():
217
+ output = model.diffusion_generate(
218
+ input_ids,
219
+ max_new_tokens=int(max_new_tokens),
220
+ output_history=True,
221
+ return_dict_in_generate=True,
222
+ steps=int(steps),
223
+ temperature=float(temperature),
224
+ top_p=float(top_p),
225
+ alg=alg,
226
+ alg_temp = float(alg_temp),
227
+ use_cache=True,
228
+ alg_p_threshold=float(p_threshold),
229
+ use_original_confidence=True,
230
+ decoding_pipeline="dim",
231
+ streamer = streamer,
232
+ **inputs
233
+ )
234
+
235
+ thread = threading.Thread(target=run_generate)
236
+ thread.start()
237
+
238
+ logger.info(f"==== wait for first token ====\n")
239
+ state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
240
+ yield (
241
+ state,
242
+ state.to_gradio_chatbot(),
243
+ gr.MultimodalTextbox(interactive=False),
244
+ ) + (disable_btn,) * 5
245
+
246
+ try:
247
+ # Stream output
248
+ for ans in streamer:
249
+ if len(ans) > 1:
250
+ ans = "\n".join(ans)
251
+ else:
252
+ ans = ans[0]
253
+
254
+ state.update_message(Conversation.ASSISTANT, ans, None)
255
+ yield (
256
+ state,
257
+ state.to_gradio_chatbot(),
258
+ gr.MultimodalTextbox(interactive=False),
259
+ ) + (disable_btn,) * 5
260
+ except Exception as e:
261
+ state.update_message(Conversation.ASSISTANT, server_error_msg, None)
262
+ yield (
263
+ state,
264
+ state.to_gradio_chatbot(),
265
+ gr.MultimodalTextbox(interactive=True),
266
+ ) + (
267
+ disable_btn,
268
+ disable_btn,
269
+ disable_btn,
270
+ enable_btn,
271
+ enable_btn,
272
+ )
273
+ return
274
+
275
+ state.end_of_current_turn()
276
+
277
+ yield (
278
+ state,
279
+ state.to_gradio_chatbot(),
280
+ gr.MultimodalTextbox(interactive=True),
281
+ ) + (enable_btn,) * 5
282
+
283
+ finish_tstamp = time.time()
284
+ logger.info(f"{ans}")
285
+ data = {
286
+ "tstamp": round(finish_tstamp, 4),
287
+ "like": None,
288
+ "model": "rp-yu/Dimple-7B",
289
+ "start": round(start_tstamp, 4),
290
+ "finish": round(start_tstamp, 4),
291
+ "state": state.dict(),
292
+ "images": all_image_paths,
293
+ }
294
+ write2file(get_log_filename(), json.dumps(data) + "\n")
295
+
296
+ title_html = """
297
+ <div style="width:100%; max-width:600px; margin:auto;">
298
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/635364b3c41f548fe39db945/T6ffjtAkFkI76QjXmN6iR.png" style="width:100%;">
299
+ <p style="margin:0; text-align:left;">
300
+ Dimple: Discrete Diffusion Multimodal Large Language Model with Parallel Decoding
301
+ </p>
302
+ <a href="https://arxiv.org/abs/">[📜 Dimple Paper]</a><br>
303
+ <a href="https://github.com/yu-rp/Dimple">[🌟 Github]</a><br>
304
+ <a href="https://huggingface.co/rp-yu/Dimple-7B">[🤗 Huggingface Model]</a><br>
305
+ <a href="https://huggingface.co/spaces/rp-yu/dimple">[💬 Huggingface Demo]</a><br>
306
+ </div>
307
+ """
308
+
309
+
310
+ tos_markdown = """
311
+ Acknowledgement: This demo is built on the huggingfcae space of [InternVL](https://huggingface.co/spaces/OpenGVLab/InternVL).
312
+ """
313
+
314
+
315
+ # .gradio-container {margin: 5px 10px 0 10px !important};
316
+ block_css = """
317
+ .gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
318
+ #buttons button {
319
+ min-width: min(120px,100%);
320
+ }
321
+
322
+ .gradient-text {
323
+ font-size: 28px;
324
+ width: auto;
325
+ font-weight: bold;
326
+ background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
327
+ background-clip: text;
328
+ -webkit-background-clip: text;
329
+ color: transparent;
330
+ }
331
+
332
+ .plain-text {
333
+ font-size: 22px;
334
+ width: auto;
335
+ font-weight: bold;
336
+ }
337
+ """
338
+
339
+
340
+ def build_demo():
341
+ textbox = gr.MultimodalTextbox(
342
+ interactive=True,
343
+ file_types=["image"],
344
+ placeholder="Enter message or upload file...",
345
+ show_label=False,
346
+ )
347
+
348
+ with gr.Blocks(
349
+ title="Dimple-7B",
350
+ theme=gr.themes.Default(),
351
+ css=block_css,
352
+ ) as demo:
353
+ state = gr.State()
354
+
355
+ with gr.Row():
356
+ with gr.Column(scale=2):
357
+ gr.HTML(title_html)
358
+
359
+ with gr.Accordion("Settings", open=False) as setting_row:
360
+ system_prompt = gr.Textbox(
361
+ value="You are a helpful assistant.",
362
+ label="System Prompt",
363
+ interactive=True,
364
+ )
365
+ temperature = gr.Slider(
366
+ minimum=0.0,
367
+ maximum=2.0,
368
+ value=0.2,
369
+ step=0.1,
370
+ interactive=True,
371
+ label="Temperature",
372
+ )
373
+ top_p = gr.Slider(
374
+ minimum=0.0,
375
+ maximum=1.0,
376
+ value=0.95,
377
+ step=0.1,
378
+ interactive=True,
379
+ label="Top P",
380
+ )
381
+ alg = gr.Radio(
382
+ choices=["origin", "maskgit_plus", "entropy"],
383
+ value="origin",
384
+ label="Selection Algorithm",
385
+ interactive=True,
386
+ )
387
+ p_threshold = gr.Slider(
388
+ minimum=0.,
389
+ maximum=1.0,
390
+ value=0.95,
391
+ step=0.01,
392
+ interactive=True,
393
+ label="Probability threshold for Confident Decoding",
394
+ )
395
+ alg_temp = gr.Slider(
396
+ minimum=0.0,
397
+ maximum=2.0,
398
+ value=0.2,
399
+ step=0.1,
400
+ interactive=True,
401
+ label="Temperature for Selectiion Algorithm",
402
+ )
403
+ max_new_tokens = gr.Slider(
404
+ minimum=1,
405
+ maximum=128,
406
+ value=64,
407
+ step=2,
408
+ interactive=True,
409
+ label="Max output tokens",
410
+ )
411
+ steps = gr.Slider(
412
+ minimum=1,
413
+ maximum=128,
414
+ value=64,
415
+ step=2,
416
+ interactive=True,
417
+ label="Number of decoding steps",
418
+ )
419
+
420
+ examples = gr.Examples(
421
+ examples=[
422
+ [
423
+ {
424
+ "files": [
425
+ "gallery/14.jfif",
426
+ ],
427
+ "text": "Please help me analyze this picture.",
428
+ }
429
+ ],
430
+ [
431
+ {
432
+ "files": [
433
+ "gallery/prod_9.jpg",
434
+ ],
435
+ "text": "Please help me describe the image.",
436
+ }
437
+ ],
438
+ [
439
+ {
440
+ "files": [
441
+ "gallery/15.PNG",
442
+ ],
443
+ "text": "Please help me analyze this picture.",
444
+ }
445
+ ],
446
+ ],
447
+ inputs=[textbox],
448
+ )
449
+
450
+ with gr.Column(scale=8):
451
+ chatbot = gr.Chatbot(
452
+ elem_id="chatbot",
453
+ label="Dimple-7B",
454
+ height=580,
455
+ show_copy_button=True,
456
+ show_share_button=True,
457
+ avatar_images=[
458
+ "assets/human.png",
459
+ "assets/assistant.png",
460
+ ],
461
+ bubble_full_width=False,
462
+ )
463
+ with gr.Row():
464
+ with gr.Column(scale=8):
465
+ textbox.render()
466
+ with gr.Column(scale=1, min_width=50):
467
+ submit_btn = gr.Button(value="Send", variant="primary")
468
+ with gr.Row(elem_id="buttons") as button_row:
469
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
470
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
471
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
472
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
473
+ regenerate_btn = gr.Button(
474
+ value="🔄 Regenerate", interactive=False
475
+ )
476
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
477
+
478
+ gr.Markdown(tos_markdown)
479
+ url_params = gr.JSON(visible=False)
480
+
481
+ # Register listeners
482
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
483
+ upvote_btn.click(
484
+ upvote_last_response,
485
+ [state],
486
+ [textbox, upvote_btn, downvote_btn, flag_btn],
487
+ )
488
+ downvote_btn.click(
489
+ downvote_last_response,
490
+ [state],
491
+ [textbox, upvote_btn, downvote_btn, flag_btn],
492
+ )
493
+ chatbot.like(
494
+ vote_selected_response,
495
+ [state],
496
+ [],
497
+ )
498
+ flag_btn.click(
499
+ flag_last_response,
500
+ [state],
501
+ [textbox, upvote_btn, downvote_btn, flag_btn],
502
+ )
503
+ regenerate_btn.click(
504
+ regenerate,
505
+ [state, system_prompt],
506
+ [state, chatbot, textbox] + btn_list,
507
+ ).then(
508
+ http_bot,
509
+ [
510
+ state,
511
+ temperature,
512
+ top_p,
513
+ p_threshold,
514
+ alg_temp,
515
+ max_new_tokens,
516
+ steps,
517
+ alg,
518
+ ],
519
+ [state, chatbot, textbox] + btn_list,
520
+ )
521
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
522
+
523
+ textbox.submit(
524
+ add_text,
525
+ [state, textbox, system_prompt],
526
+ [state, chatbot, textbox] + btn_list,
527
+ ).then(
528
+ http_bot,
529
+ [
530
+ state,
531
+ temperature,
532
+ top_p,
533
+ p_threshold,
534
+ alg_temp,
535
+ max_new_tokens,
536
+ steps,
537
+ alg,
538
+ ],
539
+ [state, chatbot, textbox] + btn_list,
540
+ )
541
+ submit_btn.click(
542
+ add_text,
543
+ [state, textbox, system_prompt],
544
+ [state, chatbot, textbox] + btn_list,
545
+ ).then(
546
+ http_bot,
547
+ [
548
+ state,
549
+ temperature,
550
+ top_p,
551
+ p_threshold,
552
+ alg_temp,
553
+ max_new_tokens,
554
+ steps,
555
+ alg,
556
+ ],
557
+ [state, chatbot, textbox] + btn_list,
558
+ )
559
+
560
+ return demo
561
+
562
+
563
+ if __name__ == "__main__":
564
+ parser = argparse.ArgumentParser()
565
+ parser.add_argument("--host", type=str, default="0.0.0.0")
566
+ parser.add_argument("--port", type=int, default=7860)
567
+ parser.add_argument("--concurrency-count", type=int, default=10)
568
+ parser.add_argument("--share", action="store_true")
569
+ parser.add_argument("--moderate", action="store_true")
570
+ args = parser.parse_args()
571
+ logger.info(f"args: {args}")
572
+
573
+ logger.info(args)
574
+ demo = build_demo()
575
+ demo.queue(api_open=False).launch(
576
+ server_name=args.host,
577
+ server_port=args.port,
578
+ share=args.share,
579
+ max_threads=args.concurrency_count,
580
+ )
assets/assistant.png ADDED

Git LFS Details

  • SHA256: 686325b01ec18e03afdaed80d623e5dce6300c8c7c27de106f298b7a8eeecca9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.83 MB
assets/human.png ADDED

Git LFS Details

  • SHA256: 241c11c27fd4e5ff9756851845759f85ad5f89d68b8266083a2f64b70cfc191c
  • Pointer size: 130 Bytes
  • Size of remote file: 47.2 kB
constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Dimple
3
+ # Copyright (c) 2025 Dimple Team
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ LOGDIR = 'logs/'
conversation.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dataclasses
3
+ import base64
4
+ import copy
5
+ import hashlib
6
+ import datetime
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ from typing import Any, List, Dict, Union
10
+ from dataclasses import field
11
+
12
+ from utils import LOGDIR
13
+
14
+
15
+ def pil2base64(img: Image.Image) -> str:
16
+ buffered = BytesIO()
17
+ img.save(buffered, format="PNG")
18
+ return base64.b64encode(buffered.getvalue()).decode()
19
+
20
+
21
+ def resize_img(img: Image.Image, max_len: int, min_len: int) -> Image.Image:
22
+ max_hw, min_hw = max(img.size), min(img.size)
23
+ aspect_ratio = max_hw / min_hw
24
+ # max_len, min_len = 800, 400
25
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
26
+ longest_edge = int(shortest_edge * aspect_ratio)
27
+ W, H = img.size
28
+ if H > W:
29
+ H, W = longest_edge, shortest_edge
30
+ else:
31
+ H, W = shortest_edge, longest_edge
32
+ return img.resize((W, H))
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Conversation:
37
+ """A class that keeps all conversation history."""
38
+
39
+ SYSTEM = "system"
40
+ USER = "user"
41
+ ASSISTANT = "assistant"
42
+
43
+ roles: List[str] = field(
44
+ default_factory=lambda: [
45
+ Conversation.SYSTEM,
46
+ Conversation.USER,
47
+ Conversation.ASSISTANT,
48
+ ]
49
+ )
50
+ mandatory_system_message = ""
51
+ system_message: str = "You are a helpful assistant."
52
+ messages: List[Dict[str, Any]] = field(default_factory=lambda: [])
53
+ max_image_limit: int = 2
54
+ skip_next: bool = False
55
+ streaming_placeholder: str = "•••"
56
+
57
+ def get_system_message(self):
58
+ if len(self.mandatory_system_message) == 0:
59
+ return self.system_message
60
+ else:
61
+ return self.mandatory_system_message + "\n\n" + self.system_message
62
+
63
+ def set_system_message(self, system_message: str):
64
+ self.system_message = system_message
65
+ return self
66
+
67
+ def get_prompt(self):
68
+ send_messages = [
69
+ {
70
+ "role": "system",
71
+ "content": self.get_system_message(),
72
+ }
73
+ ]
74
+ for message in self.messages:
75
+ if message["role"] == self.USER:
76
+ user_message = {
77
+ "role": self.USER,
78
+ "content": message["content"],
79
+ }
80
+ if "image" in message:
81
+ user_message["image"] = []
82
+ for image in message["image"]:
83
+ user_message["image"].append(pil2base64(image))
84
+
85
+ content = [{"type": "text", "text": message["content"]}]
86
+ for image_base64 in user_message["image"]:
87
+ content.append({
88
+ "type": "image",
89
+ "image": f"data:image/jpeg;base64,{image_base64}"
90
+ })
91
+ send_messages.append({'role': self.USER, 'content': content})
92
+ else:
93
+ send_messages.append(user_message)
94
+ elif message["role"] == self.ASSISTANT:
95
+ send_messages.append(
96
+ {"role": self.ASSISTANT, "content": message["content"]}
97
+ )
98
+ elif message["role"] == self.SYSTEM:
99
+ send_messages.append(
100
+ {
101
+ "role": self.SYSTEM,
102
+ "content": message["content"],
103
+ }
104
+ )
105
+ else:
106
+ raise ValueError(f"Invalid role: {message['role']}")
107
+ return send_messages
108
+
109
+ def append_message(
110
+ self,
111
+ role,
112
+ content,
113
+ image_list=None,
114
+ ):
115
+ self.messages.append(
116
+ {
117
+ "role": role,
118
+ "content": content,
119
+ "image": [] if image_list is None else image_list,
120
+ # "filenames": save_filenames,
121
+ }
122
+ )
123
+
124
+ def get_images(
125
+ self,
126
+ return_copy=False,
127
+ return_base64=False,
128
+ source: Union[str, None] = None,
129
+ ):
130
+ assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {soure}"
131
+ images = []
132
+ for i, msg in enumerate(self.messages):
133
+ if source and msg["role"] != source:
134
+ continue
135
+
136
+ for image in msg.get("image", []):
137
+ # org_image = [i.copy() for i in image]
138
+ if return_copy:
139
+ image = image.copy()
140
+
141
+ if return_base64:
142
+ image = pil2base64(image)
143
+
144
+ images.append(image)
145
+
146
+ return images
147
+
148
+ def to_gradio_chatbot(self):
149
+ ret = []
150
+ for i, msg in enumerate(self.messages):
151
+ if msg["role"] == self.SYSTEM:
152
+ continue
153
+
154
+ alt_str = (
155
+ "user upload image" if msg["role"] == self.USER else "output image"
156
+ )
157
+ image = msg.get("image", [])
158
+ if not isinstance(image, list):
159
+ images = [image]
160
+ else:
161
+ images = image
162
+
163
+ img_str_list = []
164
+ for i in range(len(images)):
165
+ image = resize_img(
166
+ images[i],
167
+ 400,
168
+ 200,
169
+ )
170
+ img_b64_str = pil2base64(image)
171
+ W, H = image.size
172
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" style="width: {W}px; max-width:none; max-height:none"></img>'
173
+ # img_str = (
174
+ # f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" />'
175
+ # )
176
+ img_str_list.append(img_str)
177
+
178
+ if ('\[' in msg["content"] and '\]' in msg["content"]) or ('\(' in msg["content"] and '\)' in msg["content"]):
179
+ # 行内公式 or 行间公式
180
+ content = msg["content"].replace('\[', '$$').replace('\]', '$$').replace('\(', '$$').replace('\)', '$$')
181
+ content = content.split('$$')
182
+ for i in range(len(content)):
183
+ if i % 2:
184
+ content[i] = content[i].strip()
185
+ content = '$$'.join(content)
186
+ print('content:', content)
187
+ # content = (
188
+ # r"<span>" + content + r"</span>"
189
+ # r"<script type='text/javascript'>"
190
+ # r"MathJax.typesetPromise();"
191
+ # r"</script>"
192
+ # )
193
+ else:
194
+ content = msg["content"]
195
+ if msg["role"] == self.USER:
196
+ msg_str = " ".join(img_str_list) + content
197
+ ret.append([msg_str, None])
198
+ else:
199
+ msg_str = content + " ".join(img_str_list)
200
+ ret[-1][-1] = msg_str
201
+ return ret
202
+
203
+ def update_message(self, role, content, image=None, idx=-1):
204
+ assert len(self.messages) > 0, "No message in the conversation."
205
+
206
+ idx = (idx + len(self.messages)) % len(self.messages)
207
+
208
+ assert (
209
+ self.messages[idx]["role"] == role
210
+ ), f"Role mismatch: {role} vs {self.messages[idx]['role']}"
211
+
212
+ self.messages[idx]["content"] = content
213
+ if image is not None:
214
+ if image not in self.messages[idx]["image"]:
215
+ self.messages[idx]["image"] = []
216
+ if not isinstance(image, list):
217
+ image = [image]
218
+ self.messages[idx]["image"].extend(image)
219
+
220
+ def return_last_message(self):
221
+ return self.messages[-1]["content"]
222
+
223
+ def end_of_current_turn(self):
224
+ assert len(self.messages) > 0, "No message in the conversation."
225
+ assert (
226
+ self.messages[-1]["role"] == self.ASSISTANT
227
+ ), f"It should end with the message from assistant instead of {self.messages[-1]['role']}."
228
+
229
+ if self.messages[-1]["content"][-1] != self.streaming_placeholder:
230
+ return
231
+
232
+ self.update_message(self.ASSISTANT, self.messages[-1]["content"][:-1], None)
233
+
234
+ def copy(self):
235
+ return Conversation(
236
+ mandatory_system_message=self.mandatory_system_message,
237
+ system_message=self.system_message,
238
+ roles=copy.deepcopy(self.roles),
239
+ messages=copy.deepcopy(self.messages),
240
+ )
241
+
242
+ def dict(self):
243
+ """
244
+ all_images = state.get_images()
245
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
246
+ t = datetime.datetime.now()
247
+ for image, hash in zip(all_images, all_image_hash):
248
+ filename = os.path.join(
249
+ LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
250
+ )
251
+ if not os.path.isfile(filename):
252
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
253
+ image.save(filename)
254
+ """
255
+ messages = []
256
+ for message in self.messages:
257
+ images = []
258
+ for image in message.get("image", []):
259
+ filename = self.save_image(image)
260
+ images.append(filename)
261
+
262
+ messages.append(
263
+ {
264
+ "role": message["role"],
265
+ "content": message["content"],
266
+ "image": images,
267
+ }
268
+ )
269
+ if len(images) == 0:
270
+ messages[-1].pop("image")
271
+
272
+ return {
273
+ "mandatory_system_message": self.mandatory_system_message,
274
+ "system_message": self.system_message,
275
+ "roles": self.roles,
276
+ "messages": messages,
277
+ }
278
+
279
+ def save_image(self, image: Image.Image) -> str:
280
+ t = datetime.datetime.now()
281
+ image_hash = hashlib.md5(image.tobytes()).hexdigest()
282
+ filename = os.path.join(
283
+ LOGDIR,
284
+ "serve_images",
285
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
286
+ f"{image_hash}.jpg",
287
+ )
288
+ if not os.path.isfile(filename):
289
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
290
+ image.save(filename)
291
+
292
+ return filename
293
+
294
+
295
+
gallery/14.jfif ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43e2db149258578de2d456c951d534278dbf3cd8f1987bf67219acf966139556
3
+ size 112403
gallery/15.PNG ADDED

Git LFS Details

  • SHA256: 6750ff31a7af7d6433432dfd9420e2c90ad9c3999490bb1200540dcb3913a882
  • Pointer size: 131 Bytes
  • Size of remote file: 202 kB
gallery/prod_9.jpg ADDED

Git LFS Details

  • SHA256: c6ab05f3070c946ed6c2bc67e6755fbf64082ed4b1ace734f02cfd721180f338
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModel, TextIteratorStreamer
4
+
5
+ class FullSequenceStreamer(TextIteratorStreamer):
6
+ def __init__(self, tokenizer, **kwargs):
7
+ super().__init__(tokenizer, **kwargs)
8
+
9
+ def put(self, value, stream_end=False):
10
+ # Assume full token_ids are passed in every time
11
+ decoded = self.tokenizer.batch_decode(value, **self.decode_kwargs)
12
+ self.text_queue.put(decoded)
13
+ if stream_end:
14
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
15
+
16
+ def end(self):
17
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
18
+
19
+ def get_model(device):
20
+
21
+ model_name = "rp-yu/Dimple-7B"
22
+ processor = AutoProcessor.from_pretrained(
23
+ model_name,
24
+ trust_remote_code=True
25
+ )
26
+ model = AutoModel.from_pretrained(
27
+ model_name,
28
+ torch_dtype=torch.bfloat16,
29
+ trust_remote_code=True,
30
+ )
31
+ model = model.eval()
32
+ model = model.to(device)
33
+
34
+ return model, processor
35
+
utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ast import Dict
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+ import base64
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import json
10
+ import requests
11
+ from constants import LOGDIR
12
+ import datetime
13
+
14
+ server_error_msg = (
15
+ "**GENERATION ERROR. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
16
+ )
17
+ moderation_msg = (
18
+ "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
19
+ )
20
+
21
+ handler = None
22
+
23
+
24
+ def build_logger(logger_name, logger_filename):
25
+ global handler
26
+
27
+ formatter = logging.Formatter(
28
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
+ )
31
+
32
+ # Set the format of root handlers
33
+ if not logging.getLogger().handlers:
34
+ logging.basicConfig(level=logging.INFO)
35
+ logging.getLogger().handlers[0].setFormatter(formatter)
36
+
37
+ # Redirect stdout and stderr to loggers
38
+ stdout_logger = logging.getLogger("stdout")
39
+ stdout_logger.setLevel(logging.INFO)
40
+ sl = StreamToLogger(stdout_logger, logging.INFO)
41
+ sys.stdout = sl
42
+
43
+ stderr_logger = logging.getLogger("stderr")
44
+ stderr_logger.setLevel(logging.ERROR)
45
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
46
+ sys.stderr = sl
47
+
48
+ # Get logger
49
+ logger = logging.getLogger(logger_name)
50
+ logger.setLevel(logging.INFO)
51
+
52
+ # Add a file handler for all loggers
53
+ if handler is None:
54
+ os.makedirs(LOGDIR, exist_ok=True)
55
+ filename = os.path.join(LOGDIR, logger_filename)
56
+ handler = logging.handlers.TimedRotatingFileHandler(
57
+ filename, when="D", utc=True
58
+ )
59
+ handler.setFormatter(formatter)
60
+
61
+ for name, item in logging.root.manager.loggerDict.items():
62
+ if isinstance(item, logging.Logger):
63
+ item.addHandler(handler)
64
+
65
+ return logger
66
+
67
+
68
+ class StreamToLogger(object):
69
+ """
70
+ Fake file-like stream object that redirects writes to a logger instance.
71
+ """
72
+
73
+ def __init__(self, logger, log_level=logging.INFO):
74
+ self.terminal = sys.stdout
75
+ self.logger = logger
76
+ self.log_level = log_level
77
+ self.linebuf = ""
78
+
79
+ def __getattr__(self, attr):
80
+ return getattr(self.terminal, attr)
81
+
82
+ def write(self, buf):
83
+ temp_linebuf = self.linebuf + buf
84
+ self.linebuf = ""
85
+ for line in temp_linebuf.splitlines(True):
86
+ # From the io.TextIOWrapper docs:
87
+ # On output, if newline is None, any '\n' characters written
88
+ # are translated to the system default line separator.
89
+ # By default sys.stdout.write() expects '\n' newlines and then
90
+ # translates them so this is still cross platform.
91
+ if line[-1] == "\n":
92
+ self.logger.log(self.log_level, line.rstrip())
93
+ else:
94
+ self.linebuf += line
95
+
96
+ def flush(self):
97
+ if self.linebuf != "":
98
+ self.logger.log(self.log_level, self.linebuf.rstrip())
99
+ self.linebuf = ""
100
+
101
+
102
+ def disable_torch_init():
103
+ """
104
+ Disable the redundant torch default initialization to accelerate model creation.
105
+ """
106
+ import torch
107
+
108
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
109
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
110
+
111
+
112
+ def violates_moderation(text):
113
+ """
114
+ Check whether the text violates OpenAI moderation API.
115
+ """
116
+ url = "https://api.openai.com/v1/moderations"
117
+ headers = {
118
+ "Content-Type": "application/json",
119
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
120
+ }
121
+ text = text.replace("\n", "")
122
+ data = "{" + '"input": ' + f'"{text}"' + "}"
123
+ data = data.encode("utf-8")
124
+ try:
125
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
126
+ flagged = ret.json()["results"][0]["flagged"]
127
+ except requests.exceptions.RequestException as e:
128
+ flagged = False
129
+ except KeyError as e:
130
+ flagged = False
131
+
132
+ return flagged
133
+
134
+
135
+ def pretty_print_semaphore(semaphore):
136
+ if semaphore is None:
137
+ return "None"
138
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
139
+
140
+
141
+ def load_image_from_base64(image):
142
+ return Image.open(BytesIO(base64.b64decode(image)))
143
+
144
+
145
+ def get_log_filename():
146
+ t = datetime.datetime.now()
147
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
148
+ return name
149
+
150
+
151
+ def data_wrapper(data):
152
+ if isinstance(data, bytes):
153
+ return data
154
+ elif isinstance(data, Image.Image):
155
+ buffered = BytesIO()
156
+ data.save(buffered, format="PNG")
157
+ return buffered.getvalue()
158
+ elif isinstance(data, str):
159
+ return data.encode()
160
+ elif isinstance(data, Dict):
161
+ return json.dumps(data).encode()
162
+ else:
163
+ raise ValueError(f"Unsupported data type: {type(data)}")