import os os.system('pip install gradio-image-prompter') os.system('pip install pydantic==2.10.6') import gradio as gr import torch import spaces import json import base64 from io import BytesIO from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor import os import pandas as pd from utils import * from PIL import Image from gradio_image_prompter import ImagePrompter sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto") sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base") sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") @spaces.GPU def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None): if input_boxes is not None: input_boxes = [input_boxes] if model_id == 'sam': inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") else: inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") original_sizes = inputs["original_sizes"] reshaped_sizes = inputs["reshaped_input_sizes"] if model_id == 'sam': inputs = inputs.to(sam_model.device) with torch.no_grad(): outputs = sam_model(**inputs) else: inputs = inputs.to(sam_hq_model.device) with torch.no_grad(): outputs = sam_hq_model(**inputs) if model_id == 'sam': masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), original_sizes, reshaped_sizes ) else: masks = sam_hq_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), original_sizes, reshaped_sizes ) scores = outputs.iou_scores return masks, scores def process_inputs(prompts): raw_entries = prompts["points"] input_points = [] input_boxes = [] for entry in raw_entries: x1, y1, type_, x2, y2, cls = entry if type_ == 1: input_points.append([int(x1), int(y1)]) elif type_ == 2: x_min = int(min(x1, x2)) y_min = int(min(y1, y2)) x_max = int(max(x1, x2)) y_max = int(max(y1, y2)) input_boxes.append([x_min, y_min, x_max, y_max]) input_boxes = [input_boxes] if input_boxes else None input_points = [input_points] if input_points else None user_image = prompts['image'] sam_masks, sam_scores = predict_masks_and_scores('sam', user_image, input_boxes=input_boxes, input_points=input_points) sam_hq_masks, sam_hq_scores = predict_masks_and_scores('sam_hq', user_image, input_boxes=input_boxes, input_points=input_points) if input_boxes and input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ') elif input_boxes: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ') elif input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ') else: img1_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM_HQ') print('sam_masks', sam_masks) print('sam_scores', sam_scores) print('sam_hq_masks', sam_hq_masks) print('sam_hq_scores', sam_hq_scores) print('input_boxes', input_boxes) print('input_points', input_points) print('user_image', user_image) print("img1_b64", img1_b64) print("img2_b64", img2_b64) html_code = f"""
""" return html_code process_inputs.zerogpu = True example_paths = [{"image": 'images/' + path} for path in os.listdir('images')] theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald") with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo: image_path_box = gr.Textbox(visible=False) gr.Markdown("## 🔍 Compare SAM vs SAM-HQ") gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it or upload your unique image.") gr.Markdown("Draw boxes and/or points over the image and click Submit!") gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)") print('example_paths', example_paths) result_html = gr.HTML(elem_id="result-html") gr.Interface( fn=process_inputs, examples=example_paths, cache_examples=False, inputs=ImagePrompter(show_label=False), outputs=result_html, ) gr.HTML(""" """) demo.launch()