test-space / app.py
vikhyatk's picture
Update app.py
987863f verified
raw
history blame
15.4 kB
REVISION = "d643fc6a0410366a34fb89c6877c52a09cb41820"
try:
import spaces
IN_SPACES = True
except ImportError:
from functools import wraps
import inspect
class spaces:
@staticmethod
def GPU(duration):
def decorator(func):
@wraps(func) # Preserves the original function's metadata
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the decorated function is a generator, yield from it
yield from func(*args, **kwargs)
else:
# For regular functions, just return the result
return func(*args, **kwargs)
return wrapper
return decorator
IN_SPACES = False
import torch
import os
import gradio as gr
import json
from queue import Queue
from threading import Thread
from transformers import (
TextIteratorStreamer,
AutoTokenizer,
AutoModelForCausalLM,
)
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream-next", revision=REVISION)
moondream = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
torch_dtype=torch.float16,
device_map={"": "cuda"},
revision=REVISION
)
# CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"]
# def get_ckpt(filename):
# ckpts = [
# torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS
# ]
# avg_ckpt = {
# key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts)
# for key in ckpts[0]
# }
# return avg_ckpt
# moondream.load_state_dict(get_ckpt("model.pt"))
moondream.eval()
def convert_to_entities(text, coords):
"""
Converts a string with special markers into an entity representation.
Markers:
- <|coord|> pairs indicate coordinate markers
- <|start_ground_points|> indicates the start of grounding
- <|start_ground_text|> indicates the start of a ground term
- <|end_ground|> indicates the end of a ground term
Returns:
- Dictionary with cleaned text and entities with their character positions
"""
# Initialize variables
cleaned_text = ""
entities = []
entity = []
# Track current position in cleaned text
current_pos = 0
# Track if we're currently processing an entity
in_entity = False
entity_start = 0
i = 0
while i < len(text):
# Check for markers
if text[i : i + 9] == "<|coord|>":
i += 9
entity.append(coords.pop(0))
continue
elif text[i : i + 23] == "<|start_ground_points|>":
in_entity = True
entity_start = current_pos
i += 23
continue
elif text[i : i + 21] == "<|start_ground_text|>":
entity_start = current_pos
i += 21
continue
elif text[i : i + 14] == "<|end_ground|>":
# Store entity position
entities.append(
{
"entity": json.dumps(entity),
"start": entity_start,
"end": current_pos,
}
)
entity = []
in_entity = False
i += 14
continue
# Add character to cleaned text
cleaned_text += text[i]
current_pos += 1
i += 1
return {"text": cleaned_text, "entities": entities}
@spaces.GPU(duration=30)
def answer_question(img, prompt):
if img is None:
yield "", ""
return
buffer = ""
for new_text in moondream.query(img, prompt, stream=True)["answer"]:
buffer += new_text
yield buffer.strip(), {"text": "Thinking...", "entities": []}
@spaces.GPU(duration=10)
def caption(img, mode):
if img is None:
yield ""
return
buffer = ""
if mode == "Short":
l = "short"
elif mode == "Long":
l = "long"
else:
l = "normal"
for t in moondream.caption(img, length=l, stream=True)["caption"]:
buffer += t
yield buffer.strip()
@spaces.GPU(duration=10)
def detect(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.detect(img, object)["objects"]
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.rectangle(
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h),
outline="red",
width=3,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
@spaces.GPU(duration=10)
def point(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.point(img, object)["points"]
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.ellipse(
(o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5),
fill="red",
outline="blue",
width=2,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
# js = """
# function createBgAnimation() {
# var canvas = document.createElement('canvas');
# canvas.id = 'life-canvas';
# document.body.appendChild(canvas);
# var canvas = document.getElementById('life-canvas');
# var ctx = canvas.getContext('2d');
# function resizeCanvas() {
# canvas.width = window.innerWidth;
# canvas.height = window.innerHeight;
# }
# resizeCanvas();
# window.addEventListener('resize', resizeCanvas);
# var cellSize = 8;
# var cols = Math.ceil(canvas.width / cellSize);
# var rows = Math.ceil(canvas.height / cellSize);
# // Track cell age for color variation
# var grid = new Array(cols).fill(null)
# .map(() => new Array(rows).fill(null)
# .map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1
# function countNeighbors(grid, x, y) {
# var sum = 0;
# for (var i = -1; i < 2; i++) {
# for (var j = -1; j < 2; j++) {
# var col = (x + i + cols) % cols;
# var row = (y + j + rows) % rows;
# sum += grid[col][row] ? 1 : 0;
# }
# }
# sum -= grid[x][y] ? 1 : 0;
# return sum;
# }
# function computeNextGeneration() {
# var next = grid.map(arr => [...arr]);
# for (var i = 0; i < cols; i++) {
# for (var j = 0; j < rows; j++) {
# var neighbors = countNeighbors(grid, i, j);
# var state = grid[i][j];
# if (state) {
# if (neighbors < 2 || neighbors > 3) {
# next[i][j] = 0; // Cell dies
# } else {
# next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5
# }
# } else if (neighbors === 3) {
# next[i][j] = 1; // New cell born
# }
# }
# }
# grid = next;
# }
# function getColor(age, isDarkMode) {
# // Light mode colors
# var lightColors = {
# 1: '#dae1f5', // Light blue-grey
# 2: '#d3e0f4',
# 3: '#ccdff3',
# 4: '#c5def2',
# 5: '#beddf1' // Slightly deeper blue-grey
# };
# // Dark mode colors
# var darkColors = {
# /*
# 1: '#4a5788', // Deep blue-grey
# 2: '#4c5a8d',
# 3: '#4e5d92',
# 4: '#506097',
# 5: '#52639c' // Brighter blue-grey
# */
# 1: 'rgb(16, 20, 32)',
# 2: 'rgb(21, 25, 39)',
# 3: 'rgb(26, 30, 46)',
# 4: 'rgb(31, 35, 53)',
# 5: 'rgb(36, 40, 60)'
# };
# return isDarkMode ? darkColors[age] : lightColors[age];
# }
# function draw() {
# var isDarkMode = document.body.classList.contains('dark');
# ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0';
# ctx.fillRect(0, 0, canvas.width, canvas.height);
# for (var i = 0; i < cols; i++) {
# for (var j = 0; j < rows; j++) {
# if (grid[i][j]) {
# ctx.fillStyle = getColor(grid[i][j], isDarkMode);
# ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1);
# }
# }
# }
# }
# var lastFrame = 0;
# var frameInterval = 300;
# function animate(timestamp) {
# if (timestamp - lastFrame >= frameInterval) {
# draw();
# computeNextGeneration();
# lastFrame = timestamp;
# }
# requestAnimationFrame(animate);
# }
# animate(0);
# }
# """
js = ""
css = """
.output-text span p {
font-size: 1.4rem !important;
}
.chain-of-thought {
opacity: 0.7 !important;
}
.chain-of-thought span.label {
display: none;
}
.chain-of-thought span.textspan {
padding-right: 0;
}
#life-canvas {
/*position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: -1;
opacity: 0.3;*/
}
"""
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
if IN_SPACES:
# gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
pass
gr.Markdown(
"""
# 🌔 moondream vl (new)
A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream)
"""
)
mode_radio = gr.Radio(
["Caption", "Query", "Detect", "Point"],
show_label=False,
value=lambda: "Caption",
)
input_image = gr.State(None)
with gr.Row():
with gr.Column():
@gr.render(inputs=[mode_radio])
def show_inputs(mode):
if mode == "Query":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="How many people are in this image?",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(answer_question, [img, prompt], [output, thought])
prompt.submit(answer_question, [img, prompt], [output, thought])
img.change(answer_question, [img, prompt], [output, thought])
img.change(lambda img: img, [img], [input_image])
elif mode == "Caption":
with gr.Group():
with gr.Row():
caption_mode = gr.Radio(
["Short", "Normal", "Long"],
label="Caption Length",
value=lambda: "Normal",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(caption, [img, caption_mode], output)
img.change(caption, [img, caption_mode], output)
elif mode == "Detect":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(detect, [img, prompt], [thought, ann])
prompt.submit(detect, [img, prompt], [thought, ann])
img.change(detect, [img, prompt], [thought, ann])
elif mode == "Point":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(point, [img, prompt], [thought, ann])
prompt.submit(point, [img, prompt], [thought, ann])
img.change(point, [img, prompt], [thought, ann])
else:
gr.Markdown("Coming soon!")
with gr.Column():
thought = gr.HighlightedText(
elem_classes=["chain-of-thought"],
label="Thinking tokens",
interactive=False,
)
output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
ann = gr.Image(visible=False)
def on_select(img, evt: gr.SelectData):
if img is None or evt.value[1] is None:
return gr.update(visible=False, value=None)
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
coords = json.loads(evt.value[1])
if len(coords) % 2 != 0:
raise ValueError("Only points supported right now.")
img_clone = img.copy()
draw = ImageDraw.Draw(img_clone)
for i in range(0, len(coords), 2): # Step by 2 to handle x,y pairs
x = int(coords[i] * w)
y = int(coords[i + 1] * h)
draw.ellipse(
(x - 3, y - 3, x + 3, y + 3),
fill="red",
outline="red",
)
return gr.update(visible=True, value=img_clone)
thought.select(on_select, [input_image], [ann])
input_image.change(lambda: gr.update(visible=False), [], [ann])
mode_radio.change(
lambda: ("", "", gr.update(visible=False, value=None)),
[],
[output, thought, ann],
)
demo.queue().launch()