Spaces:
Building
on
Zero
Building
on
Zero
import os | |
os.system('pip install gradio-image-prompter') | |
os.system('pip install pydantic==2.10.6') | |
import gradio as gr | |
import torch | |
import spaces | |
import json | |
import base64 | |
from io import BytesIO | |
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor | |
import os | |
import pandas as pd | |
from utils import * | |
from PIL import Image | |
from gradio_image_prompter import ImagePrompter | |
#sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge") | |
#sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge") | |
sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base") | |
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base") | |
#sam_model = SamModel.from_pretrained("facebook/sam-vit-huge") | |
#sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-base") | |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None): | |
if input_boxes is not None: | |
input_boxes = [input_boxes] | |
inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
) | |
scores = outputs.iou_scores | |
return masks, scores | |
def process_inputs(prompts): | |
raw_entries = prompts["points"] | |
input_points = [] | |
input_boxes = [] | |
for entry in raw_entries: | |
x1, y1, type_, x2, y2, cls = entry | |
if type_ == 1: | |
input_points.append([int(x1), int(y1)]) | |
elif type_ == 2: | |
x_min = int(min(x1, x2)) | |
y_min = int(min(y1, y2)) | |
x_max = int(max(x1, x2)) | |
y_max = int(max(y1, y2)) | |
input_boxes.append([x_min, y_min, x_max, y_max]) | |
input_boxes = [input_boxes] if input_boxes else None | |
input_points = [input_points] if input_points else None | |
user_image = prompts['image'] | |
sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points) | |
sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points) | |
if input_boxes and input_points: | |
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM') | |
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ') | |
elif input_boxes: | |
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM') | |
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ') | |
elif input_points: | |
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM') | |
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ') | |
else: | |
img1_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM') | |
img2_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM_HQ') | |
print('user_image', user_image) | |
print("img1_b64", img1_b64) | |
print("img2_b64", img2_b64) | |
html_code = f""" | |
<div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer"> | |
<div style="position: relative; width: 100%;"> | |
<img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;"> | |
<div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;"> | |
<img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;"> | |
</div> | |
<div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div> | |
</div> | |
<input type="range" min="0" max="100" value="0" | |
style="width:100%; margin-top: 10px;" | |
oninput=" | |
const val = this.value; | |
const container = document.getElementById('imageCompareContainer'); | |
const width = container.offsetWidth; | |
const clipValue = 100 - val; | |
document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)'; | |
document.getElementById('sliderLine').style.left = (width * val / 100) + 'px'; | |
"> | |
</div> | |
""" | |
return html_code | |
example_paths = [{"image": 'images/' + path} for path in os.listdir('images')] | |
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald") | |
with gr.Blocks(theme=theme, title="π Compare SAM vs SAM-HQ") as demo: | |
image_path_box = gr.Textbox(visible=False) | |
gr.Markdown("## π Compare SAM vs SAM-HQ") | |
gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it") | |
gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)") | |
print('example_paths', example_paths) | |
result_html = gr.HTML(elem_id="result-html") | |
gr.Interface( | |
fn=process_inputs, | |
#examples=example_paths, | |
inputs=ImagePrompter(show_label=False), | |
outputs=result_html, | |
) | |
gr.HTML(""" | |
<style> | |
#result-html { | |
min-height: 500px; | |
border: 1px solid #ccc; | |
padding: 10px; | |
box-sizing: border-box; | |
background-color: #fff; | |
border-radius: 8px; | |
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1); | |
} | |
</style> | |
""") | |
demo.launch() | |