import gradio as gr import json import os from typing import Any, List, Dict import spaces from PIL import Image, ImageDraw import requests from transformers import AutoModelForImageTextToText, AutoProcessor from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize import torch import re import traceback # --- Configuration --- MODEL_ID = "Hcompany/Holo1-3B" # --- Helpers (robust across different transformers versions) --- def pick_device() -> str: # Force CPU per request return "cpu" def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str: """ Works whether apply_chat_template lives on the processor or tokenizer, or not at all (falls back to naive text join of 'text' contents). """ tok = getattr(processor, "tokenizer", None) if hasattr(processor, "apply_chat_template"): return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) if tok is not None and hasattr(tok, "apply_chat_template"): return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Fallback: concatenate visible text segments texts = [] for m in messages: for c in m.get("content", []): if isinstance(c, dict) and c.get("type") == "text": texts.append(c.get("text", "")) return "\n".join(texts) def batch_decode_compat(processor, token_id_batches, **kw): tok = getattr(processor, "tokenizer", None) if tok is not None and hasattr(tok, "batch_decode"): return tok.batch_decode(token_id_batches, **kw) if hasattr(processor, "batch_decode"): return processor.batch_decode(token_id_batches, **kw) raise AttributeError("No batch_decode available on processor or tokenizer.") def get_image_proc_params(processor) -> Dict[str, int]: """ Safely access image processor params with defaults that work for Qwen2-VL family. """ ip = getattr(processor, "image_processor", None) return { "patch_size": getattr(ip, "patch_size", 14), "merge_size": getattr(ip, "merge_size", 1), "min_pixels": getattr(ip, "min_pixels", 256 * 256), "max_pixels": getattr(ip, "max_pixels", 1280 * 1280), } def trim_generated(generated_ids, inputs): """ Trim prompt tokens from generated tokens when input_ids exist. """ in_ids = getattr(inputs, "input_ids", None) if in_ids is None and isinstance(inputs, dict): in_ids = inputs.get("input_ids", None) if in_ids is None: return [out_ids for out_ids in generated_ids] return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)] # --- Model and Processor Loading (Load once) --- print(f"Loading model and processor for {MODEL_ID} (CPU only)...") model = None processor = None model_loaded = False load_error_message = "" try: # CPU-friendly dtype; bf16 on CPU is spotty, so prefer float32 model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True ).to(pick_device()) processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model_loaded = True print("Model and processor loaded successfully.") except Exception as e: load_error_message = ( f"Error loading model/processor: {e}\n" "This might be due to network issues, an incorrect model ID, or incompatible library versions.\n" "Check the full traceback in the Space logs." ) print(load_error_message) traceback.print_exc() # --- Prompt builder --- def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[dict]: guidelines: str = ( "Localize an element on the GUI image according to my instructions and " "output a click position as Click(x, y) with x num pixels from the left edge " "and y num pixels from the top edge." ) return [ { "role": "user", "content": [ {"type": "image", "image": pil_image}, {"type": "text", "text": f"{guidelines}\n{instruction}"} ], } ] # --- Inference (CPU) --- def run_inference_localization( messages_for_template: List[dict[str, Any]], pil_image_for_processing: Image.Image ) -> str: """ CPU inference; robust to processor/tokenizer differences and logs full traceback on failure. """ try: model.to(pick_device()) # 1) Build prompt text via robust helper text_prompt = apply_chat_template_compat(processor, messages_for_template) # 2) Prepare inputs (text + image) inputs = processor( text=[text_prompt], images=[pil_image_for_processing], padding=True, return_tensors="pt", ) # Move tensor inputs to the same device as model (CPU) if isinstance(inputs, dict): for k, v in list(inputs.items()): if hasattr(v, "to"): inputs[k] = v.to(model.device) # 3) Generate (deterministic) generated_ids = model.generate( **inputs, max_new_tokens=128, do_sample=False, ) # 4) Trim prompt tokens if possible generated_ids_trimmed = trim_generated(generated_ids, inputs) # 5) Decode via robust helper decoded_output = batch_decode_compat( processor, generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return decoded_output[0] if decoded_output else "" except Exception as e: print(f"Error during model inference: {e}") traceback.print_exc() raise # --- Gradio processing function --- def predict_click_location(input_pil_image: Image.Image, instruction: str): if not model_loaded or not processor or not model: return f"Model not loaded. Error: {load_error_message}", None if not input_pil_image: return "No image provided. Please upload an image.", None if not instruction or instruction.strip() == "": return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB") # 1) Resize according to image processor params (safe defaults if missing) try: ip = get_image_proc_params(processor) resized_height, resized_width = smart_resize( input_pil_image.height, input_pil_image.width, factor=ip["patch_size"] * ip["merge_size"], min_pixels=ip["min_pixels"], max_pixels=ip["max_pixels"], ) resized_image = input_pil_image.resize( size=(resized_width, resized_height), resample=Image.Resampling.LANCZOS ) except Exception as e: print(f"Error resizing image: {e}") traceback.print_exc() return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB") # 2) Build messages with image + instruction messages = get_localization_prompt(resized_image, instruction) # 3) Run inference try: coordinates_str = run_inference_localization(messages, resized_image) except Exception as e: return f"Error during model inference: {e}", resized_image.copy().convert("RGB") # 4) Parse coordinates and draw marker output_image_with_click = resized_image.copy().convert("RGB") match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str) if match: try: x = int(match.group(1)) y = int(match.group(2)) draw = ImageDraw.Draw(output_image_with_click) radius = max(5, min(resized_width // 100, resized_height // 100, 15)) bbox = (x - radius, y - radius, x + radius, y + radius) draw.ellipse(bbox, outline="red", width=max(2, radius // 4)) print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})") except Exception as e: print(f"Error drawing on image: {e}") traceback.print_exc() else: print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}") return coordinates_str, output_image_with_click # --- Load Example Data --- example_image = None example_instruction = "Select July 14th as the check-out date" try: example_image_url = "https://huggingface.co/Hcompany/Holo1-7B/resolve/main/calendar_example.jpg" example_image = Image.open(requests.get(example_image_url, stream=True).raw) except Exception as e: print(f"Could not load example image from URL: {e}") traceback.print_exc() try: example_image = Image.new("RGB", (200, 150), color="lightgray") draw = ImageDraw.Draw(example_image) draw.text((10, 10), "Example image\nfailed to load", fill="black") except Exception: pass # --- Gradio UI --- title = "Holo1-7B: Action VLM Localization Demo (CPU)" article = f"""

