test-space / app.py
vikhyatk's picture
Update app.py
038e09d verified
raw
history blame
11 kB
try:
import spaces
IN_SPACES = True
except ImportError:
from functools import wraps
import inspect
class spaces:
@staticmethod
def GPU(duration):
def decorator(func):
@wraps(func) # Preserves the original function's metadata
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the decorated function is a generator, yield from it
yield from func(*args, **kwargs)
else:
# For regular functions, just return the result
return func(*args, **kwargs)
return wrapper
return decorator
IN_SPACES = False
import torch
from queue import Queue
import os
import gradio as gr
from threading import Thread
from transformers import (
TextIteratorStreamer,
AutoTokenizer,
AutoModelForCausalLM,
)
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
if IN_SPACES:
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
moondream = AutoModelForCausalLM.from_pretrained(
"vikhyatk/moondream-next",
trust_remote_code=True,
torch_dtype=torch.float16,
device_map={"": "cuda"},
attn_implementation="flash_attention_2",
token=auth_token if IN_SPACES else None,
)
moondream.eval()
@spaces.GPU(duration=10)
def answer_question(img, prompt):
if img is None:
yield "", ""
return
image_embeds = moondream.encode_image(img)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
queue = Queue()
thread = Thread(
target=moondream.answer_question,
kwargs={
"image_embeds": image_embeds,
"question": prompt,
"tokenizer": tokenizer,
"allow_cot": True,
"result_queue": queue,
"streamer": streamer,
},
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer.strip(), "Thinking..."
answer = queue.get()
yield answer["answer"], answer["thought"]
@spaces.GPU(duration=10)
def caption(img, mode):
if img is None:
yield ""
return
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
thread = Thread(
target=moondream.caption,
kwargs={
"images": [img],
"length": "short" if mode == "Short" else None,
"tokenizer": tokenizer,
"streamer": streamer,
},
)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer.strip()
@spaces.GPU(duration=10)
def detect(img, object):
if img is None:
yield "", gr.update(visible=False, value=None)
return
w, h = img.size
if w > 768 or h > 768:
img = Resize(768)(img)
objs = moondream.detect(img, object, tokenizer)
draw_image = ImageDraw.Draw(img)
for o in objs:
draw_image.rectangle(
(o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h),
outline="red",
width=3,
)
yield f"{len(objs)} detected", gr.update(visible=True, value=img)
js = """
function createBgAnimation() {
var canvas = document.createElement('canvas');
canvas.id = 'life-canvas';
document.body.appendChild(canvas);
var canvas = document.getElementById('life-canvas');
var ctx = canvas.getContext('2d');
function resizeCanvas() {
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
}
resizeCanvas();
window.addEventListener('resize', resizeCanvas);
var cellSize = 8;
var cols = Math.ceil(canvas.width / cellSize);
var rows = Math.ceil(canvas.height / cellSize);
// Track cell age for color variation
var grid = new Array(cols).fill(null)
.map(() => new Array(rows).fill(null)
.map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1
function countNeighbors(grid, x, y) {
var sum = 0;
for (var i = -1; i < 2; i++) {
for (var j = -1; j < 2; j++) {
var col = (x + i + cols) % cols;
var row = (y + j + rows) % rows;
sum += grid[col][row] ? 1 : 0;
}
}
sum -= grid[x][y] ? 1 : 0;
return sum;
}
function computeNextGeneration() {
var next = grid.map(arr => [...arr]);
for (var i = 0; i < cols; i++) {
for (var j = 0; j < rows; j++) {
var neighbors = countNeighbors(grid, i, j);
var state = grid[i][j];
if (state) {
if (neighbors < 2 || neighbors > 3) {
next[i][j] = 0; // Cell dies
} else {
next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5
}
} else if (neighbors === 3) {
next[i][j] = 1; // New cell born
}
}
}
grid = next;
}
function getColor(age, isDarkMode) {
// Light mode colors
var lightColors = {
1: '#dae1f5', // Light blue-grey
2: '#d3e0f4',
3: '#ccdff3',
4: '#c5def2',
5: '#beddf1' // Slightly deeper blue-grey
};
// Dark mode colors
var darkColors = {
/*
1: '#4a5788', // Deep blue-grey
2: '#4c5a8d',
3: '#4e5d92',
4: '#506097',
5: '#52639c' // Brighter blue-grey
*/
1: 'rgb(16, 20, 32)',
2: 'rgb(21, 25, 39)',
3: 'rgb(26, 30, 46)',
4: 'rgb(31, 35, 53)',
5: 'rgb(36, 40, 60)'
};
return isDarkMode ? darkColors[age] : lightColors[age];
}
function draw() {
var isDarkMode = document.body.classList.contains('dark');
ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0';
ctx.fillRect(0, 0, canvas.width, canvas.height);
for (var i = 0; i < cols; i++) {
for (var j = 0; j < rows; j++) {
if (grid[i][j]) {
ctx.fillStyle = getColor(grid[i][j], isDarkMode);
ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1);
}
}
}
}
var lastFrame = 0;
var frameInterval = 300;
function animate(timestamp) {
if (timestamp - lastFrame >= frameInterval) {
draw();
computeNextGeneration();
lastFrame = timestamp;
}
requestAnimationFrame(animate);
}
animate(0);
}
"""
css = """
.output-text span p {
font-size: 1.4rem !important;
}
.chain-of-thought {
opacity: 0.7 !important;
}
#life-canvas {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: -1;
opacity: 0.3;
}
"""
with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
if IN_SPACES:
gr.HTML("<style>body, body gradio-app { background: none !important; }</style>")
gr.Markdown(
"""
# 🌔 moondream vl (new)
A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream)
"""
)
mode_radio = gr.Radio(
["Caption", "Query", "Detect"],
show_label=False,
value=lambda: "Caption",
)
with gr.Row():
with gr.Column():
@gr.render(inputs=[mode_radio])
def show_inputs(mode):
if mode == "Query":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Input",
value="How many people are in this image?",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(answer_question, [img, prompt], [output, thought])
prompt.submit(answer_question, [img, prompt], [output, thought])
img.change(answer_question, [img, prompt], [output, thought])
elif mode == "Caption":
with gr.Group():
with gr.Row():
caption_mode = gr.Radio(
["Short", "Normal"],
label="Caption Length",
value=lambda: "Normal",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(caption, [img, caption_mode], output)
img.change(caption, [img, caption_mode], output)
elif mode == "Detect":
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Object",
value="Cat",
scale=4,
)
submit = gr.Button("Submit")
img = gr.Image(type="pil", label="Upload an Image")
submit.click(detect, [img, prompt], [thought, ann])
prompt.submit(detect, [img, prompt], [thought, ann])
img.change(detect, [img, prompt], [thought, ann])
else:
gr.Markdown("Coming soon!")
with gr.Column():
thought = gr.Markdown(elem_classes=["chain-of-thought"], line_breaks=True)
output = gr.Markdown(label="Response", elem_classes=["output-text"])
ann = gr.Image(visible=False)
mode_radio.change(
lambda: ("", "", gr.update(visible=False, value=None)),
[],
[output, thought, ann],
)
demo.queue().launch()