test-space / app.py
vikhyatk's picture
Update app.py
38e37ed verified
raw
history blame
15.3 kB
REVISION = "2660b786445034596d9e05a58d98cb4652c16487"
try:
import spaces
IN_SPACES = True
except ImportError:
from functools import wraps
import inspect
class spaces:
@staticmethod
def GPU(duration):
def decorator(func):
@wraps(func) # Preserves the original function's metadata
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the decorated function is a generator, yield from it
yield from func(*args, **kwargs)
else:
# For regular functions, just return the result
return func(*args, **kwargs)
return wrapper
return decorator
IN_SPACES = False
import torch
import os
import gradio as gr
import json
from queue import Queue
from threading import Thread
from transformers import AutoModelForCausalLM
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
moondream = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map={"": "cuda"},
revision=REVISION
)
# CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"]
# def get_ckpt(filename):
# ckpts = [
# torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS
# ]
# avg_ckpt = {
# key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts)
# for key in ckpts[0]
# }
# return avg_ckpt
# moondream.load_state_dict(get_ckpt("model.pt"))
moondream.eval()
def convert_to_entities(text, coords):
"""
Converts a string with special markers into an entity representation.
Markers:
- <|coord|> pairs indicate coordinate markers
- <|start_ground_points|> indicates the start of grounding
- <|start_ground_text|> indicates the start of a ground term
- <|end_ground|> indicates the end of a ground term
Returns:
- Dictionary with cleaned text and entities with their character positions
"""
# Initialize variables
cleaned_text = ""
entities = []
entity = []
# Track current position in cleaned text
current_pos = 0
# Track if we're currently processing an entity
in_entity = False
entity_start = 0
i = 0
while i < len(text):
# Check for markers
if text[i : i + 9] == "<|coord|>":
i += 9
entity.append(coords.pop(0))
continue
elif text[i : i + 23] == "<|start_ground_points|>":
in_entity = True
entity_start = current_pos
i += 23
continue
elif text[i : i + 21] == "<|start_ground_text|>":
entity_start = current_pos
i += 21
continue
elif text[i : i + 14] == "<|end_ground|>":
# Store entity position
entities.append(
{
"entity": json.dumps(entity),
"start": entity_start,
"end": current_pos,
}
)
entity = []
in_entity = False
i += 14
continue
# Add character to cleaned text
cleaned_text += text[i]
current_pos += 1
i += 1
return {"text": cleaned_text, "entities": entities}
@spaces.GPU(duration=30)
def answer_question(img, prompt):
if img is None:
yield "", ""
return
buffer = ""
for new_text in moondream.query(img, prompt, stream=True)["answer"]:
buffer += new_text
yield buffer.strip(), {"text": "Thinking...", "entities": []}
@spaces.GPU(duration=10)
def caption(img, mode):
if img is None:
yield ""
return
buffer = ""
if mode == "Short":
l = "short"
elif mode == "Long":
l = "long"
else:
l = "normal"
for t in moondream.caption(img, length=l, stream=True)["caption"]:
buffer += t
yield buffer.strip()
@spaces.GPU(duration=10)
def detect(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.detect(img, object)["objects"]
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.rectangle(
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h),
outline="red",
width=3,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
@spaces.GPU(duration=10)
def point(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.point(img, object)["points"]
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.ellipse(
(o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5),
fill="red",
outline="blue",
width=2,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
# js = """
# function createBgAnimation() {
# var canvas = document.createElement('canvas');
# canvas.id = 'life-canvas';
# document.body.appendChild(canvas);
# var canvas = document.getElementById('life-canvas');
# var ctx = canvas.getContext('2d');
# function resizeCanvas() {
# canvas.width = window.innerWidth;
# canvas.height = window.innerHeight;
# }
# resizeCanvas();
# window.addEventListener('resize', resizeCanvas);
# var cellSize = 8;
# var cols = Math.ceil(canvas.width / cellSize);
# var rows = Math.ceil(canvas.height / cellSize);
# // Track cell age for color variation
# var grid = new Array(cols).fill(null)
# .map(() => new Array(rows).fill(null)
# .map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1
# function countNeighbors(grid, x, y) {
# var sum = 0;
# for (var i = -1; i < 2; i++) {
# for (var j = -1; j < 2; j++) {
# var col = (x + i + cols) % cols;
# var row = (y + j + rows) % rows;
# sum += grid[col][row] ? 1 : 0;
# }
# }
# sum -= grid[x][y] ? 1 : 0;
# return sum;
# }
# function computeNextGeneration() {
# var next = grid.map(arr => [...arr]);
# for (var i = 0; i < cols; i++) {
# for (var j = 0; j < rows; j++) {
# var neighbors = countNeighbors(grid, i, j);
# var state = grid[i][j];
# if (state) {
# if (neighbors < 2 || neighbors > 3) {
# next[i][j] = 0; // Cell dies
# } else {
# next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5
# }
# } else if (neighbors === 3) {
# next[i][j] = 1; // New cell born
# }
# }
# }
# grid = next;
# }
# function getColor(age, isDarkMode) {
# // Light mode colors
# var lightColors = {
# 1: '#dae1f5', // Light blue-grey
# 2: '#d3e0f4',
# 3: '#ccdff3',
# 4: '#c5def2',
# 5: '#beddf1' // Slightly deeper blue-grey
# };
# // Dark mode colors
# var darkColors = {
# /*
# 1: '#4a5788', // Deep blue-grey
# 2: '#4c5a8d',
# 3: '#4e5d92',
# 4: '#506097',
# 5: '#52639c' // Brighter blue-grey
# */
# 1: 'rgb(16, 20, 32)',
# 2: 'rgb(21, 25, 39)',
# 3: 'rgb(26, 30, 46)',
# 4: 'rgb(31, 35, 53)',
# 5: 'rgb(36, 40, 60)'
# };
# return isDarkMode ? darkColors[age] : lightColors[age];
# }
# function draw() {
# var isDarkMode = document.body.classList.contains('dark');
# ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0';
# ctx.fillRect(0, 0, canvas.width, canvas.height);
# for (var i = 0; i < cols; i++) {
# for (var j = 0; j < rows; j++) {
# if (grid[i][j]) {
# ctx.fillStyle = getColor(grid[i][j], isDarkMode);
# ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1);
# }
# }
# }
# }
# var lastFrame = 0;
# var frameInterval = 300;
# function animate(timestamp) {
# if (timestamp - lastFrame >= frameInterval) {
# draw();
# computeNextGeneration();
# lastFrame = timestamp;
# }
# requestAnimationFrame(animate);
# }
# animate(0);
# }
# """
js = ""
css = """
.output-text span p {
font-size: 1.4rem !important;
}
.chain-of-thought {
opacity: 0.7 !important;
}
.chain-of-thought span.label {
display: none;
}
.chain-of-thought span.textspan {
padding-right: 0;
}
#life-canvas {
/*position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: -1;
opacity: 0.3;*/
}
"""
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
if IN_SPACES:
# gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
pass
gr.Markdown(
"""
# 🌔 moondream vl (new)
A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream)
"""
)
mode_radio = gr.Radio(
["Caption", "Query", "Detect", "Point"],
show_label=False,
value=lambda: "Caption",
)
input_image = gr.State(None)
with gr.Row():
with gr.Column():
@gr.render(inputs=[mode_radio])
def show_inputs(mode):
if mode == "Query":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="How many people are in this image?",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(answer_question, [img, prompt], [output, thought])
prompt.submit(answer_question, [img, prompt], [output, thought])
img.change(answer_question, [img, prompt], [output, thought])
img.change(lambda img: img, [img], [input_image])
elif mode == "Caption":
with gr.Group():
with gr.Row():
caption_mode = gr.Radio(
["Short", "Normal", "Long"],
label="Caption Length",
value=lambda: "Normal",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(caption, [img, caption_mode], output)
img.change(caption, [img, caption_mode], output)
elif mode == "Detect":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(detect, [img, prompt], [thought, ann])
prompt.submit(detect, [img, prompt], [thought, ann])
img.change(detect, [img, prompt], [thought, ann])
elif mode == "Point":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(point, [img, prompt], [thought, ann])
prompt.submit(point, [img, prompt], [thought, ann])
img.change(point, [img, prompt], [thought, ann])
else:
gr.Markdown("Coming soon!")
with gr.Column():
thought = gr.HighlightedText(
elem_classes=["chain-of-thought"],
label="Thinking tokens",
interactive=False,
)
output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
ann = gr.Image(visible=False)
def on_select(img, evt: gr.SelectData):
if img is None or evt.value[1] is None:
return gr.update(visible=False, value=None)
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
coords = json.loads(evt.value[1])
if len(coords) % 2 != 0:
raise ValueError("Only points supported right now.")
img_clone = img.copy()
draw = ImageDraw.Draw(img_clone)
for i in range(0, len(coords), 2): # Step by 2 to handle x,y pairs
x = int(coords[i] * w)
y = int(coords[i + 1] * h)
draw.ellipse(
(x - 3, y - 3, x + 3, y + 3),
fill="red",
outline="red",
)
return gr.update(visible=True, value=img_clone)
thought.select(on_select, [input_image], [ann])
input_image.change(lambda: gr.update(visible=False), [], [ann])
mode_radio.change(
lambda: ("", "", gr.update(visible=False, value=None)),
[],
[output, thought, ann],
)
demo.queue().launch()