REVISION = "d643fc6a0410366a34fb89c6877c52a09cb41820" try: import spaces IN_SPACES = True except ImportError: from functools import wraps import inspect class spaces: @staticmethod def GPU(duration): def decorator(func): @wraps(func) # Preserves the original function's metadata def wrapper(*args, **kwargs): if inspect.isgeneratorfunction(func): # If the decorated function is a generator, yield from it yield from func(*args, **kwargs) else: # For regular functions, just return the result return func(*args, **kwargs) return wrapper return decorator IN_SPACES = False import torch import os import gradio as gr import json from queue import Queue from threading import Thread from transformers import ( TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM, ) from PIL import ImageDraw from torchvision.transforms.v2 import Resize os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream-next", revision=REVISION) moondream = AutoModelForCausalLM.from_pretrained( "vikhyatk/moondream-next", trust_remote_code=True, torch_dtype=torch.float16, device_map={"": "cuda"}, revision=REVISION ) # CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"] # def get_ckpt(filename): # ckpts = [ # torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS # ] # avg_ckpt = { # key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts) # for key in ckpts[0] # } # return avg_ckpt # moondream.load_state_dict(get_ckpt("model.pt")) moondream.eval() def convert_to_entities(text, coords): """ Converts a string with special markers into an entity representation. Markers: - <|coord|> pairs indicate coordinate markers - <|start_ground_points|> indicates the start of grounding - <|start_ground_text|> indicates the start of a ground term - <|end_ground|> indicates the end of a ground term Returns: - Dictionary with cleaned text and entities with their character positions """ # Initialize variables cleaned_text = "" entities = [] entity = [] # Track current position in cleaned text current_pos = 0 # Track if we're currently processing an entity in_entity = False entity_start = 0 i = 0 while i < len(text): # Check for markers if text[i : i + 9] == "<|coord|>": i += 9 entity.append(coords.pop(0)) continue elif text[i : i + 23] == "<|start_ground_points|>": in_entity = True entity_start = current_pos i += 23 continue elif text[i : i + 21] == "<|start_ground_text|>": entity_start = current_pos i += 21 continue elif text[i : i + 14] == "<|end_ground|>": # Store entity position entities.append( { "entity": json.dumps(entity), "start": entity_start, "end": current_pos, } ) entity = [] in_entity = False i += 14 continue # Add character to cleaned text cleaned_text += text[i] current_pos += 1 i += 1 return {"text": cleaned_text, "entities": entities} @spaces.GPU(duration=30) def answer_question(img, prompt): if img is None: yield "", "" return buffer = "" for new_text in moondream.query(img, prompt, stream=True)["answer"]: buffer += new_text yield buffer.strip(), {"text": "Thinking...", "entities": []} @spaces.GPU(duration=10) def caption(img, mode): if img is None: yield "" return buffer = "" if mode == "Short": l = "short" elif mode == "Long": l = "long" else: l = "normal" for t in moondream.caption(img, length=l, stream=True)["caption"]: buffer += t yield buffer.strip() @spaces.GPU(duration=10) def detect(img, object): if img is None: yield "", gr.update(visible=False, value=None) return w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size objs = moondream.detect(img, object)["objects"] draw_image = ImageDraw.Draw(img) for o in objs: draw_image.rectangle( (o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h), outline="red", width=3, ) yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( visible=True, value=img ) @spaces.GPU(duration=10) def point(img, object): if img is None: yield "", gr.update(visible=False, value=None) return w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size objs = moondream.point(img, object)["points"] draw_image = ImageDraw.Draw(img) for o in objs: draw_image.ellipse( (o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5), fill="red", outline="blue", width=2, ) yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( visible=True, value=img ) # js = """ # function createBgAnimation() { # var canvas = document.createElement('canvas'); # canvas.id = 'life-canvas'; # document.body.appendChild(canvas); # var canvas = document.getElementById('life-canvas'); # var ctx = canvas.getContext('2d'); # function resizeCanvas() { # canvas.width = window.innerWidth; # canvas.height = window.innerHeight; # } # resizeCanvas(); # window.addEventListener('resize', resizeCanvas); # var cellSize = 8; # var cols = Math.ceil(canvas.width / cellSize); # var rows = Math.ceil(canvas.height / cellSize); # // Track cell age for color variation # var grid = new Array(cols).fill(null) # .map(() => new Array(rows).fill(null) # .map(() => Math.random() > 0.8 ? 1 : 0)); // If alive, start with age 1 # function countNeighbors(grid, x, y) { # var sum = 0; # for (var i = -1; i < 2; i++) { # for (var j = -1; j < 2; j++) { # var col = (x + i + cols) % cols; # var row = (y + j + rows) % rows; # sum += grid[col][row] ? 1 : 0; # } # } # sum -= grid[x][y] ? 1 : 0; # return sum; # } # function computeNextGeneration() { # var next = grid.map(arr => [...arr]); # for (var i = 0; i < cols; i++) { # for (var j = 0; j < rows; j++) { # var neighbors = countNeighbors(grid, i, j); # var state = grid[i][j]; # if (state) { # if (neighbors < 2 || neighbors > 3) { # next[i][j] = 0; // Cell dies # } else { # next[i][j] = Math.min(state + 1, 5); // Age the cell, max age of 5 # } # } else if (neighbors === 3) { # next[i][j] = 1; // New cell born # } # } # } # grid = next; # } # function getColor(age, isDarkMode) { # // Light mode colors # var lightColors = { # 1: '#dae1f5', // Light blue-grey # 2: '#d3e0f4', # 3: '#ccdff3', # 4: '#c5def2', # 5: '#beddf1' // Slightly deeper blue-grey # }; # // Dark mode colors # var darkColors = { # /* # 1: '#4a5788', // Deep blue-grey # 2: '#4c5a8d', # 3: '#4e5d92', # 4: '#506097', # 5: '#52639c' // Brighter blue-grey # */ # 1: 'rgb(16, 20, 32)', # 2: 'rgb(21, 25, 39)', # 3: 'rgb(26, 30, 46)', # 4: 'rgb(31, 35, 53)', # 5: 'rgb(36, 40, 60)' # }; # return isDarkMode ? darkColors[age] : lightColors[age]; # } # function draw() { # var isDarkMode = document.body.classList.contains('dark'); # ctx.fillStyle = isDarkMode ? '#0b0f19' : '#f0f0f0'; # ctx.fillRect(0, 0, canvas.width, canvas.height); # for (var i = 0; i < cols; i++) { # for (var j = 0; j < rows; j++) { # if (grid[i][j]) { # ctx.fillStyle = getColor(grid[i][j], isDarkMode); # ctx.fillRect(i * cellSize, j * cellSize, cellSize - 1, cellSize - 1); # } # } # } # } # var lastFrame = 0; # var frameInterval = 300; # function animate(timestamp) { # if (timestamp - lastFrame >= frameInterval) { # draw(); # computeNextGeneration(); # lastFrame = timestamp; # } # requestAnimationFrame(animate); # } # animate(0); # } # """ js = "" css = """ .output-text span p { font-size: 1.4rem !important; } .chain-of-thought { opacity: 0.7 !important; } .chain-of-thought span.label { display: none; } .chain-of-thought span.textspan { padding-right: 0; } #life-canvas { /*position: fixed; top: 0; left: 0; width: 100%; height: 100%; z-index: -1; opacity: 0.3;*/ } """ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo: if IN_SPACES: # gr.HTML("") pass gr.Markdown( """ # 🌔 moondream vl (new) A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream) """ ) mode_radio = gr.Radio( ["Caption", "Query", "Detect", "Point"], show_label=False, value=lambda: "Caption", ) input_image = gr.State(None) with gr.Row(): with gr.Column(): @gr.render(inputs=[mode_radio]) def show_inputs(mode): if mode == "Query": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Input", value="How many people are in this image?", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(answer_question, [img, prompt], [output, thought]) prompt.submit(answer_question, [img, prompt], [output, thought]) img.change(answer_question, [img, prompt], [output, thought]) img.change(lambda img: img, [img], [input_image]) elif mode == "Caption": with gr.Group(): with gr.Row(): caption_mode = gr.Radio( ["Short", "Normal", "Long"], label="Caption Length", value=lambda: "Normal", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(caption, [img, caption_mode], output) img.change(caption, [img, caption_mode], output) elif mode == "Detect": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Object", value="Cat", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(detect, [img, prompt], [thought, ann]) prompt.submit(detect, [img, prompt], [thought, ann]) img.change(detect, [img, prompt], [thought, ann]) elif mode == "Point": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Object", value="Cat", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(point, [img, prompt], [thought, ann]) prompt.submit(point, [img, prompt], [thought, ann]) img.change(point, [img, prompt], [thought, ann]) else: gr.Markdown("Coming soon!") with gr.Column(): thought = gr.HighlightedText( elem_classes=["chain-of-thought"], label="Thinking tokens", interactive=False, ) output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True) ann = gr.Image(visible=False) def on_select(img, evt: gr.SelectData): if img is None or evt.value[1] is None: return gr.update(visible=False, value=None) w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size coords = json.loads(evt.value[1]) if len(coords) % 2 != 0: raise ValueError("Only points supported right now.") img_clone = img.copy() draw = ImageDraw.Draw(img_clone) for i in range(0, len(coords), 2): # Step by 2 to handle x,y pairs x = int(coords[i] * w) y = int(coords[i + 1] * h) draw.ellipse( (x - 3, y - 3, x + 3, y + 3), fill="red", outline="red", ) return gr.update(visible=True, value=img_clone) thought.select(on_select, [input_image], [ann]) input_image.change(lambda: gr.update(visible=False), [], [ann]) mode_radio.change( lambda: ("", "", gr.update(visible=False, value=None)), [], [output, thought, ann], ) demo.queue().launch()