Spaces:
Running
on
Zero
Running
on
Zero
REVISION = "d643fc6a0410366a34fb89c6877c52a09cb41820" | |
try: | |
import spaces | |
IN_SPACES = True | |
except ImportError: | |
from functools import wraps | |
import inspect | |
class spaces: | |
def GPU(duration): | |
def decorator(func): | |
# Preserves the original function's metadata | |
def wrapper(*args, **kwargs): | |
if inspect.isgeneratorfunction(func): | |
# If the decorated function is a generator, yield from it | |
yield from func(*args, **kwargs) | |
else: | |
# For regular functions, just return the result | |
return func(*args, **kwargs) | |
return wrapper | |
return decorator | |
IN_SPACES = False | |
import torch | |
import os | |
import gradio as gr | |
import json | |
from queue import Queue | |
from threading import Thread | |
from transformers import ( | |
TextIteratorStreamer, | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
) | |
from PIL import ImageDraw | |
from torchvision.transforms.v2 import Resize | |
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True | |
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream-next", revision=REVISION) | |
moondream = AutoModelForCausalLM.from_pretrained( | |
"vikhyatk/moondream-next", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
device_map={"": "cuda"}, | |
revision=REVISION | |
) | |
# CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"] | |
# def get_ckpt(filename): | |
# ckpts = [ | |
# torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS | |
# ] | |
# avg_ckpt = { | |
# key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts) | |
# for key in ckpts[0] | |
# } | |
# return avg_ckpt | |
# moondream.load_state_dict(get_ckpt("model.pt")) | |
moondream.eval() | |
def convert_to_entities(text, coords): | |
""" | |
Converts a string with special markers into an entity representation. | |
Markers: | |
- <|coord|> pairs indicate coordinate markers | |
- <|start_ground_points|> indicates the start of grounding | |
- <|start_ground_text|> indicates the start of a ground term | |
- <|end_ground|> indicates the end of a ground term | |
Returns: | |
- Dictionary with cleaned text and entities with their character positions | |
""" | |
# Initialize variables | |
cleaned_text = "" | |
entities = [] | |
entity = [] | |
# Track current position in cleaned text | |
current_pos = 0 | |
# Track if we're currently processing an entity | |
in_entity = False | |
entity_start = 0 | |
i = 0 | |
while i < len(text): | |
# Check for markers | |
if text[i : i + 9] == "<|coord|>": | |
i += 9 | |
entity.append(coords.pop(0)) | |
continue | |
elif text[i : i + 23] == "<|start_ground_points|>": | |
in_entity = True | |
entity_start = current_pos | |
i += 23 | |
continue | |
elif text[i : i + 21] == "<|start_ground_text|>": | |
entity_start = current_pos | |
i += 21 | |
continue | |
elif text[i : i + 14] == "<|end_ground|>": | |
# Store entity position | |
entities.append( | |
{ | |
"entity": json.dumps(entity), | |
"start": entity_start, | |
"end": current_pos, | |
} | |
) | |
entity = [] | |
in_entity = False | |
i += 14 | |
continue | |
# Add character to cleaned text | |
cleaned_text += text[i] | |
current_pos += 1 | |
i += 1 | |
return {"text": cleaned_text, "entities": entities} | |
def answer_question(img, prompt): | |
if img is None: | |
yield "", "" | |
return | |
buffer = "" | |
for new_text in moondream.query(img, prompt, stream=True)["answer"]: | |
buffer += new_text | |
yield buffer.strip(), {"text": "Thinking...", "entities": []} | |
def caption(img, mode): | |
if img is None: | |
yield "" | |
return | |
buffer = "" | |
if mode == "Short": | |
l = "short" | |
elif mode == "Long": | |
l = "long" | |
else: | |
l = "normal" | |
for t in moondream.caption(img, length=l, stream=True)["caption"]: | |
buffer += t | |
yield buffer.strip() | |
def detect(img, object): | |
if img is None: | |
yield "", gr.update(visible=False, value=None) | |
return | |
w, h = img.size | |
if w > 768 or h > 768: | |
img = Resize(768)(img) | |
w, h = img.size | |
objs = moondream.detect(img, object)["objects"] | |
draw_image = ImageDraw.Draw(img) | |
for o in objs: | |
draw_image.rectangle( | |
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h), | |
outline="red", | |
width=3, | |
) | |
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( | |
visible=True, value=img | |
) | |
def point(img, object): | |
if img is None: | |
yield "", gr.update(visible=False, value=None) | |
return | |
w, h = img.size | |
if w > 768 or h > 768: | |
img = Resize(768)(img) | |
w, h = img.size | |
objs = moondream.point(img, object)["points"] | |
draw_image = ImageDraw.Draw(img) | |
for o in objs: | |
draw_image.ellipse( | |
(o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5), | |
fill="red", | |
outline="blue", | |
width=2, | |
) | |
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( | |
visible=True, value=img | |
) | |
# js = """ | |
# function createBgAnimation() { | |
# var canvas = document.createElement('canvas'); | |
# canvas.id = 'life-canvas'; | |
# document.body.appendChild(canvas); | |
# var canvas = document.getElementById('life-canvas'); | |
# var ctx = canvas.getContext('2d'); | |
# function resizeCanvas() { | |
# canvas.width = window.innerWidth; | |
# canvas.height = window.innerHeight; | |
# } | |
# resizeCanvas(); | |
# window.addEventListener('resize', resizeCanvas); | |
# var cellSize = 8; | |
# var cols = Math.ceil(canvas.width / cellSize); | |
# var rows = Math.ceil(canvas.height / cellSize); | |
# // Track cell age for color variation | |
# var grid = new Array(cols).fill(null) | |
# .map(() => new Array(rows).fill(null) | |
# .map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1 | |
# function countNeighbors(grid, x, y) { | |
# var sum = 0; | |
# for (var i = -1; i < 2; i++) { | |
# for (var j = -1; j < 2; j++) { | |
# var col = (x + i + cols) % cols; | |
# var row = (y + j + rows) % rows; | |
# sum += grid[col][row] ? 1 : 0; | |
# } | |
# } | |
# sum -= grid[x][y] ? 1 : 0; | |
# return sum; | |
# } | |
# function computeNextGeneration() { | |
# var next = grid.map(arr => [...arr]); | |
# for (var i = 0; i < cols; i++) { | |
# for (var j = 0; j < rows; j++) { | |
# var neighbors = countNeighbors(grid, i, j); | |
# var state = grid[i][j]; | |
# if (state) { | |
# if (neighbors < 2 || neighbors > 3) { | |
# next[i][j] = 0; // Cell dies | |
# } else { | |
# next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5 | |
# } | |
# } else if (neighbors === 3) { | |
# next[i][j] = 1; // New cell born | |
# } | |
# } | |
# } | |
# grid = next; | |
# } | |
# function getColor(age, isDarkMode) { | |
# // Light mode colors | |
# var lightColors = { | |
# 1: '#dae1f5', // Light blue-grey | |
# 2: '#d3e0f4', | |
# 3: '#ccdff3', | |
# 4: '#c5def2', | |
# 5: '#beddf1' // Slightly deeper blue-grey | |
# }; | |
# // Dark mode colors | |
# var darkColors = { | |
# /* | |
# 1: '#4a5788', // Deep blue-grey | |
# 2: '#4c5a8d', | |
# 3: '#4e5d92', | |
# 4: '#506097', | |
# 5: '#52639c' // Brighter blue-grey | |
# */ | |
# 1: 'rgb(16, 20, 32)', | |
# 2: 'rgb(21, 25, 39)', | |
# 3: 'rgb(26, 30, 46)', | |
# 4: 'rgb(31, 35, 53)', | |
# 5: 'rgb(36, 40, 60)' | |
# }; | |
# return isDarkMode ? darkColors[age] : lightColors[age]; | |
# } | |
# function draw() { | |
# var isDarkMode = document.body.classList.contains('dark'); | |
# ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0'; | |
# ctx.fillRect(0, 0, canvas.width, canvas.height); | |
# for (var i = 0; i < cols; i++) { | |
# for (var j = 0; j < rows; j++) { | |
# if (grid[i][j]) { | |
# ctx.fillStyle = getColor(grid[i][j], isDarkMode); | |
# ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1); | |
# } | |
# } | |
# } | |
# } | |
# var lastFrame = 0; | |
# var frameInterval = 300; | |
# function animate(timestamp) { | |
# if (timestamp - lastFrame >= frameInterval) { | |
# draw(); | |
# computeNextGeneration(); | |
# lastFrame = timestamp; | |
# } | |
# requestAnimationFrame(animate); | |
# } | |
# animate(0); | |
# } | |
# """ | |
js = "" | |
css = """ | |
.output-text span p { | |
font-size: 1.4rem !important; | |
} | |
.chain-of-thought { | |
opacity: 0.7 !important; | |
} | |
.chain-of-thought span.label { | |
display: none; | |
} | |
.chain-of-thought span.textspan { | |
padding-right: 0; | |
} | |
#life-canvas { | |
/*position: fixed; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
z-index: -1; | |
opacity: 0.3;*/ | |
} | |
""" | |
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo: | |
if IN_SPACES: | |
# gr.HTML("<style>body, body gradio-app { background: none !important; }</style>") | |
pass | |
gr.Markdown( | |
""" | |
# 🌔 moondream vl (new) | |
A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream) | |
""" | |
) | |
mode_radio = gr.Radio( | |
["Caption", "Query", "Detect", "Point"], | |
show_label=False, | |
value=lambda: "Caption", | |
) | |
input_image = gr.State(None) | |
with gr.Row(): | |
with gr.Column(): | |
def show_inputs(mode): | |
if mode == "Query": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Input", | |
value="How many people are in this image?", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(answer_question, [img, prompt], [output, thought]) | |
prompt.submit(answer_question, [img, prompt], [output, thought]) | |
img.change(answer_question, [img, prompt], [output, thought]) | |
img.change(lambda img: img, [img], [input_image]) | |
elif mode == "Caption": | |
with gr.Group(): | |
with gr.Row(): | |
caption_mode = gr.Radio( | |
["Short", "Normal", "Long"], | |
label="Caption Length", | |
value=lambda: "Normal", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(caption, [img, caption_mode], output) | |
img.change(caption, [img, caption_mode], output) | |
elif mode == "Detect": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Object", | |
value="Cat", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(detect, [img, prompt], [thought, ann]) | |
prompt.submit(detect, [img, prompt], [thought, ann]) | |
img.change(detect, [img, prompt], [thought, ann]) | |
elif mode == "Point": | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Object", | |
value="Cat", | |
scale=4, | |
) | |
submit = gr.Button("Submit") | |
img = gr.Image(type="pil", label="Upload an Image") | |
submit.click(point, [img, prompt], [thought, ann]) | |
prompt.submit(point, [img, prompt], [thought, ann]) | |
img.change(point, [img, prompt], [thought, ann]) | |
else: | |
gr.Markdown("Coming soon!") | |
with gr.Column(): | |
thought = gr.HighlightedText( | |
elem_classes=["chain-of-thought"], | |
label="Thinking tokens", | |
interactive=False, | |
) | |
output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True) | |
ann = gr.Image(visible=False) | |
def on_select(img, evt: gr.SelectData): | |
if img is None or evt.value[1] is None: | |
return gr.update(visible=False, value=None) | |
w, h = img.size | |
if w > 768 or h > 768: | |
img = Resize(768)(img) | |
w, h = img.size | |
coords = json.loads(evt.value[1]) | |
if len(coords) % 2 != 0: | |
raise ValueError("Only points supported right now.") | |
img_clone = img.copy() | |
draw = ImageDraw.Draw(img_clone) | |
for i in range(0, len(coords), 2): # Step by 2 to handle x,y pairs | |
x = int(coords[i] * w) | |
y = int(coords[i + 1] * h) | |
draw.ellipse( | |
(x - 3, y - 3, x + 3, y + 3), | |
fill="red", | |
outline="red", | |
) | |
return gr.update(visible=True, value=img_clone) | |
thought.select(on_select, [input_image], [ann]) | |
input_image.change(lambda: gr.update(visible=False), [], [ann]) | |
mode_radio.change( | |
lambda: ("", "", gr.update(visible=False, value=None)), | |
[], | |
[output, thought, ann], | |
) | |
demo.queue().launch() | |