vikhyatk commited on
Commit
abc934b
·
verified ·
1 Parent(s): c57ffa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -15
app.py CHANGED
@@ -4,25 +4,41 @@ import re
4
  import os
5
  import gradio as gr
6
  from threading import Thread
7
- from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
8
  from PIL import ImageDraw
9
  from torchvision.transforms.v2 import Resize
10
 
11
  import subprocess
12
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
 
13
 
14
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
15
  tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
16
  moondream = AutoModelForCausalLM.from_pretrained(
17
- "vikhyatk/moondream-next", trust_remote_code=True,
18
- torch_dtype=torch.bfloat16, device_map={"": "cuda"},
19
- attn_implementation="flash_attention_2", use_auth_token=auth_token
 
 
 
20
  )
21
  moondream.eval()
22
 
23
 
24
  @spaces.GPU(duration=10)
25
  def answer_question(img, prompt):
 
 
 
26
  image_embeds = moondream.encode_image(img)
27
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
28
  thread = Thread(
@@ -41,6 +57,30 @@ def answer_question(img, prompt):
41
  buffer += new_text
42
  yield buffer.strip()
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def extract_floats(text):
45
  # Regular expression to match an array of four floating point numbers
46
  pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
@@ -58,6 +98,7 @@ def extract_bbox(text):
58
  bbox = (x1, y1, x2, y2)
59
  return bbox
60
 
 
61
  def process_answer(img, answer):
62
  if extract_bbox(answer) is not None:
63
  x1, y1, x2, y2 = extract_bbox(answer)
@@ -71,7 +112,41 @@ def process_answer(img, answer):
71
 
72
  return gr.update(visible=False, value=None)
73
 
74
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  gr.Markdown(
76
  """
77
  # 🌔 moondream vl (new)
@@ -79,16 +154,44 @@ with gr.Blocks() as demo:
79
  """
80
  )
81
  with gr.Row():
82
- prompt = gr.Textbox(label="Input", value="Describe this image.", scale=4)
83
- submit = gr.Button("Submit")
84
- with gr.Row():
85
- img = gr.Image(type="pil", label="Upload an Image")
86
  with gr.Column():
87
- output = gr.Markdown(label="Response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  ann = gr.Image(visible=False, label="Annotated Image")
89
 
90
- submit.click(answer_question, [img, prompt], output)
91
- prompt.submit(answer_question, [img, prompt], output)
92
- output.change(process_answer, [img, output], ann, show_progress=False)
93
 
94
- demo.queue().launch()
 
4
  import os
5
  import gradio as gr
6
  from threading import Thread
7
+ from transformers import (
8
+ TextIteratorStreamer,
9
+ AutoTokenizer,
10
+ AutoModelForCausalLM,
11
+ StaticCache,
12
+ )
13
  from PIL import ImageDraw
14
  from torchvision.transforms.v2 import Resize
15
 
16
  import subprocess
17
+
18
+ subprocess.run(
19
+ "pip install flash-attn --no-build-isolation",
20
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
21
+ shell=True,
22
+ )
23
 
24
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
25
  tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
26
  moondream = AutoModelForCausalLM.from_pretrained(
27
+ "vikhyatk/moondream-next",
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.float16,
30
+ device_map={"": "cuda"},
31
+ attn_implementation="flash_attention_2",
32
+ token=auth_token,
33
  )
34
  moondream.eval()
35
 
36
 
37
  @spaces.GPU(duration=10)
38
  def answer_question(img, prompt):
39
+ if img is None:
40
+ return
41
+
42
  image_embeds = moondream.encode_image(img)
43
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
44
  thread = Thread(
 
57
  buffer += new_text
58
  yield buffer.strip()
59
 
60
+
61
+ @spaces.GPU(duration=10)
62
+ def caption(img, mode):
63
+ if img is None:
64
+ return
65
+
66
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
67
+ thread = Thread(
68
+ target=moondream.caption,
69
+ kwargs={
70
+ "images": [img],
71
+ "length": "short" if mode == "Short" else None,
72
+ "tokenizer": tokenizer,
73
+ "streamer": streamer,
74
+ },
75
+ )
76
+ thread.start()
77
+
78
+ buffer = ""
79
+ for new_text in streamer:
80
+ buffer += new_text
81
+ yield buffer.strip()
82
+
83
+
84
  def extract_floats(text):
85
  # Regular expression to match an array of four floating point numbers
86
  pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
 
98
  bbox = (x1, y1, x2, y2)
99
  return bbox
100
 
101
+
102
  def process_answer(img, answer):
103
  if extract_bbox(answer) is not None:
104
  x1, y1, x2, y2 = extract_bbox(answer)
 
112
 
113
  return gr.update(visible=False, value=None)
114
 
115
+
116
+ with gr.Blocks(title="moondream vl (new)") as demo:
117
+ gr.HTML(
118
+ """
119
+ <script>
120
+ window.addEventListener('load', function () {
121
+ gradioURL = window.location.href;
122
+ if (!gradioURL.endsWith('?__theme=dark')) {
123
+ window.location.replace(gradioURL + '?__theme=dark');
124
+ }
125
+ });
126
+ </script>
127
+ <style type="text/css">
128
+ .output-text span p { font-size: 1.4rem !important; }
129
+ /* Add a beautiful dark background animation for space theme */
130
+ body gradio-app {
131
+ background: linear-gradient(to right, #0c0d21, #1f1e33) !important;
132
+ animation: gradientBG 15s ease infinite;
133
+ background-size: 400% 400%;
134
+ }
135
+
136
+ @keyframes gradientBG {
137
+ 0% {
138
+ background-position: 0% 50%;
139
+ }
140
+ 50% {
141
+ background-position: 100% 50%;
142
+ }
143
+ 100% {
144
+ background-position: 0% 50%;
145
+ }
146
+ }
147
+ </style>
148
+ """
149
+ )
150
  gr.Markdown(
151
  """
152
  # 🌔 moondream vl (new)
 
154
  """
155
  )
156
  with gr.Row():
 
 
 
 
157
  with gr.Column():
158
+ mode_radio = gr.Radio(
159
+ ["Caption", "Query", "Detect"],
160
+ show_label=False,
161
+ value=lambda: "Caption",
162
+ )
163
+
164
+ @gr.render(inputs=[mode_radio])
165
+ def show_inputs(mode):
166
+ if mode == "Query":
167
+ with gr.Group():
168
+ with gr.Row():
169
+ prompt = gr.Textbox(
170
+ label="Input",
171
+ value="How many people are in this image?",
172
+ scale=4,
173
+ )
174
+ submit = gr.Button("Submit")
175
+ img = gr.Image(type="pil", label="Upload an Image")
176
+ submit.click(answer_question, [img, prompt], output)
177
+ prompt.submit(answer_question, [img, prompt], output)
178
+ img.change(answer_question, [img, prompt], output)
179
+ elif mode == "Caption":
180
+ with gr.Group():
181
+ caption_mode = gr.Radio(
182
+ ["Short", "Normal"],
183
+ show_label=False,
184
+ value=lambda: "Normal",
185
+ )
186
+ img = gr.Image(type="pil", label="Upload an Image")
187
+ caption_mode.change(caption, [img, caption_mode], output)
188
+ img.change(caption, [img, caption_mode], output)
189
+ else:
190
+ gr.Markdown("Coming soon!")
191
+
192
+ with gr.Column():
193
+ output = gr.Markdown(label="Response", elem_classes=["output-text"])
194
  ann = gr.Image(visible=False, label="Annotated Image")
195
 
 
 
 
196
 
197
+ demo.queue().launch()