Model: {MODEL_ID} by HCompany | Paper: HCompany Tech Report | Blog: Surfer-H Blog Post

""" if not model_loaded: with gr.Blocks() as demo: gr.Markdown(f"#
⚠️ Error: Model Failed to Load ⚠️
") gr.Markdown(f"
{load_error_message}
") gr.Markdown("
See Space logs for the full traceback.
") else: with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(f"

{title}

") with gr.Row(): with gr.Column(scale=1): input_image_component = gr.Image(type="pil", label="Input UI Image", height=400) instruction_component = gr.Textbox( label="Instruction", placeholder="e.g., Click the 'Login' button", info="Type the action you want the model to localize on the image." ) submit_button = gr.Button("Localize Click", variant="primary") with gr.Column(scale=1): output_coords_component = gr.Textbox( label="Predicted Coordinates (Format: Click(x, y))", interactive=False ) output_image_component = gr.Image( type="pil", label="Image with Predicted Click Point", height=400, interactive=False ) if example_image: gr.Examples( examples=[[example_image, example_instruction]], inputs=[input_image_component, instruction_component], outputs=[output_coords_component, output_image_component], fn=predict_click_location, cache_examples="lazy", ) gr.Markdown(article) submit_button.click( fn=predict_click_location, inputs=[input_image_component, instruction_component], outputs=[output_coords_component, output_image_component] ) if __name__ == "__main__": # CPU Spaces can be slow; keep debug True for logs demo.launch(debug=True)