REVISION = "d2e8f9152cc719f1fc3b42b088d65b46ce837462" try: import spaces IN_SPACES = True except ImportError: from functools import wraps import inspect class spaces: @staticmethod def GPU(duration): def decorator(func): @wraps(func) # Preserves the original function's metadata def wrapper(*args, **kwargs): if inspect.isgeneratorfunction(func): # If the decorated function is a generator, yield from it yield from func(*args, **kwargs) else: # For regular functions, just return the result return func(*args, **kwargs) return wrapper return decorator IN_SPACES = False import torch import os import gradio as gr import json from queue import Queue from threading import Thread from transformers import AutoModelForCausalLM from PIL import ImageDraw from torchvision.transforms.v2 import Resize os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True moondream = AutoModelForCausalLM.from_pretrained( "vikhyatk/moondream-next", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map={"": "cuda"}, revision=REVISION ) # CKPT_DIRS = ["/tmp/md-ckpt/ckpt/ft/song-moon-4c-s15/s72001/"] # def get_ckpt(filename): # ckpts = [ # torch.load(os.path.join(dir, filename), map_location="cpu") for dir in CKPT_DIRS # ] # avg_ckpt = { # key.replace("._orig_mod", ""): sum(ckpt[key] for ckpt in ckpts) / len(ckpts) # for key in ckpts[0] # } # return avg_ckpt # moondream.load_state_dict(get_ckpt("model.pt")) moondream.eval() def convert_to_entities(text, coords): """ Converts a string with special markers into an entity representation. Markers: - <|coord|> pairs indicate coordinate markers - <|start_ground_points|> indicates the start of grounding - <|start_ground_text|> indicates the start of a ground term - <|end_ground|> indicates the end of a ground term Returns: - Dictionary with cleaned text and entities with their character positions """ # Initialize variables cleaned_text = "" entities = [] entity = [] # Track current position in cleaned text current_pos = 0 # Track if we're currently processing an entity in_entity = False entity_start = 0 i = 0 while i < len(text): # Check for markers if text[i : i + 9] == "<|coord|>": i += 9 entity.append(coords.pop(0)) continue elif text[i : i + 23] == "<|start_ground_points|>": in_entity = True entity_start = current_pos i += 23 continue elif text[i : i + 21] == "<|start_ground_text|>": entity_start = current_pos i += 21 continue elif text[i : i + 14] == "<|end_ground|>": # Store entity position entities.append( { "entity": json.dumps(entity), "start": entity_start, "end": current_pos, } ) entity = [] in_entity = False i += 14 continue # Add character to cleaned text cleaned_text += text[i] current_pos += 1 i += 1 return {"text": cleaned_text, "entities": entities} @spaces.GPU(duration=30) def answer_question(img, prompt): buffer = "" for new_text in moondream.query(img, prompt, stream=True)["answer"]: buffer += new_text yield buffer.strip(), {"text": "Thinking...", "entities": []} @spaces.GPU(duration=10) def caption(img, mode): if img is None: yield "" return buffer = "" if mode == "Short": l = "short" elif mode == "Long": l = "long" else: l = "normal" for t in moondream.caption(img, length=l, stream=True)["caption"]: buffer += t yield buffer.strip() @spaces.GPU(duration=10) def detect(img, object): if img is None: yield "", gr.update(visible=False, value=None) return w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size objs = moondream.detect(img, object)["objects"] draw_image = ImageDraw.Draw(img) for o in objs: draw_image.rectangle( (o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h), outline="red", width=3, ) yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( visible=True, value=img ) @spaces.GPU(duration=10) def point(img, object): if img is None: yield "", gr.update(visible=False, value=None) return w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size objs = moondream.point(img, object)["points"] draw_image = ImageDraw.Draw(img) for o in objs: draw_image.ellipse( (o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5), fill="red", outline="blue", width=2, ) yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( visible=True, value=img ) @spaces.GPU(duration=10) def localized_query(img, x, y, question): if img is None: yield "", {"text": "", "entities": []}, gr.update(visible=False, value=None) return answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"] w, h = img.size x, y = x * w, y * h img_clone = img.copy() draw = ImageDraw.Draw(img_clone) draw.ellipse( (x - 5, y - 5, x + 5, y + 5), fill="red", outline="blue", ) yield answer, {"text": "", "entities": []}, gr.update(visible=True, value=img_clone) js = "" css = """ .output-text span p { font-size: 1.4rem !important; } .chain-of-thought { opacity: 0.7 !important; } .chain-of-thought span.label { display: none; } .chain-of-thought span.textspan { padding-right: 0; } """ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo: if IN_SPACES: # gr.HTML("") pass gr.Markdown( """ # 🌔 moondream vl (new) A tiny vision language model. [GitHub](https://github.com/vikhyat/moondream) """ ) mode_radio = gr.Radio( ["Caption", "Query", "Detect", "Point", "Localized"], show_label=False, value=lambda: "Caption", ) input_image = gr.State(None) with gr.Row(): with gr.Column(): @gr.render(inputs=[mode_radio]) def show_inputs(mode): if mode == "Query": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Input", value="How many people are in this image?", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(answer_question, [img, prompt], [output, thought]) prompt.submit(answer_question, [img, prompt], [output, thought]) img.change(answer_question, [img, prompt], [output, thought]) img.change(lambda img: img, [img], [input_image]) elif mode == "Caption": with gr.Group(): with gr.Row(): caption_mode = gr.Radio( ["Short", "Normal", "Long"], label="Caption Length", value=lambda: "Normal", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(caption, [img, caption_mode], output) img.change(caption, [img, caption_mode], output) elif mode == "Detect": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Object", value="Cat", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(detect, [img, prompt], [thought, ann]) prompt.submit(detect, [img, prompt], [thought, ann]) img.change(detect, [img, prompt], [thought, ann]) elif mode == "Point": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Object", value="Cat", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") submit.click(point, [img, prompt], [thought, ann]) prompt.submit(point, [img, prompt], [thought, ann]) img.change(point, [img, prompt], [thought, ann]) elif mode == "Localized": with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Input", value="What is this?", scale=4, ) submit = gr.Button("Submit") img = gr.Image(type="pil", label="Upload an Image") x_slider = gr.Slider(label="x", minimum=0, maximum=1) y_slider = gr.Slider(label="y", minimum=0, maximum=1) submit.click(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) prompt.submit(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) x_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) y_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) img.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) else: gr.Markdown("Coming soon!") with gr.Column(): thought = gr.HighlightedText( elem_classes=["chain-of-thought"], label="Thinking tokens", interactive=False, ) output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True) ann = gr.Image(visible=False) def on_select(img, evt: gr.SelectData): if img is None or evt.value[1] is None: return gr.update(visible=False, value=None) w, h = img.size if w > 768 or h > 768: img = Resize(768)(img) w, h = img.size coords = json.loads(evt.value[1]) if len(coords) % 2 != 0: raise ValueError("Only points supported right now.") img_clone = img.copy() draw = ImageDraw.Draw(img_clone) for i in range(0, len(coords), 2): # Step by 2 to handle x,y pairs x = int(coords[i] * w) y = int(coords[i + 1] * h) draw.ellipse( (x - 3, y - 3, x + 3, y + 3), fill="red", outline="red", ) return gr.update(visible=True, value=img_clone) thought.select(on_select, [input_image], [ann]) input_image.change(lambda: gr.update(visible=False), [], [ann]) mode_radio.change( lambda: ("", "", gr.update(visible=False, value=None)), [], [output, thought, ann], ) demo.queue().launch()