try: import spaces IN_SPACES = True except ImportError: from functools import wraps import inspect class spaces: @staticmethod def GPU(duration): def decorator(func): @wraps(func) # Preserves the original function's metadata def wrapper(*args, **kwargs): if inspect.isgeneratorfunction(func): # If the decorated function is a generator, yield from it yield from func(*args, **kwargs) else: # For regular functions, just return the result return func(*args, **kwargs) return wrapper return decorator IN_SPACES = False import torch import os import gradio as gr import json from queue import Queue from threading import Thread from transformers import ( TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM, ) from PIL import ImageDraw from torchvision.transforms.v2 import Resize if IN_SPACES: import subprocess subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) auth_token = os.environ.get("TOKEN_FROM_SECRET") or True tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream-next") moondream = AutoModelForCausalLM.from_pretrained( "vikhyatk/moondream-next", trust_remote_code=True, torch_dtype=torch.float16, device_map={"": "cuda"}, attn_implementation="flash_attention_2", token=auth_token if IN_SPACES else None, ) # CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"] # def get_ckpt(filename): # ckpts = [ # torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS # ] # avg_ckpt = { # key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts) # for key in ckpts[0] # } # return avg_ckpt # moondream.load_state_dict(get_ckpt("model.pt")) moondream.eval() def convert_to_entities(text, coords): """ Converts a string with special markers into an entity representation. Markers: - <|coord|> pairs indicate coordinate markers - <|start_ground|> indicates the start of a ground term - <|end_ground|> indicates the end of a ground term Returns: - Dictionary with cleaned text and entities with their character positions """ # Initialize variables cleaned_text = "" entities = [] entity = [] # Track current position in cleaned text current_pos = 0 # Track if we're currently processing an entity in_entity = False entity_start = 0 i = 0 while i < len(text): # Check for markers if text[i : i + 9] == "<|coord|>": i += 9 entity.append(coords.pop(0)) continue elif text[i : i + 16] == "<|start_ground|>": in_entity = True entity_start = current_pos i += 16 continue elif text[i : i + 14] == "<|end_ground|>": # Store entity position entities.append( { "entity": json.dumps(entity), "start": entity_start, "end": current_pos, } ) entity = [] in_entity = False i += 14 continue # Add character to cleaned text cleaned_text += text[i] current_pos += 1 i += 1 return {"text": cleaned_text, "entities": entities} @spaces.GPU(duration=10) def answer_question(img, prompt): if img is None: yield "", "" return image_embeds = moondream.encode_image(img) streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) queue = Queue() thread = Thread( target=moondream.answer_question, kwargs={ "image_embeds": image_embeds, "question": prompt, "tokenizer": tokenizer, "allow_cot": True, "result_queue": queue, "streamer": streamer, }, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer.strip(), {"text": "Thinking...", "entities": []} answer = queue.get() thought = convert_to_entities(answer["thought"], answer["coords"]) yield answer["answer"], thought @spaces.GPU(duration=10) def caption(img, mode): if img is None: yield "" return streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) thread = Thread( target=moondream.caption, kwargs={ "images": [img], "length": "short" if mode == "Short" else None, "tokenizer": tokenizer, "streamer": streamer, }, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer.strip() @spaces.GPU(duration=10) def detect(img, object): if img is None: yield "", gr.update(visible=False, value=None) return w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size objs = moondream.detect(img, object, tokenizer) draw_image = ImageDraw.Draw(img) for o in objs: draw_image.rectangle( (o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h), outline="red", width=3, ) yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( visible=True, value=img ) @spaces.GPU(duration=10) def point(img, object): if img is None: yield "", gr.update(visible=False, value=None) return w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size objs = moondream.point(img, object, tokenizer) draw_image = ImageDraw.Draw(img) for o in objs: draw_image.ellipse( (o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5), fill="red", outline="blue", width=2, ) yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( visible=True, value=img ) js = """ function createBgAnimation() { var canvas = document.createElement('canvas'); canvas.id = 'life-canvas'; document.body.appendChild(canvas); var canvas = document.getElementById('life-canvas'); var ctx = canvas.getContext('2d'); function resizeCanvas() { canvas.width = window.innerWidth; canvas.height = window.innerHeight; } resizeCanvas(); window.addEventListener('resize', resizeCanvas); var cellSize = 8; var cols = Math.ceil(canvas.width / cellSize); var rows = Math.ceil(canvas.height / cellSize); // Track cell age for color variation var grid = new Array(cols).fill(null) .map(() => new Array(rows).fill(null) .map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1 function countNeighbors(grid, x, y) { var sum = 0; for (var i = -1; i < 2; i++) { for (var j = -1; j < 2; j++) { var col = (x + i + cols) % cols; var row = (y + j + rows) % rows; sum += grid[col][row] ? 1 : 0; } } sum -= grid[x][y] ? 1 : 0; return sum; } function computeNextGeneration() { var next = grid.map(arr => [...arr]); for (var i = 0; i < cols; i++) { for (var j = 0; j < rows; j++) { var neighbors = countNeighbors(grid, i, j); var state = grid[i][j]; if (state) { if (neighbors < 2 || neighbors > 3) { next[i][j] = 0; // Cell dies } else { next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5 } } else if (neighbors === 3) { next[i][j] = 1; // New cell born } } } grid = next; } function getColor(age, isDarkMode) { // Light mode colors var lightColors = { 1: '#dae1f5', // Light blue-grey 2: '#d3e0f4', 3: '#ccdff3', 4: '#c5def2', 5: '#beddf1' // Slightly deeper blue-grey }; // Dark mode colors var darkColors = { /* 1: '#4a5788', // Deep blue-grey 2: '#4c5a8d', 3: '#4e5d92', 4: '#506097', 5: '#52639c' // Brighter blue-grey */ 1: 'rgb(16, 20, 32)', 2: 'rgb(21, 25, 39)', 3: 'rgb(26, 30, 46)', 4: 'rgb(31, 35, 53)', 5: 'rgb(36, 40, 60)' }; return isDarkMode ? darkColors[age] : lightColors[age]; } function draw() { var isDarkMode = document.body.classList.contains('dark'); ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0'; ctx.fillRect(0, 0, canvas.width, canvas.height); for (var i = 0; i < cols; i++) { for (var j = 0; j < rows; j++) { if (grid[i][j]) { ctx.fillStyle = getColor(grid[i][j], isDarkMode); ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1); } } } } var lastFrame = 0; var frameInterval = 300; function animate(timestamp) { if (timestamp - lastFrame >= frameInterval) { draw(); computeNextGeneration(); lastFrame = timestamp; } requestAnimationFrame(animate); } animate(0); } """ css = """ .output-text span p { font-size: 1.4rem !important; } .chain-of-thought { opacity: 0.7 !important; } .chain-of-thought span.label { display: none; } .chain-of-thought span.textspan { padding-right: 0; } #life-canvas { position: fixed; top: 0; left: 0; width: 100%; height: 100%; z-index: -1; opacity: 0.3; } """ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo: if IN_SPACES: gr.HTML("") gr.Markdown( """ # 🌔 moondream vl (new) A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream) """ ) mode_radio = gr.Radio( ["Caption", "Query", "Detect", "Point"], show_label=False, value=lambda: "Caption", ) input_image = gr.State(None) with gr.Row(): with gr.Column(): @gr.render(inputs=[mode_radio]) def show_inputs(mode): if mode == "Query": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Input", value="How many people are in this image?", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(answer_question, [img, prompt], [output, thought]) prompt.submit(answer_question, [img, prompt], [output, thought]) img.change(answer_question, [img, prompt], [output, thought]) img.change(lambda img: img, [img], [input_image]) elif mode == "Caption": with gr.Group(): with gr.Row(): caption_mode = gr.Radio( ["Short", "Normal"], label="Caption Length", value=lambda: "Normal", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(caption, [img, caption_mode], output) img.change(caption, [img, caption_mode], output) elif mode == "Detect": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Object", value="Cat", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(detect, [img, prompt], [thought, ann]) prompt.submit(detect, [img, prompt], [thought, ann]) img.change(detect, [img, prompt], [thought, ann]) elif mode == "Point": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Object", value="Cat", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(point, [img, prompt], [thought, ann]) prompt.submit(point, [img, prompt], [thought, ann]) img.change(point, [img, prompt], [thought, ann]) else: gr.Markdown("Coming soon!") with gr.Column(): thought = gr.HighlightedText( elem_classes=["chain-of-thought"], label="Thinking tokens", interactive=False, ) output = gr.Markdown(label="Response", elem_classes=["output-text"]) ann = gr.Image(visible=False) def on_select(img, evt: gr.SelectData): if img is None or evt.value[1] is None: return gr.update(visible=False, value=None) w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size coords = json.loads(evt.value[1]) if len(coords) != 2: raise ValueError("Only points supported right now.") coords[0] = int(coords[0] * w) coords[1] = int(coords[1] * h) img_clone = img.copy() draw = ImageDraw.Draw(img_clone) draw.ellipse( (coords[0] - 3, coords[1] - 3, coords[0] + 3, coords[1] + 3), fill="red", outline="red", ) return gr.update(visible=True, value=img_clone) thought.select(on_select, [input_image], [ann]) input_image.change(lambda: gr.update(visible=False), [], [ann]) mode_radio.change( lambda: ("", "", gr.update(visible=False, value=None)), [], [output, thought, ann], ) demo.queue().launch(share=True)