test-space / app.py
vikhyatk's picture
Update app.py
d382dff verified
raw
history blame
15.8 kB
try:
import spaces
IN_SPACES = True
except ImportError:
from functools import wraps
import inspect
class spaces:
@staticmethod
def GPU(duration):
def decorator(func):
@wraps(func) # Preserves the original function's metadata
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the decorated function is a generator, yield from it
yield from func(*args, **kwargs)
else:
# For regular functions, just return the result
return func(*args, **kwargs)
return wrapper
return decorator
IN_SPACES = False
import torch
import os
import gradio as gr
import json
from queue import Queue
from threading import Thread
from transformers import (
TextIteratorStreamer,
AutoTokenizer,
AutoModelForCausalLM,
)
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
if IN_SPACES:
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream-next")
moondream = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
torch_dtype=torch.float16,
device_map={"": "cuda"},
attn_implementation="flash_attention_2",
token=auth_token if IN_SPACES else None,
)
# CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"]
# def get_ckpt(filename):
# ckpts = [
# torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS
# ]
# avg_ckpt = {
# key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts)
# for key in ckpts[0]
# }
# return avg_ckpt
# moondream.load_state_dict(get_ckpt("model.pt"))
moondream.eval()
def convert_to_entities(text, coords):
"""
Converts a string with special markers into an entity representation.
Markers:
- <|coord|> pairs indicate coordinate markers
- <|start_ground|> indicates the start of a ground term
- <|end_ground|> indicates the end of a ground term
Returns:
- Dictionary with cleaned text and entities with their character positions
"""
# Initialize variables
cleaned_text = ""
entities = []
entity = []
# Track current position in cleaned text
current_pos = 0
# Track if we're currently processing an entity
in_entity = False
entity_start = 0
i = 0
while i < len(text):
# Check for markers
if text[i : i + 9] == "<|coord|>":
i += 9
entity.append(coords.pop(0))
continue
elif text[i : i + 16] == "<|start_ground|>":
in_entity = True
entity_start = current_pos
i += 16
continue
elif text[i : i + 14] == "<|end_ground|>":
# Store entity position
entities.append(
{
"entity": json.dumps(entity),
"start": entity_start,
"end": current_pos,
}
)
entity = []
in_entity = False
i += 14
continue
# Add character to cleaned text
cleaned_text += text[i]
current_pos += 1
i += 1
return {"text": cleaned_text, "entities": entities}
@spaces.GPU(duration=10)
def answer_question(img, prompt):
if img is None:
yield "", ""
return
image_embeds = moondream.encode_image(img)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
queue = Queue()
thread = Thread(
target=moondream.answer_question,
kwargs={
"image_embeds": image_embeds,
"question": prompt,
"tokenizer": tokenizer,
"allow_cot": True,
"result_queue": queue,
"streamer": streamer,
},
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer.strip(), {"text": "Thinking...", "entities": []}
answer = queue.get()
thought = convert_to_entities(answer["thought"], answer["coords"])
yield answer["answer"], thought
@spaces.GPU(duration=10)
def caption(img, mode):
if img is None:
yield ""
return
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
thread = Thread(
target=moondream.caption,
kwargs={
"images": [img],
"length": "short" if mode == "Short" else None,
"tokenizer": tokenizer,
"streamer": streamer,
},
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer.strip()
@spaces.GPU(duration=10)
def detect(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.detect(img, object, tokenizer)
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.rectangle(
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h),
outline="red",
width=3,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
@spaces.GPU(duration=10)
def point(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
objs = moondream.point(img, object, tokenizer)
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.ellipse(
(o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5),
fill="red",
outline="blue",
width=2,
)
yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
visible=True, value=img
)
js = """
function createBgAnimation() {
var canvas = document.createElement('canvas');
canvas.id = 'life-canvas';
document.body.appendChild(canvas);
var canvas = document.getElementById('life-canvas');
var ctx = canvas.getContext('2d');
function resizeCanvas() {
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
}
resizeCanvas();
window.addEventListener('resize', resizeCanvas);
var cellSize = 8;
var cols = Math.ceil(canvas.width / cellSize);
var rows = Math.ceil(canvas.height / cellSize);
// Track cell age for color variation
var grid = new Array(cols).fill(null)
.map(() => new Array(rows).fill(null)
.map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1
function countNeighbors(grid, x, y) {
var sum = 0;
for (var i = -1; i < 2; i++) {
for (var j = -1; j < 2; j++) {
var col = (x + i + cols) % cols;
var row = (y + j + rows) % rows;
sum += grid[col][row] ? 1 : 0;
}
}
sum -= grid[x][y] ? 1 : 0;
return sum;
}
function computeNextGeneration() {
var next = grid.map(arr => [...arr]);
for (var i = 0; i < cols; i++) {
for (var j = 0; j < rows; j++) {
var neighbors = countNeighbors(grid, i, j);
var state = grid[i][j];
if (state) {
if (neighbors < 2 || neighbors > 3) {
next[i][j] = 0; // Cell dies
} else {
next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5
}
} else if (neighbors === 3) {
next[i][j] = 1; // New cell born
}
}
}
grid = next;
}
function getColor(age, isDarkMode) {
// Light mode colors
var lightColors = {
1: '#dae1f5', // Light blue-grey
2: '#d3e0f4',
3: '#ccdff3',
4: '#c5def2',
5: '#beddf1' // Slightly deeper blue-grey
};
// Dark mode colors
var darkColors = {
/*
1: '#4a5788', // Deep blue-grey
2: '#4c5a8d',
3: '#4e5d92',
4: '#506097',
5: '#52639c' // Brighter blue-grey
*/
1: 'rgb(16, 20, 32)',
2: 'rgb(21, 25, 39)',
3: 'rgb(26, 30, 46)',
4: 'rgb(31, 35, 53)',
5: 'rgb(36, 40, 60)'
};
return isDarkMode ? darkColors[age] : lightColors[age];
}
function draw() {
var isDarkMode = document.body.classList.contains('dark');
ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0';
ctx.fillRect(0, 0, canvas.width, canvas.height);
for (var i = 0; i < cols; i++) {
for (var j = 0; j < rows; j++) {
if (grid[i][j]) {
ctx.fillStyle = getColor(grid[i][j], isDarkMode);
ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1);
}
}
}
}
var lastFrame = 0;
var frameInterval = 300;
function animate(timestamp) {
if (timestamp - lastFrame >= frameInterval) {
draw();
computeNextGeneration();
lastFrame = timestamp;
}
requestAnimationFrame(animate);
}
animate(0);
}
"""
css = """
.output-text span p {
font-size: 1.4rem !important;
}
.chain-of-thought {
opacity: 0.7 !important;
}
.chain-of-thought span.label {
display: none;
}
.chain-of-thought span.textspan {
padding-right: 0;
}
#life-canvas {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: -1;
opacity: 0.3;
}
"""
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
if IN_SPACES:
gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
gr.Markdown(
"""
# 🌔 moondream vl (new)
A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream)
"""
)
mode_radio = gr.Radio(
["Caption", "Query", "Detect", "Point"],
show_label=False,
value=lambda: "Caption",
)
input_image = gr.State(None)
with gr.Row():
with gr.Column():
@gr.render(inputs=[mode_radio])
def show_inputs(mode):
if mode == "Query":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="How many people are in this image?",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(answer_question, [img, prompt], [output, thought])
prompt.submit(answer_question, [img, prompt], [output, thought])
img.change(answer_question, [img, prompt], [output, thought])
img.change(lambda img: img, [img], [input_image])
elif mode == "Caption":
with gr.Group():
with gr.Row():
caption_mode = gr.Radio(
["Short", "Normal"],
label="Caption Length",
value=lambda: "Normal",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(caption, [img, caption_mode], output)
img.change(caption, [img, caption_mode], output)
elif mode == "Detect":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(detect, [img, prompt], [thought, ann])
prompt.submit(detect, [img, prompt], [thought, ann])
img.change(detect, [img, prompt], [thought, ann])
elif mode == "Point":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(point, [img, prompt], [thought, ann])
prompt.submit(point, [img, prompt], [thought, ann])
img.change(point, [img, prompt], [thought, ann])
else:
gr.Markdown("Coming soon!")
with gr.Column():
thought = gr.HighlightedText(
elem_classes=["chain-of-thought"],
label="Thinking tokens",
interactive=False,
)
output = gr.Markdown(label="Response", elem_classes=["output-text"])
ann = gr.Image(visible=False)
def on_select(img, evt: gr.SelectData):
if img is None or evt.value[1] is None:
return gr.update(visible=False, value=None)
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
w, h = img.size
coords = json.loads(evt.value[1])
if len(coords) != 2:
raise ValueError("Only points supported right now.")
coords[0] = int(coords[0] * w)
coords[1] = int(coords[1] * h)
img_clone = img.copy()
draw = ImageDraw.Draw(img_clone)
draw.ellipse(
(coords[0] - 3, coords[1] - 3, coords[0] + 3, coords[1] + 3),
fill="red",
outline="red",
)
return gr.update(visible=True, value=img_clone)
thought.select(on_select, [input_image], [ann])
input_image.change(lambda: gr.update(visible=False), [], [ann])
mode_radio.change(
lambda: ("", "", gr.update(visible=False, value=None)),
[],
[output, thought, ann],
)
demo.queue().launch(share=True)