test-space / app.py
vikhyatk's picture
Update app.py
59763e9 verified
REVISION = "af989d6d7d933af3df965d739f1528ef990f8eda"
try:
import spaces
IN_SPACES = True
except ImportError:
from functools import wraps
import inspect
class spaces:
@staticmethod
def GPU(duration):
def decorator(func):
@wraps(func) # Preserves the original function's metadata
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the decorated function is a generator, yield from it
yield from func(*args, **kwargs)
else:
# For regular functions, just return the result
return func(*args, **kwargs)
return wrapper
return decorator
IN_SPACES = False
import torch
import os
import gradio as gr
import json
from queue import Queue
from threading import Thread
from transformers import AutoModelForCausalLM
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
moondream = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map={"": "cuda"},
revision=REVISION
)
moondream.eval()
def convert_to_entities(text, coords):
"""
Converts a string with special markers into an entity representation.
Markers:
- <|coord|> pairs indicate coordinate markers
- <|start_ground_points|> indicates the start of grounding
- <|start_ground_text|> indicates the start of a ground term
- <|end_ground|> indicates the end of a ground term
Returns:
- Dictionary with cleaned text and entities with their character positions
"""
# Initialize variables
cleaned_text = ""
entities = []
entity = []
# Track current position in cleaned text
current_pos = 0
# Track if we're currently processing an entity
in_entity = False
entity_start = 0
i = 0
while i < len(text):
# Check for markers
if text[i : i + 9] == "<|coord|>":
i += 9
entity.append(coords.pop(0))
continue
elif text[i : i + 23] == "<|start_ground_points|>":
in_entity = True
entity_start = current_pos
i += 23
continue
elif text[i : i + 21] == "<|start_ground_text|>":
entity_start = current_pos
i += 21
continue
elif text[i : i + 14] == "<|end_ground|>":
# Store entity position
entities.append(
{
"entity": json.dumps(entity),
"start": entity_start,
"end": current_pos,
}
)
entity = []
in_entity = False
i += 14
continue
# Add character to cleaned text
cleaned_text += text[i]
current_pos += 1
i += 1
return {"text": cleaned_text, "entities": entities}
@spaces.GPU(duration=30)
def answer_question(img, prompt, reasoning):
buffer = ""
resp = moondream.query(img, prompt, stream=True, reasoning=reasoning)
reasoning_text = resp["reasoning"]["text"] if reasoning else "[reasoning disabled]"
entities = [
{"start": g["start_idx"], "end": g["end_idx"], "entity": json.dumps(g["points"])}
for g in resp["reasoning"]["grounding"]
] if reasoning else []
for new_text in resp["answer"]:
buffer += new_text
yield buffer.strip(), {"text": reasoning_text, "entities": entities}
@spaces.GPU(duration=10)
def caption(img, mode):
if img is None:
yield ""
return
buffer = ""
if mode == "Short":
l = "short"
elif mode == "Long":
l = "long"
else:
l = "normal"
for t in moondream.caption(img, length=l, stream=True)["caption"]:
buffer += t
yield buffer.strip()
@spaces.GPU(duration=10)
def detect(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.detect(img, object)["objects"]
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.rectangle(
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h),
outline="red",
width=3,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
@spaces.GPU(duration=10)
def point(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.point(img, object, settings={"max_objects": 200})["points"]
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.ellipse(
(o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5),
fill="red",
outline="blue",
width=2,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
@spaces.GPU(duration=10)
def localized_query(img, x, y, question):
if img is None:
yield "", {"text": "", "entities": []}, gr.update(visible=False, value=None)
return
answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"]
w, h = img.size
x, y = x * w, y * h
img_clone = img.copy()
draw = ImageDraw.Draw(img_clone)
draw.ellipse(
(x - 5, y - 5, x + 5, y + 5),
fill="red",
outline="blue",
)
yield answer, {"text": "", "entities": []}, gr.update(visible=True, value=img_clone)
js = ""
css = """
.output-text span p {
font-size: 1.4rem !important;
}
.chain-of-thought {
opacity: 0.7 !important;
}
.chain-of-thought span.label {
display: none;
}
.chain-of-thought span.textspan {
padding-right: 0;
}
"""
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
if IN_SPACES:
# gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
pass
gr.Markdown(
"""
# 🌔 test space, pls ignore
"""
)
mode_radio = gr.Radio(
["Caption", "Query", "Detect", "Point", "Localized"],
show_label=False,
value=lambda: "Caption",
)
input_image = gr.State(None)
with gr.Row():
with gr.Column():
@gr.render(inputs=[mode_radio])
def show_inputs(mode):
if mode == "Query":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="How many people are in this image?",
scale=4,
)
submit = gr.Button("Submit")
reasoning = gr.Checkbox(label="Enable reasoning")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(answer_question, [img, prompt, reasoning], [output, thought])
prompt.submit(answer_question, [img, prompt, reasoning], [output, thought])
reasoning.change(answer_question, [img, prompt, reasoning], [output, thought])
img.change(answer_question, [img, prompt, reasoning], [output, thought])
img.change(lambda img: img, [img], [input_image])
elif mode == "Caption":
with gr.Group():
with gr.Row():
caption_mode = gr.Radio(
["Short", "Normal", "Long"],
label="Caption Length",
value=lambda: "Normal",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(caption, [img, caption_mode], output)
img.change(caption, [img, caption_mode], output)
elif mode == "Detect":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(detect, [img, prompt], [thought, ann])
prompt.submit(detect, [img, prompt], [thought, ann])
img.change(detect, [img, prompt], [thought, ann])
elif mode == "Point":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(point, [img, prompt], [thought, ann])
prompt.submit(point, [img, prompt], [thought, ann])
img.change(point, [img, prompt], [thought, ann])
elif mode == "Localized":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="What is this?",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
x_slider = gr.Slider(label="x", minimum=0, maximum=1)
y_slider = gr.Slider(label="y", minimum=0, maximum=1)
submit.click(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann])
prompt.submit(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann])
x_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann])
y_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann])
img.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann])
else:
gr.Markdown("Coming soon!")
with gr.Column():
thought = gr.HighlightedText(
elem_classes=["chain-of-thought"],
label="Thinking tokens",
interactive=False,
)
output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
ann = gr.Image(visible=False)
def on_select(img, evt: gr.SelectData):
if img is None or evt.value[1] is None:
return gr.update(visible=False, value=None)
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
points = json.loads(evt.value[1])
img_clone = img.copy()
draw = ImageDraw.Draw(img_clone)
for point in points:
x = int(point[0] * w)
y = int(point[1] * h)
draw.ellipse(
(x - 3, y - 3, x + 3, y + 3),
fill="red",
outline="red",
)
return gr.update(visible=True, value=img_clone)
thought.select(on_select, [input_image], [ann])
input_image.change(lambda: gr.update(visible=False), [], [ann])
mode_radio.change(
lambda: ("", "", gr.update(visible=False, value=None)),
[],
[output, thought, ann],
)
demo.queue().launch()