sergiopaniego's picture
Fixed
f6c5f6c
import os
os.system('pip install gradio-image-prompter')
os.system('pip install pydantic==2.10.6')
import gradio as gr
import torch
import spaces
import json
import base64
from io import BytesIO
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
import os
import pandas as pd
from utils import *
from PIL import Image
from gradio_image_prompter import ImagePrompter
#sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
#sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base")
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
#sam_model = SamModel.from_pretrained("facebook/sam-vit-huge")
#sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
@spaces.GPU
def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None):
if input_boxes is not None:
input_boxes = [input_boxes]
inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
return masks, scores
def process_inputs(prompts):
raw_entries = prompts["points"]
input_points = []
input_boxes = []
for entry in raw_entries:
x1, y1, type_, x2, y2, cls = entry
if type_ == 1:
input_points.append([int(x1), int(y1)])
elif type_ == 2:
x_min = int(min(x1, x2))
y_min = int(min(y1, y2))
x_max = int(max(x1, x2))
y_max = int(max(y1, y2))
input_boxes.append([x_min, y_min, x_max, y_max])
input_boxes = [input_boxes] if input_boxes else None
input_points = [input_points] if input_points else None
user_image = prompts['image']
sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
if input_boxes and input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
elif input_boxes:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
elif input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')
else:
img1_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM_HQ')
print('user_image', user_image)
print("img1_b64", img1_b64)
print("img2_b64", img2_b64)
html_code = f"""
<div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
<div style="position: relative; width: 100%;">
<img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
<div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
<img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
</div>
<div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
</div>
<input type="range" min="0" max="100" value="0"
style="width:100%; margin-top: 10px;"
oninput="
const val = this.value;
const container = document.getElementById('imageCompareContainer');
const width = container.offsetWidth;
const clipValue = 100 - val;
document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
">
</div>
"""
return html_code
example_paths = [{"image": 'images/' + path} for path in os.listdir('images')]
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="πŸ” Compare SAM vs SAM-HQ") as demo:
image_path_box = gr.Textbox(visible=False)
gr.Markdown("## πŸ” Compare SAM vs SAM-HQ")
gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
print('example_paths', example_paths)
result_html = gr.HTML(elem_id="result-html")
gr.Interface(
fn=process_inputs,
#examples=example_paths,
inputs=ImagePrompter(show_label=False),
outputs=result_html,
)
gr.HTML("""
<style>
#result-html {
min-height: 500px;
border: 1px solid #ccc;
padding: 10px;
box-sizing: border-box;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
}
</style>
""")
demo.launch()