ZebangCheng commited on
Commit
9cc1833
·
1 Parent(s): fbbaf0c

zerogpu test

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 😁
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -27,22 +27,6 @@ from minigpt4.tasks import *
27
  import socket
28
  import os
29
 
30
- def find_free_port(start_port, end_port):
31
- for port in range(start_port, end_port + 1):
32
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
33
- if sock.connect_ex(('localhost', port)) != 0: # Port is not open
34
- return port
35
- raise OSError(f"Cannot find empty port in range: {start_port}-{end_port}")
36
-
37
- def set_gradio_server_port():
38
- start_port = 7870
39
- end_port = 9999
40
- free_port = find_free_port(start_port, end_port)
41
- os.environ["GRADIO_SERVER_PORT"] = str(free_port)
42
- print(f"Set GRADIO_SERVER_PORT to {free_port}")
43
-
44
- # Set GRADIO_SERVER_PORT
45
- set_gradio_server_port()
46
 
47
  def parse_args():
48
  parser = argparse.ArgumentParser(description="Demo")
@@ -747,4 +731,4 @@ with gr.Blocks() as demo:
747
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
748
 
749
  demo.queue()
750
- demo.launch(share=True)
 
27
  import socket
28
  import os
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def parse_args():
32
  parser = argparse.ArgumentParser(description="Demo")
 
731
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
732
 
733
  demo.queue()
734
+ demo.launch()
minigpt4/common/eval_utils.py CHANGED
@@ -54,7 +54,7 @@ def init_model(args):
54
 
55
  model_config = cfg.model_cfg
56
  model_cls = registry.get_model_class(model_config.arch)
57
- model = model_cls.from_config(model_config).to('cuda:0')
58
 
59
  # import pudb; pudb.set_trace()
60
  key = list(cfg.datasets_cfg.keys())[0]
 
54
 
55
  model_config = cfg.model_cfg
56
  model_cls = registry.get_model_class(model_config.arch)
57
+ model = model_cls.from_config(model_config).to('cuda')
58
 
59
  # import pudb; pudb.set_trace()
60
  key = list(cfg.datasets_cfg.keys())[0]
minigpt4/conversation/conversation.py CHANGED
@@ -21,8 +21,7 @@ from typing import List, Tuple, Any
21
  from minigpt4.common.registry import registry
22
 
23
  import os
24
- from huggingface_hub import login
25
- login(token=os.environ['API_TOKEN'])
26
 
27
  class SeparatorStyle(Enum):
28
  """Different separator style."""
@@ -172,7 +171,7 @@ def extract_audio_from_video(video_path):
172
  return samples, sr
173
 
174
  class Chat:
175
- def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
176
  self.device = device
177
  self.model = model
178
  self.vis_processor = vis_processor
 
21
  from minigpt4.common.registry import registry
22
 
23
  import os
24
+
 
25
 
26
  class SeparatorStyle(Enum):
27
  """Different separator style."""
 
171
  return samples, sr
172
 
173
  class Chat:
174
+ def __init__(self, model, vis_processor, device='cuda', stopping_criteria=None):
175
  self.device = device
176
  self.model = model
177
  self.vis_processor = vis_processor
minigpt4/models/minigpt_v2.py CHANGED
@@ -90,7 +90,7 @@ class MiniGPTv2(MiniGPTBase):
90
  self.llama_model.gradient_checkpointing_enable()
91
 
92
  def encode_img(self, image, video_features):
93
- # device = 'cuda:0'
94
  device = image.device
95
  if len(image.shape) > 4:
96
  image = image.reshape(-1, *image.shape[-3:])
 
90
  self.llama_model.gradient_checkpointing_enable()
91
 
92
  def encode_img(self, image, video_features):
93
+ # device = 'cuda'
94
  device = image.device
95
  if len(image.shape) > 4:
96
  image = image.reshape(-1, *image.shape[-3:])
requirements.txt CHANGED
@@ -16,6 +16,6 @@ torch==2.1.2
16
  torchvision
17
  timm==0.6.13
18
  transformers==4.30.0
19
- gradio
20
  gradio_client
21
  numpy<2.0
 
16
  torchvision
17
  timm==0.6.13
18
  transformers==4.30.0
19
+ gradio==5.34.0
20
  gradio_client
21
  numpy<2.0