Nyandwi commited on
Commit
857bec0
·
verified ·
1 Parent(s): c5a991f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +596 -0
app.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .demo_modelpart import InferenceDemo
2
+ import gradio as gr
3
+ import os
4
+ from threading import Thread
5
+
6
+ # import time
7
+ import cv2
8
+
9
+ import datetime
10
+ # import copy
11
+ import torch
12
+
13
+ import spaces
14
+ import numpy as np
15
+
16
+ from llava import conversation as conversation_lib
17
+ from llava.constants import DEFAULT_IMAGE_TOKEN
18
+
19
+
20
+ from llava.constants import (
21
+ IMAGE_TOKEN_INDEX,
22
+ DEFAULT_IMAGE_TOKEN,
23
+ DEFAULT_IM_START_TOKEN,
24
+ DEFAULT_IM_END_TOKEN,
25
+ )
26
+ from llava.conversation import conv_templates, SeparatorStyle
27
+ from llava.model.builder import load_pretrained_model
28
+ from llava.utils import disable_torch_init
29
+ from llava.mm_utils import (
30
+ tokenizer_image_token,
31
+ get_model_name_from_path,
32
+ KeywordsStoppingCriteria,
33
+ )
34
+
35
+ from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
36
+
37
+ import requests
38
+ from PIL import Image
39
+ from io import BytesIO
40
+ from transformers import TextStreamer, TextIteratorStreamer
41
+
42
+ import hashlib
43
+ import PIL
44
+ import base64
45
+ import json
46
+
47
+ import datetime
48
+ import gradio as gr
49
+ import gradio_client
50
+ import subprocess
51
+ import sys
52
+
53
+ from huggingface_hub import HfApi
54
+ from huggingface_hub import login
55
+ from huggingface_hub import revision_exists
56
+
57
+ login(token=os.environ["HF_TOKEN"],
58
+ write_permission=True)
59
+
60
+ api = HfApi()
61
+ repo_name = os.environ["LOG_REPO"]
62
+
63
+ external_log_dir = "./logs"
64
+ LOGDIR = external_log_dir
65
+ VOTEDIR = "./votes"
66
+
67
+
68
+ def install_gradio_4_35_0():
69
+ current_version = gr.__version__
70
+ if current_version != "4.35.0":
71
+ print(f"Current Gradio version: {current_version}")
72
+ print("Installing Gradio 4.35.0...")
73
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"])
74
+ print("Gradio 4.35.0 installed successfully.")
75
+ else:
76
+ print("Gradio 4.35.0 is already installed.")
77
+
78
+ # Call the function to install Gradio 4.35.0 if needed
79
+ install_gradio_4_35_0()
80
+
81
+ import gradio as gr
82
+ import gradio_client
83
+ print(f"Gradio version: {gr.__version__}")
84
+ print(f"Gradio-client version: {gradio_client.__version__}")
85
+
86
+ def get_conv_log_filename():
87
+ t = datetime.datetime.now()
88
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
89
+ if not os.path.isfile(name):
90
+ os.makedirs(os.path.dirname(name), exist_ok=True)
91
+ return name
92
+
93
+ def get_conv_vote_filename():
94
+ t = datetime.datetime.now()
95
+ name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
96
+ if not os.path.isfile(name):
97
+ os.makedirs(os.path.dirname(name), exist_ok=True)
98
+ return name
99
+
100
+ def vote_last_response(state, vote_type, model_selector):
101
+ with open(get_conv_vote_filename(), "a") as fout:
102
+ data = {
103
+ "type": vote_type,
104
+ "model": model_selector,
105
+ "state": state,
106
+ }
107
+ fout.write(json.dumps(data) + "\n")
108
+ api.upload_file(
109
+ path_or_fileobj=get_conv_vote_filename(),
110
+ path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
111
+ repo_id=repo_name,
112
+ repo_type="dataset")
113
+
114
+
115
+ def upvote_last_response(state):
116
+ vote_last_response(state, "upvote", "CulturalPangea-7B")
117
+ gr.Info("Thank you for your voting!")
118
+ return state
119
+
120
+ def downvote_last_response(state):
121
+ vote_last_response(state, "downvote", "CulturalPangea-7B")
122
+ gr.Info("Thank you for your voting!")
123
+ return state
124
+
125
+ class InferenceDemo(object):
126
+ def __init__(
127
+ self, args, model_path, tokenizer, model, image_processor, context_len
128
+ ) -> None:
129
+ disable_torch_init()
130
+
131
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
132
+ tokenizer,
133
+ model,
134
+ image_processor,
135
+ context_len,
136
+ )
137
+
138
+ if "llama-2" in model_name.lower():
139
+ conv_mode = "llava_llama_2"
140
+ elif "v1" in model_name.lower():
141
+ conv_mode = "llava_v1"
142
+ elif "mpt" in model_name.lower():
143
+ conv_mode = "mpt"
144
+ elif "qwen" in model_name.lower():
145
+ conv_mode = "qwen_1_5"
146
+ elif "pangea" in model_name.lower():
147
+ conv_mode = "qwen_1_5"
148
+ else:
149
+ conv_mode = "llava_v0"
150
+
151
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
152
+ print(
153
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
154
+ conv_mode, args.conv_mode, args.conv_mode
155
+ )
156
+ )
157
+ else:
158
+ args.conv_mode = conv_mode
159
+ self.conv_mode = conv_mode
160
+ self.conversation = conv_templates[args.conv_mode].copy()
161
+ self.num_frames = args.num_frames
162
+
163
+ class ChatSessionManager:
164
+ def __init__(self):
165
+ self.chatbot_instance = None
166
+
167
+ def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
168
+ self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
169
+ print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
170
+
171
+ def reset_chatbot(self):
172
+ self.chatbot_instance = None
173
+
174
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
175
+ if self.chatbot_instance is None:
176
+ self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
177
+ return self.chatbot_instance
178
+
179
+
180
+ def is_valid_video_filename(name):
181
+ video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
182
+
183
+ ext = name.split(".")[-1].lower()
184
+
185
+ if ext in video_extensions:
186
+ return True
187
+ else:
188
+ return False
189
+
190
+ def is_valid_image_filename(name):
191
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
192
+
193
+ ext = name.split(".")[-1].lower()
194
+
195
+ if ext in image_extensions:
196
+ return True
197
+ else:
198
+ return False
199
+
200
+
201
+ def sample_frames(video_file, num_frames):
202
+ video = cv2.VideoCapture(video_file)
203
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
204
+ interval = total_frames // num_frames
205
+ frames = []
206
+ for i in range(total_frames):
207
+ ret, frame = video.read()
208
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
209
+ if not ret:
210
+ continue
211
+ if i % interval == 0:
212
+ frames.append(pil_img)
213
+ video.release()
214
+ return frames
215
+
216
+
217
+ def load_image(image_file):
218
+ if image_file.startswith("http") or image_file.startswith("https"):
219
+ response = requests.get(image_file)
220
+ if response.status_code == 200:
221
+ image = Image.open(BytesIO(response.content)).convert("RGB")
222
+ else:
223
+ print("failed to load the image")
224
+ else:
225
+ print("Load image from local file")
226
+ print(image_file)
227
+ image = Image.open(image_file).convert("RGB")
228
+
229
+ return image
230
+
231
+
232
+ def clear_response(history):
233
+ for index_conv in range(1, len(history)):
234
+ # loop until get a text response from our model.
235
+ conv = history[-index_conv]
236
+ if not (conv[0] is None):
237
+ break
238
+ question = history[-index_conv][0]
239
+ history = history[:-index_conv]
240
+ return history, question
241
+
242
+ chat_manager = ChatSessionManager()
243
+
244
+
245
+ def clear_history(history):
246
+ chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
247
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
248
+ return None
249
+
250
+
251
+
252
+ def add_message(history, message):
253
+ global chat_image_num
254
+ print("#### len(history)",len(history))
255
+ if not history:
256
+ history = []
257
+ print("### Initialize chatbot")
258
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
259
+ chat_image_num = 0
260
+ print("chat_image_num", chat_image_num)
261
+
262
+ if len(message["files"]) <= 1:
263
+ for x in message["files"]:
264
+ history.append(((x,), None))
265
+ chat_image_num += 1
266
+ if chat_image_num > 1:
267
+ history = []
268
+ chat_manager.reset_chatbot()
269
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
270
+ chat_image_num = 0
271
+ for x in message["files"]:
272
+ history.append(((x,), None))
273
+ chat_image_num += 1
274
+
275
+ if message["text"] is not None:
276
+ history.append((message["text"], None))
277
+ print("chat_image_num", chat_image_num)
278
+ # print(f"### Chatbot instance ID: {id(our_chatbot)}")
279
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
280
+ else:
281
+ for x in message["files"]:
282
+ history.append(((x,), None))
283
+ if message["text"] is not None:
284
+ history.append((message["text"], None))
285
+
286
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
287
+
288
+
289
+ @spaces.GPU
290
+ def bot(history, temperature, top_p, max_output_tokens):
291
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
292
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
293
+ text = history[-1][0]
294
+ images_this_term = []
295
+ text_this_term = ""
296
+
297
+ num_new_images = 0
298
+ previous_image = False
299
+ for i, message in enumerate(history[:-1]):
300
+ if type(message[0]) is tuple:
301
+ if previous_image:
302
+ gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
303
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
304
+ return None
305
+
306
+ images_this_term.append(message[0][0])
307
+ if is_valid_video_filename(message[0][0]):
308
+ raise ValueError("Video is not supported")
309
+ num_new_images += our_chatbot.num_frames
310
+ elif is_valid_image_filename(message[0][0]):
311
+ print("#### Load image from local file",message[0][0])
312
+ num_new_images += 1
313
+ else:
314
+ raise ValueError("Invalid image file")
315
+ previous_image = True
316
+ else:
317
+ num_new_images = 0
318
+ previous_image = False
319
+
320
+ all_image_hash = []
321
+ all_image_path = []
322
+ for image_path in images_this_term:
323
+ with open(image_path, "rb") as image_file:
324
+ image_data = image_file.read()
325
+ image_hash = hashlib.md5(image_data).hexdigest()
326
+ all_image_hash.append(image_hash)
327
+ image = PIL.Image.open(image_path).convert("RGB")
328
+ t = datetime.datetime.now()
329
+ filename = os.path.join(
330
+ LOGDIR,
331
+ "serve_images",
332
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
333
+ f"{image_hash}.jpg",
334
+ )
335
+ all_image_path.append(filename)
336
+ if not os.path.isfile(filename):
337
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
338
+ print("image save to",filename)
339
+ image.save(filename)
340
+
341
+ image_list = []
342
+ for f in images_this_term:
343
+ if is_valid_video_filename(f):
344
+ image_list += sample_frames(f, our_chatbot.num_frames)
345
+ elif is_valid_image_filename(f):
346
+ image_list.append(load_image(f))
347
+ else:
348
+ raise ValueError("Invalid image file")
349
+
350
+ image_tensor = []
351
+ if num_new_images > 0:
352
+ image_tensor = [
353
+ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
354
+ 0
355
+ ]
356
+ .half()
357
+ .to(our_chatbot.model.device)
358
+ for f in image_list
359
+ ]
360
+
361
+
362
+ image_tensor = torch.stack(image_tensor)
363
+ image_token = DEFAULT_IMAGE_TOKEN * num_new_images
364
+
365
+ inp = text
366
+ inp = image_token + "\n" + inp
367
+ else:
368
+ inp = text
369
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
370
+ # image = None
371
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
372
+ prompt = our_chatbot.conversation.get_prompt()
373
+
374
+ input_ids = tokenizer_image_token(
375
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
376
+ ).unsqueeze(0).to(our_chatbot.model.device)
377
+ # print("### input_id",input_ids)
378
+ stop_str = (
379
+ our_chatbot.conversation.sep
380
+ if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
381
+ else our_chatbot.conversation.sep2
382
+ )
383
+ keywords = [stop_str]
384
+ stopping_criteria = KeywordsStoppingCriteria(
385
+ keywords, our_chatbot.tokenizer, input_ids
386
+ )
387
+
388
+ streamer = TextIteratorStreamer(
389
+ our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
390
+ )
391
+ print(our_chatbot.model.device)
392
+ print(input_ids.device)
393
+ # print(image_tensor.device)
394
+
395
+ generate_kwargs = dict(
396
+ inputs=input_ids,
397
+ streamer=streamer,
398
+ images=image_tensor if num_new_images > 0 else None,
399
+ do_sample=True,
400
+ temperature=temperature,
401
+ top_p=top_p,
402
+ max_new_tokens=max_output_tokens,
403
+ use_cache=False,
404
+ stopping_criteria=[stopping_criteria],
405
+ )
406
+
407
+ t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
408
+ t.start()
409
+
410
+ outputs = []
411
+ for stream_token in streamer:
412
+ outputs.append(stream_token)
413
+
414
+ history[-1] = [text, "".join(outputs)]
415
+ yield history
416
+ our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
417
+ # print("### turn end history", history)
418
+ # print("### turn end conv",our_chatbot.conversation)
419
+
420
+ with open(get_conv_log_filename(), "a") as fout:
421
+ data = {
422
+ "type": "chat",
423
+ "model": "CulturalPangea-7B",
424
+ "state": history,
425
+ "images": all_image_hash,
426
+ "images_path": all_image_path
427
+ }
428
+ print("#### conv log",data)
429
+ fout.write(json.dumps(data) + "\n")
430
+ for upload_img in all_image_path:
431
+ api.upload_file(
432
+ path_or_fileobj=upload_img,
433
+ path_in_repo=upload_img.replace("./logs/", ""),
434
+ repo_id=repo_name,
435
+ repo_type="dataset",
436
+ # revision=revision,
437
+ # ignore_patterns=["data*"]
438
+ )
439
+ # upload json
440
+ api.upload_file(
441
+ path_or_fileobj=get_conv_log_filename(),
442
+ path_in_repo=get_conv_log_filename().replace("./logs/", ""),
443
+ repo_id=repo_name,
444
+ repo_type="dataset")
445
+
446
+
447
+
448
+ txt = gr.Textbox(
449
+ scale=4,
450
+ show_label=False,
451
+ placeholder="Enter text and press enter.",
452
+ container=False,
453
+ )
454
+
455
+ with gr.Blocks(
456
+ css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}",
457
+ ) as demo:
458
+
459
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
460
+ # gr.Markdown(title_markdown)
461
+ gr.HTML(html_header)
462
+
463
+ with gr.Column():
464
+ with gr.Accordion("Parameters", open=False) as parameter_row:
465
+ temperature = gr.Slider(
466
+ minimum=0.0,
467
+ maximum=1.0,
468
+ value=0.7,
469
+ step=0.1,
470
+ interactive=True,
471
+ label="Temperature",
472
+ )
473
+ top_p = gr.Slider(
474
+ minimum=0.0,
475
+ maximum=1.0,
476
+ value=1,
477
+ step=0.1,
478
+ interactive=True,
479
+ label="Top P",
480
+ )
481
+ max_output_tokens = gr.Slider(
482
+ minimum=0,
483
+ maximum=8192,
484
+ value=4096,
485
+ step=256,
486
+ interactive=True,
487
+ label="Max output tokens",
488
+ )
489
+ with gr.Row():
490
+ chatbot = gr.Chatbot([], elem_id="Pangea", bubble_full_width=False, height=750)
491
+
492
+ with gr.Row():
493
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
494
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
495
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
496
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
497
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
498
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
499
+
500
+
501
+ chat_input = gr.MultimodalTextbox(
502
+ interactive=True,
503
+ file_types=["image"],
504
+ placeholder="Enter message or upload file...",
505
+ show_label=False,
506
+ submit_btn="🚀"
507
+ )
508
+
509
+ print(cur_dir)
510
+ gr.Examples(
511
+ examples_per_page=20,
512
+ examples=[
513
+ [
514
+ {
515
+ "files": [
516
+ f"{cur_dir}/examples/norway.jpg",
517
+ ],
518
+ "text": "Analysieren, in welchem Land diese Szene höchstwahrscheinlich gedreht wurde.",
519
+ },
520
+ ],
521
+ [
522
+ {
523
+ "files": [
524
+ f"{cur_dir}/examples/africa.jpg",
525
+ ],
526
+ "text": "इस तस्वीर में हर एक दृश्य तत्व का क्या प्रतिनिधित्व करता है?",
527
+ },
528
+ ],
529
+ [
530
+ {
531
+ "files": [
532
+ f"{cur_dir}/examples/food.jpg",
533
+ ],
534
+ "text": "Unaweza kunipa kichocheo cha kutengeneza hii pancake?",
535
+ },
536
+ ],
537
+ ],
538
+ inputs=[chat_input],
539
+ label="Image",
540
+ )
541
+
542
+ gr.Markdown(tos_markdown)
543
+ gr.Markdown(learn_more_markdown)
544
+ gr.Markdown(bibtext)
545
+
546
+ chat_input.submit(
547
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
548
+ ).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
549
+
550
+
551
+ # chatbot.like(print_like_dislike, None, None)
552
+ clear_btn.click(
553
+ fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
554
+ )
555
+
556
+ upvote_btn.click(
557
+ fn=upvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
558
+ )
559
+
560
+
561
+ downvote_btn.click(
562
+ fn=downvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
563
+ )
564
+
565
+
566
+ demo.queue()
567
+
568
+ if __name__ == "__main__":
569
+ import argparse
570
+
571
+ argparser = argparse.ArgumentParser()
572
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
573
+ argparser.add_argument("--port", default="6123", type=str)
574
+ argparser.add_argument(
575
+ "--model_path", default="neulab/CulturalPangea-7B", type=str
576
+ )
577
+ # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
578
+ argparser.add_argument("--model-base", type=str, default=None)
579
+ argparser.add_argument("--num-gpus", type=int, default=1)
580
+ argparser.add_argument("--conv-mode", type=str, default=None)
581
+ argparser.add_argument("--temperature", type=float, default=0.7)
582
+ argparser.add_argument("--max-new-tokens", type=int, default=4096)
583
+ argparser.add_argument("--num_frames", type=int, default=16)
584
+ argparser.add_argument("--load-8bit", action="store_true")
585
+ argparser.add_argument("--load-4bit", action="store_true")
586
+ argparser.add_argument("--debug", action="store_true")
587
+
588
+ args = argparser.parse_args()
589
+
590
+ model_path = args.model_path
591
+ filt_invalid = "cut"
592
+ model_name = get_model_name_from_path(args.model_path)
593
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
594
+ model=model.to(torch.device('cuda'))
595
+ chat_image_num = 0
596
+ demo.launch()