File size: 7,046 Bytes
f6c5f6c
296f4e8
8497449
 
296f4e8
d6ff06e
 
f9f0b68
d6ff06e
 
 
 
 
 
 
 
a09a58c
 
2569c33
941f0af
579099e
941f0af
579099e
d6ff06e
f9f0b68
938994c
579099e
 
 
938994c
579099e
938994c
579099e
fce10b4
5c3423d
 
 
579099e
 
 
 
 
 
 
 
938994c
579099e
 
 
 
 
 
 
 
9c9f88f
 
 
d6ff06e
5c3423d
a09a58c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ff06e
938994c
 
d6ff06e
 
 
 
 
 
 
 
 
 
bf13da5
 
 
d6ff06e
2cb1b66
65083f0
2cb1b66
65083f0
2cb1b66
 
d6ff06e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b097347
 
a09a58c
d6ff06e
 
 
 
 
adb9df9
 
d6ff06e
 
a09a58c
d6ff06e
a09a58c
 
e0683ff
ebbe0c2
a09a58c
 
d6ff06e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09a58c
d6ff06e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os

os.system('pip install gradio-image-prompter')
os.system('pip install pydantic==2.10.6')

import gradio as gr
import torch
import spaces
import json
import base64
from io import BytesIO
from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
import os
import pandas as pd
from utils import *
from PIL import Image
from gradio_image_prompter import ImagePrompter


sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto")
sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

@spaces.GPU
def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
    if input_boxes is not None:
        input_boxes = [input_boxes]

    if model_id == 'sam':
        inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
    else:
        inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")

    original_sizes = inputs["original_sizes"]
    reshaped_sizes = inputs["reshaped_input_sizes"]

    if model_id == 'sam':
        inputs = inputs.to(sam_model.device)
        with torch.no_grad():
            outputs = sam_model(**inputs)
    else:
        inputs = inputs.to(sam_hq_model.device)
        with torch.no_grad():
            outputs = sam_hq_model(**inputs)

    if model_id == 'sam':
        masks = sam_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
        )
    else:
        masks = sam_hq_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
        )
    scores = outputs.iou_scores
    return masks, scores



def process_inputs(prompts):
    raw_entries = prompts["points"]
    
    input_points = []
    input_boxes = []

    for entry in raw_entries:
        x1, y1, type_, x2, y2, cls = entry
        if type_ == 1:
            input_points.append([int(x1), int(y1)])
        elif type_ == 2:
            x_min = int(min(x1, x2))
            y_min = int(min(y1, y2))
            x_max = int(max(x1, x2))
            y_max = int(max(y1, y2))
            input_boxes.append([x_min, y_min, x_max, y_max])

    input_boxes = [input_boxes] if input_boxes else None
    input_points = [input_points] if input_points else None
    user_image = prompts['image']

    sam_masks, sam_scores = predict_masks_and_scores('sam', user_image, input_boxes=input_boxes, input_points=input_points)
    sam_hq_masks, sam_hq_scores = predict_masks_and_scores('sam_hq', user_image, input_boxes=input_boxes, input_points=input_points)

    if input_boxes and input_points:
        img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
    elif input_boxes:
        img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
    elif input_points:
        img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')
    else:
        img1_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM')
        img2_b64 = show_all_annotations_on_image_base64(user_image, None, None, None, None, model_name='SAM_HQ')

    print('sam_masks', sam_masks)
    print('sam_scores', sam_scores)
    print('sam_hq_masks', sam_hq_masks)
    print('sam_hq_scores', sam_hq_scores)
    print('input_boxes', input_boxes)
    print('input_points', input_points)
    print('user_image', user_image)
    print("img1_b64", img1_b64)
    print("img2_b64", img2_b64)

    html_code = f"""
    <div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
        <div style="position: relative; width: 100%;">
            <img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
            <div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
                <img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
            </div>
            <div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
        </div>
        <input type="range" min="0" max="100" value="0" 
            style="width:100%; margin-top: 10px;" 
            oninput="
            const val = this.value;
            const container = document.getElementById('imageCompareContainer');
            const width = container.offsetWidth;
            const clipValue = 100 - val;
            document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
            document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
            ">
    </div>
    """
    return html_code

process_inputs.zerogpu = True

example_paths = [{"image": 'images/' + path} for path in os.listdir('images')]

theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
with gr.Blocks(theme=theme, title="πŸ” Compare SAM vs SAM-HQ") as demo:
    image_path_box = gr.Textbox(visible=False)
    gr.Markdown("## πŸ” Compare SAM vs SAM-HQ")
    gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it or upload your unique image.")
    gr.Markdown("Draw boxes and/or points over the image and click Submit!")
    gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")

    print('example_paths', example_paths)
    result_html = gr.HTML(elem_id="result-html")
    gr.Interface(
        fn=process_inputs,
        examples=example_paths,
        cache_examples=False,
        inputs=ImagePrompter(show_label=False),
        outputs=result_html,
    )

    gr.HTML("""
        <style>
        #result-html {
            min-height: 500px;
            border: 1px solid #ccc;
            padding: 10px;
            box-sizing: border-box;
            background-color: #fff;
            border-radius: 8px;
            box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
        }
        </style>
    """)


demo.launch()