Spaces:
Running
on
L4
Running
on
L4
File size: 7,046 Bytes
f6c5f6c 296f4e8 8497449 296f4e8 d6ff06e f9f0b68 d6ff06e a09a58c 2569c33 941f0af 579099e 941f0af 579099e d6ff06e f9f0b68 938994c 579099e 938994c 579099e 938994c 579099e fce10b4 5c3423d 579099e 938994c 579099e 9c9f88f d6ff06e 5c3423d a09a58c d6ff06e 938994c d6ff06e bf13da5 d6ff06e 2cb1b66 65083f0 2cb1b66 65083f0 2cb1b66 d6ff06e b097347 a09a58c d6ff06e adb9df9 d6ff06e a09a58c d6ff06e a09a58c e0683ff ebbe0c2 a09a58c d6ff06e a09a58c d6ff06e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
os.system('pip install gradio-image-prompter')
os.system('pip install pydantic==2.10.6')
import gradio as gr
import torch
import spaces
import json
import base64
from io import BytesIO
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
import os
import pandas as pd
from utils import *
from PIL import Image
from gradio_image_prompter import ImagePrompter
sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto")
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
@spaces.GPU
def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
if input_boxes is not None:
input_boxes = [input_boxes]
if model_id == 'sam':
inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
else:
inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
original_sizes = inputs["original_sizes"]
reshaped_sizes = inputs["reshaped_input_sizes"]
if model_id == 'sam':
inputs = inputs.to(sam_model.device)
with torch.no_grad():
outputs = sam_model(**inputs)
else:
inputs = inputs.to(sam_hq_model.device)
with torch.no_grad():
outputs = sam_hq_model(**inputs)
if model_id == 'sam':
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
)
else:
masks = sam_hq_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
)
scores = outputs.iou_scores
return masks, scores
def process_inputs(prompts):
raw_entries = prompts["points"]
input_points = []
input_boxes = []
for entry in raw_entries:
x1, y1, type_, x2, y2, cls = entry
if type_ == 1:
input_points.append([int(x1), int(y1)])
elif type_ == 2:
x_min = int(min(x1, x2))
y_min = int(min(y1, y2))
x_max = int(max(x1, x2))
y_max = int(max(y1, y2))
input_boxes.append([x_min, y_min, x_max, y_max])
input_boxes = [input_boxes] if input_boxes else None
input_points = [input_points] if input_points else None
user_image = prompts['image']
sam_masks, sam_scores = predict_masks_and_scores('sam', user_image, input_boxes=input_boxes, input_points=input_points)
sam_hq_masks, sam_hq_scores = predict_masks_and_scores('sam_hq', user_image, input_boxes=input_boxes, input_points=input_points)
if input_boxes and input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
elif input_boxes:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
elif input_points:
img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')
else:
img1_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM')
img2_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM_HQ')
print('sam_masks', sam_masks)
print('sam_scores', sam_scores)
print('sam_hq_masks', sam_hq_masks)
print('sam_hq_scores', sam_hq_scores)
print('input_boxes', input_boxes)
print('input_points', input_points)
print('user_image', user_image)
print("img1_b64", img1_b64)
print("img2_b64", img2_b64)
html_code = f"""
<div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
<div style="position: relative; width: 100%;">
<img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
<div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
<img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
</div>
<div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
</div>
<input type="range" min="0" max="100" value="0"
style="width:100%; margin-top: 10px;"
oninput="
const val = this.value;
const container = document.getElementById('imageCompareContainer');
const width = container.offsetWidth;
const clipValue = 100 - val;
document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
">
</div>
"""
return html_code
process_inputs.zerogpu = True
example_paths = [{"image": 'images/' + path} for path in os.listdir('images')]
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="π Compare SAM vs SAM-HQ") as demo:
image_path_box = gr.Textbox(visible=False)
gr.Markdown("## π Compare SAM vs SAM-HQ")
gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it or upload your unique image.")
gr.Markdown("Draw boxes and/or points over the image and click Submit!")
gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
print('example_paths', example_paths)
result_html = gr.HTML(elem_id="result-html")
gr.Interface(
fn=process_inputs,
examples=example_paths,
cache_examples=False,
inputs=ImagePrompter(show_label=False),
outputs=result_html,
)
gr.HTML("""
<style>
#result-html {
min-height: 500px;
border: 1px solid #ccc;
padding: 10px;
box-sizing: border-box;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
}
</style>
""")
demo.launch()
|