Spaces:
Running
on
Zero
Running
on
Zero
REVISION = "af989d6d7d933af3df965d739f1528ef990f8eda" | |
try: | |
import spaces | |
IN_SPACES = True | |
except ImportError: | |
from functools import wraps | |
import inspect | |
class spaces: | |
def GPU(duration): | |
def decorator(func): | |
# Preserves the original function's metadata | |
def wrapper(*args, **kwargs): | |
if inspect.isgeneratorfunction(func): | |
# If the decorated function is a generator, yield from it | |
yield from func(*args, **kwargs) | |
else: | |
# For regular functions, just return the result | |
return func(*args, **kwargs) | |
return wrapper | |
return decorator | |
IN_SPACES = False | |
import torch | |
import os | |
import gradio as gr | |
import json | |
from queue import Queue | |
from threading import Thread | |
from transformers import AutoModelForCausalLM | |
from PIL import ImageDraw | |
from torchvision.transforms.v2 import Resize | |
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True | |
moondream = AutoModelForCausalLM.from_pretrained( | |
"vikhyatk/moondream-next", | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
device_map={"": "cuda"}, | |
revision=REVISION | |
) | |
moondream.eval() | |
def convert_to_entities(text, coords): | |
""" | |
Converts a string with special markers into an entity representation. | |
Markers: | |
- <|coord|> pairs indicate coordinate markers | |
- <|start_ground_points|> indicates the start of grounding | |
- <|start_ground_text|> indicates the start of a ground term | |
- <|end_ground|> indicates the end of a ground term | |
Returns: | |
- Dictionary with cleaned text and entities with their character positions | |
""" | |
# Initialize variables | |
cleaned_text = "" | |
entities = [] | |
entity = [] | |
# Track current position in cleaned text | |
current_pos = 0 | |
# Track if we're currently processing an entity | |
in_entity = False | |
entity_start = 0 | |
i = 0 | |
while i < len(text): | |
# Check for markers | |
if text[i : i + 9] == "<|coord|>": | |
i += 9 | |
entity.append(coords.pop(0)) | |
continue | |
elif text[i : i + 23] == "<|start_ground_points|>": | |
in_entity = True | |
entity_start = current_pos | |
i += 23 | |
continue | |
elif text[i : i + 21] == "<|start_ground_text|>": | |
entity_start = current_pos | |
i += 21 | |
continue | |
elif text[i : i + 14] == "<|end_ground|>": | |
# Store entity position | |
entities.append( | |
{ | |
"entity": json.dumps(entity), | |
"start": entity_start, | |
"end": current_pos, | |
} | |
) | |
entity = [] | |
in_entity = False | |
i += 14 | |
continue | |
# Add character to cleaned text | |
cleaned_text += text[i] | |
current_pos += 1 | |
i += 1 | |
return {"text": cleaned_text, "entities": entities} | |
def answer_question(img, prompt, reasoning): | |
buffer = "" | |
resp = moondream.query(img, prompt, stream=True, reasoning=reasoning) | |
reasoning_text = resp["reasoning"]["text"] if reasoning else "[reasoning disabled]" | |
entities = [ | |
{"start": g["start_idx"], "end": g["end_idx"], "entity": json.dumps(g["points"])} | |
for g in resp["reasoning"]["grounding"] | |
] if reasoning else [] | |
for new_text in resp["answer"]: | |
buffer += new_text | |
yield buffer.strip(), {"text": reasoning_text, "entities": entities} | |
def caption(img, mode): | |
if img is None: | |
yield "" | |
return | |
buffer = "" | |
if mode == "Short": | |
l = "short" | |
elif mode == "Long": | |
l = "long" | |
else: | |
l = "normal" | |
for t in moondream.caption(img, length=l, stream=True)["caption"]: | |
buffer += t | |
yield buffer.strip() | |
def detect(img, object): | |
if img is None: | |
yield "", gr.update(visible=False, value=None) | |
return | |
w, h = img.size | |
if w > 768 or h > 768: | |
img = Resize(768)(img) | |
w, h = img.size | |
objs = moondream.detect(img, object)["objects"] | |
draw_image = ImageDraw.Draw(img) | |
for o in objs: | |
draw_image.rectangle( | |
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h), | |
outline="red", | |
width=3, | |
) | |
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( | |
visible=True, value=img | |
) | |
def point(img, object): | |
if img is None: | |
yield "", gr.update(visible=False, value=None) | |
return | |
w, h = img.size | |
if w > 768 or h > 768: | |
img = Resize(768)(img) | |
w, h = img.size | |
objs = moondream.point(img, object, settings={"max_objects": 200})["points"] | |
draw_image = ImageDraw.Draw(img) | |
for o in objs: | |
draw_image.ellipse( | |
(o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5), | |
fill="red", | |
outline="blue", | |
width=2, | |
) | |
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( | |
visible=True, value=img | |
) | |
def localized_query(img, x, y, question): | |
if img is None: | |
yield "", {"text": "", "entities": []}, gr.update(visible=False, value=None) | |
return | |
answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"] | |
w, h = img.size | |
x, y = x * w, y * h | |
img_clone = img.copy() | |
draw = ImageDraw.Draw(img_clone) | |
draw.ellipse( | |
(x - 5, y - 5, x + 5, y + 5), | |
fill="red", | |
outline="blue", | |
) | |
yield answer, {"text": "", "entities": []}, gr.update(visible=True, value=img_clone) | |
js = "" | |
css = """ | |
.output-text span p { | |
font-size: 1.4rem !important; | |
} | |
.chain-of-thought { | |
opacity: 0.7 !important; | |
} | |
.chain-of-thought span.label { | |
display: none; | |
} | |
.chain-of-thought span.textspan { | |
padding-right: 0; | |
} | |
""" | |
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo: | |
if IN_SPACES: | |
# gr.HTML("<style>body, body gradio-app { background: none !important; }</style>") | |
pass | |
gr.Markdown( | |
""" | |
# 🌔 test space, pls ignore | |
""" | |
) | |
mode_radio = gr.Radio( | |
["Caption", "Query", "Detect", "Point", "Localized"], | |
show_label=False, | |
value=lambda: "Caption", | |
) | |
input_image = gr.State(None) | |
with gr.Row(): | |
with gr.Column(): | |
def show_inputs(mode): | |
if mode == "Query": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Input", | |
value="How many people are in this image?", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
reasoning = gr.Checkbox(label="Enable reasoning") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(answer_question, [img, prompt, reasoning], [output, thought]) | |
prompt.submit(answer_question, [img, prompt, reasoning], [output, thought]) | |
reasoning.change(answer_question, [img, prompt, reasoning], [output, thought]) | |
img.change(answer_question, [img, prompt, reasoning], [output, thought]) | |
img.change(lambda img: img, [img], [input_image]) | |
elif mode == "Caption": | |
with gr.Group(): | |
with gr.Row(): | |
caption_mode = gr.Radio( | |
["Short", "Normal", "Long"], | |
label="Caption Length", | |
value=lambda: "Normal", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(caption, [img, caption_mode], output) | |
img.change(caption, [img, caption_mode], output) | |
elif mode == "Detect": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Object", | |
value="Cat", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(detect, [img, prompt], [thought, ann]) | |
prompt.submit(detect, [img, prompt], [thought, ann]) | |
img.change(detect, [img, prompt], [thought, ann]) | |
elif mode == "Point": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Object", | |
value="Cat", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(point, [img, prompt], [thought, ann]) | |
prompt.submit(point, [img, prompt], [thought, ann]) | |
img.change(point, [img, prompt], [thought, ann]) | |
elif mode == "Localized": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Input", | |
value="What is this?", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
x_slider = gr.Slider(label="x", minimum=0, maximum=1) | |
y_slider = gr.Slider(label="y", minimum=0, maximum=1) | |
submit.click(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
prompt.submit(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
x_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
y_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
img.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
else: | |
gr.Markdown("Coming soon!") | |
with gr.Column(): | |
thought = gr.HighlightedText( | |
elem_classes=["chain-of-thought"], | |
label="Thinking tokens", | |
interactive=False, | |
) | |
output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True) | |
ann = gr.Image(visible=False) | |
def on_select(img, evt: gr.SelectData): | |
if img is None or evt.value[1] is None: | |
return gr.update(visible=False, value=None) | |
w, h = img.size | |
if w > 768 or h > 768: | |
img = Resize(768)(img) | |
w, h = img.size | |
points = json.loads(evt.value[1]) | |
img_clone = img.copy() | |
draw = ImageDraw.Draw(img_clone) | |
for point in points: | |
x = int(point[0] * w) | |
y = int(point[1] * h) | |
draw.ellipse( | |
(x - 3, y - 3, x + 3, y + 3), | |
fill="red", | |
outline="red", | |
) | |
return gr.update(visible=True, value=img_clone) | |
thought.select(on_select, [input_image], [ann]) | |
input_image.change(lambda: gr.update(visible=False), [], [ann]) | |
mode_radio.change( | |
lambda: ("", "", gr.update(visible=False, value=None)), | |
[], | |
[output, thought, ann], | |
) | |
demo.queue().launch() | |