sergiopaniego HF Staff commited on
Commit
d6ff06e
·
1 Parent(s): 6afc369

Started Space

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ import json
5
+ import base64
6
+ from io import BytesIO
7
+ from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor
8
+ import os
9
+ import pandas as pd
10
+ from utils import *
11
+ from PIL import Image
12
+
13
+ # Carga de modelos
14
+ sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
15
+ sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
16
+
17
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge")
18
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
19
+
20
+ @spaces.GPU
21
+ def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None):
22
+ if input_boxes is not None:
23
+ input_boxes = [input_boxes]
24
+ inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+
28
+ masks = processor.image_processor.post_process_masks(
29
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
30
+ )
31
+ scores = outputs.iou_scores
32
+ return masks, scores
33
+
34
+ def encode_pil_to_base64(pil_image):
35
+ buffer = BytesIO()
36
+ pil_image.save(buffer, format="PNG")
37
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
38
+
39
+ def compare_images_points_and_masks(user_image, input_boxes, input_points):
40
+ for example_path, example_data in example_data_map.items():
41
+ if example_data["size"] == list(user_image.size):
42
+ user_image = Image.open(example_data['original_image_path'])
43
+ input_boxes = input_boxes.values.tolist()
44
+ input_points = input_points.values.tolist()
45
+
46
+ input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]]
47
+ input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]]
48
+
49
+ input_boxes = input_boxes if input_boxes[0] else None
50
+ input_points = input_points if input_points[0] else None
51
+
52
+ sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
53
+ sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
54
+
55
+ if input_boxes and input_points:
56
+ img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
57
+ img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ')
58
+ elif input_boxes:
59
+ img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM')
60
+ img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ')
61
+ elif input_points:
62
+ img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM')
63
+ img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ')
64
+
65
+ print('user_image', user_image)
66
+ print("img1_b64", img1_b64)
67
+ print("img2_b64", img2_b64)
68
+
69
+ html_code = f"""
70
+ <div style="position: relative; width: 100%; max-width: 600px; margin: 0 auto;" id="imageCompareContainer">
71
+ <div style="position: relative; width: 100%;">
72
+ <img src="data:image/png;base64,{img1_b64}" style="width:100%; display:block;">
73
+ <div id="topWrapper" style="position:absolute; top:0; left:0; width:100%; overflow:hidden;">
74
+ <img id="topImage" src="data:image/png;base64,{img2_b64}" style="width:100%;">
75
+ </div>
76
+ <div id="sliderLine" style="position:absolute; top:0; left:0; width:2px; height:100%; background-color:red; pointer-events:none;"></div>
77
+ </div>
78
+ <input type="range" min="0" max="100" value="0"
79
+ style="width:100%; margin-top: 10px;"
80
+ oninput="
81
+ const val = this.value;
82
+ const container = document.getElementById('imageCompareContainer');
83
+ const width = container.offsetWidth;
84
+ const clipValue = 100 - val;
85
+ document.getElementById('topImage').style.clipPath = 'inset(0 ' + clipValue + '% 0 0)';
86
+ document.getElementById('sliderLine').style.left = (width * val / 100) + 'px';
87
+ ">
88
+ </div>
89
+ """
90
+ return html_code
91
+
92
+ def load_examples(json_file="examples.json"):
93
+ with open(json_file, "r") as f:
94
+ examples = json.load(f)
95
+ return examples
96
+
97
+ examples = load_examples()
98
+ example_paths = [example["image_path"] for example in examples]
99
+ example_data_map = {
100
+ example["image_path"]: {
101
+ "original_image_path": example["original_image_path"],
102
+ "points": example["points"],
103
+ "boxes": example["boxes"],
104
+ "size": example["size"]
105
+ }
106
+ for example in examples
107
+ }
108
+
109
+ theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
110
+ with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo:
111
+ image_path_box = gr.Textbox(visible=False)
112
+ gr.Markdown("## 🔍 Compare SAM vs SAM-HQ")
113
+ gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
114
+ gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
115
+
116
+ with gr.Row():
117
+ image_input = gr.Image(
118
+ type="pil",
119
+ label="Example image (click below to load)",
120
+ interactive=False,
121
+ height=500,
122
+ show_label=True
123
+ )
124
+
125
+ gr.Examples(
126
+ examples=example_paths,
127
+ inputs=[image_input],
128
+ label="Click an example to try 👇",
129
+ )
130
+
131
+ result_html = gr.HTML(elem_id="result-html")
132
+
133
+ with gr.Row():
134
+ points_input = gr.Dataframe(
135
+ headers=["x", "y"],
136
+ label="Points",
137
+ datatype=["number", "number"],
138
+ col_count=(2, "fixed")
139
+ )
140
+ boxes_input = gr.Dataframe(
141
+ headers=["x0", "y0", "x1", "y1"],
142
+ label="Boxes",
143
+ datatype=["number", "number", "number", "number"],
144
+ col_count=(4, "fixed")
145
+ )
146
+
147
+ def on_image_change(image):
148
+ for example_path, example_data in example_data_map.items():
149
+ print(image.size)
150
+ if example_data["size"] == list(image.size):
151
+ return example_data["points"], example_data["boxes"]
152
+ return [], []
153
+
154
+ image_input.change(
155
+ fn=on_image_change,
156
+ inputs=[image_input],
157
+ outputs=[points_input, boxes_input]
158
+ )
159
+
160
+ compare_button = gr.Button("Compare points and masks")
161
+ compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html)
162
+
163
+ gr.HTML("""
164
+ <style>
165
+ #result-html {
166
+ min-height: 500px;
167
+ border: 1px solid #ccc;
168
+ padding: 10px;
169
+ box-sizing: border-box;
170
+ background-color: #fff;
171
+ border-radius: 8px;
172
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
173
+ }
174
+ </style>
175
+ """)
176
+
177
+ demo.launch()
examples.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "image_path": "./images/image_0.png",
4
+ "original_image_path": "./images/original_image_0.png",
5
+ "points": null,
6
+ "boxes": [[4,13,1007,1023]],
7
+ "size": [1024, 1024]
8
+ },
9
+ {
10
+ "image_path": "./images/image_1.png",
11
+ "original_image_path": "./images/original_image_1.png",
12
+ "points": null,
13
+ "boxes": [[230, 99, 694, 670]],
14
+ "size": [768, 768]
15
+ },
16
+ {
17
+ "image_path": "./images/image_2.png",
18
+ "original_image_path": "./images/original_image_2.png",
19
+ "points": [[495,518],[217,140]],
20
+ "boxes": null,
21
+ "size": [894, 1000]
22
+ },
23
+ {
24
+ "image_path": "./images/image_3.png",
25
+ "original_image_path": "./images/original_image_3.png",
26
+ "points": [[111, 241],[249, 317],[375, 190]],
27
+ "boxes": null,
28
+ "size": [512, 512]
29
+ },
30
+ {
31
+ "image_path": "./images/image_4.png",
32
+ "original_image_path": "./images/original_image_4.png",
33
+ "points": null,
34
+ "boxes": [[128, 152, 1880, 1838]],
35
+ "size": [2048, 2048]
36
+ },
37
+ {
38
+ "image_path": "./images/image_5.png",
39
+ "original_image_path": "./images/original_image_5.png",
40
+ "points": [[373,363], [452, 575]],
41
+ "boxes": null,
42
+ "size": [1024, 683]
43
+ },
44
+ {
45
+ "image_path": "./images/image_6.png",
46
+ "original_image_path": "./images/original_image_6.png",
47
+ "points": null,
48
+ "boxes": [[181, 196, 757, 495]],
49
+ "size": [800, 533]
50
+ }
51
+ ]
images/image_0.png ADDED

Git LFS Details

  • SHA256: 8acfddb52061db75859d2452ba8e168ee96245ac31618729edaa31548854f0ea
  • Pointer size: 132 Bytes
  • Size of remote file: 1.97 MB
images/image_1.png ADDED

Git LFS Details

  • SHA256: 6082fb3d5849b14ad082d234c1c1c7ff8de195af6a062ad152386d75d39fb9bf
  • Pointer size: 131 Bytes
  • Size of remote file: 628 kB
images/image_2.png ADDED

Git LFS Details

  • SHA256: fb89b1ab049b0bbf11943b11c644e2a7970eccf1ac2bcde73c68ed6bb53096c9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.61 MB
images/image_3.png ADDED

Git LFS Details

  • SHA256: 9e53cb4a9b443b74a1b7991d9c198e7435181c061e7802b10184ef055da3e384
  • Pointer size: 131 Bytes
  • Size of remote file: 433 kB
images/image_4.png ADDED

Git LFS Details

  • SHA256: b4aa85548f056d5717742e20f81f7a55674e3ce05db3b0b6ac70c0dfeffdaa56
  • Pointer size: 132 Bytes
  • Size of remote file: 6.14 MB
images/image_5.png ADDED

Git LFS Details

  • SHA256: 9944d95815a122c048b9deec0ae28443f7d64f915455ad7036c852a00a8ad4bd
  • Pointer size: 131 Bytes
  • Size of remote file: 951 kB
images/image_6.png ADDED

Git LFS Details

  • SHA256: f440f6ba3b9bb265740edddcdb803836b54add9609282464d7e82ba5f452237d
  • Pointer size: 131 Bytes
  • Size of remote file: 307 kB
images/original_image_0.png ADDED

Git LFS Details

  • SHA256: 75b113e521d89addb6c48344ef27fefd0f494eafc703e9d0657978929fce4601
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
images/original_image_1.png ADDED

Git LFS Details

  • SHA256: 7e5ccc2cbc51e4849bba6d8984b5705835f332506a187dda680b207cc7a1fab2
  • Pointer size: 131 Bytes
  • Size of remote file: 613 kB
images/original_image_2.png ADDED

Git LFS Details

  • SHA256: d42a70173297297b654cd067e7ed3de717c3d2b37fd6d13b0396e5fc58449850
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
images/original_image_3.png ADDED

Git LFS Details

  • SHA256: 23fe057297248971db5dc01f17b6c631636cc462711ee52c8d221b131c8a456d
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
images/original_image_4.png ADDED

Git LFS Details

  • SHA256: e697f853c0cdc07e3bf4982e96e38b77707f91256aef087147c8897784fe90bc
  • Pointer size: 132 Bytes
  • Size of remote file: 6.05 MB
images/original_image_5.png ADDED

Git LFS Details

  • SHA256: 453a3e1627effb4d8ed6049e8d457ebe7f869537acf8e2846b36cc62ee23d1a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
images/original_image_6.png ADDED

Git LFS Details

  • SHA256: 631bc19a9b5a3bd291de6375abf63c33234aaee5194ca95245d418581ff294d1
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface_hub
3
+ requests
4
+ pillow
5
+ torch
6
+ git+https://github.com/huggingface/transformers.git
7
+ matplotlib
8
+ numpy
utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import base64
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+
7
+
8
+ def fig_to_base64(fig):
9
+ buf = BytesIO()
10
+ fig.savefig(buf, format='png', bbox_inches='tight')
11
+ plt.close(fig)
12
+ buf.seek(0)
13
+ return base64.b64encode(buf.getvalue()).decode()
14
+
15
+ def show_mask(mask, ax, random_color=False):
16
+ if random_color:
17
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
18
+ else:
19
+ color = np.array([30/255, 144/255, 255/255, 0.6])
20
+ h, w = mask.shape[-2:]
21
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
22
+ ax.imshow(mask_image)
23
+
24
+ def show_box(box, ax):
25
+ x0, y0 = box[0], box[1]
26
+ w, h = box[2] - box[0], box[3] - box[1]
27
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
28
+
29
+ def show_points(coords, labels, ax, marker_size=375):
30
+ pos_points = coords[labels==1]
31
+ neg_points = coords[labels==0]
32
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
33
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
34
+
35
+ def show_boxes_on_image_base64(raw_image, boxes):
36
+ fig, ax = plt.subplots(figsize=(10,10))
37
+ ax.imshow(raw_image)
38
+ for box in boxes:
39
+ show_box(box, ax)
40
+ ax.axis('off')
41
+ return fig_to_base64(fig)
42
+
43
+ def show_points_on_image_base64(raw_image, input_points, input_labels=None):
44
+ fig, ax = plt.subplots(figsize=(10,10))
45
+ ax.imshow(raw_image)
46
+ input_points = np.array(input_points)
47
+ labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels)
48
+ show_points(input_points, labels, ax)
49
+ ax.axis('off')
50
+ return fig_to_base64(fig)
51
+
52
+ def show_points_and_boxes_on_image_base64(raw_image, boxes, input_points, input_labels=None):
53
+ fig, ax = plt.subplots(figsize=(10,10))
54
+ ax.imshow(raw_image)
55
+ input_points = np.array(input_points)
56
+ labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels)
57
+ show_points(input_points, labels, ax)
58
+ for box in boxes:
59
+ show_box(box, ax)
60
+ ax.axis('off')
61
+ return fig_to_base64(fig)
62
+
63
+ def show_masks_on_image_base64(raw_image, masks, scores):
64
+ if len(masks.shape) == 4:
65
+ masks = masks.squeeze()
66
+ if scores.shape[0] == 1:
67
+ scores = scores.squeeze()
68
+
69
+ nb_predictions = scores.shape[-1]
70
+ print(f"Number of predictions: {nb_predictions}")
71
+ fig, axes = plt.subplots(1, nb_predictions, figsize=(5 * nb_predictions, 5))
72
+
73
+ if nb_predictions == 1:
74
+ axes = [axes]
75
+
76
+ for i, (mask, score) in enumerate(zip(masks, scores)):
77
+ print(i)
78
+ mask = mask.cpu().detach().numpy()
79
+ axes[i].imshow(np.array(raw_image))
80
+ show_mask(mask, axes[i])
81
+ axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
82
+ axes[i].axis("off")
83
+
84
+ return fig_to_base64(fig)
85
+
86
+ def show_first_mask_on_image_base64(raw_image, masks, scores):
87
+ if masks.ndim == 4:
88
+ mask = masks[0, 0]
89
+ elif masks.ndim == 3:
90
+ mask = masks[0]
91
+ else:
92
+ mask = masks
93
+
94
+ if isinstance(mask, torch.Tensor):
95
+ mask = mask.cpu().detach().numpy()
96
+
97
+ score_text = ""
98
+ if scores is not None:
99
+ if isinstance(scores, torch.Tensor):
100
+ scores = scores.flatten()
101
+ score = scores[0].item()
102
+ else:
103
+ score = float(np.array(scores).flatten()[0])
104
+ score_text = f"Score: {score:.3f}"
105
+
106
+ fig, ax = plt.subplots(figsize=(5, 5))
107
+ ax.imshow(np.array(raw_image))
108
+ show_mask(mask, ax)
109
+ ax.set_title(score_text)
110
+ ax.axis("off")
111
+
112
+ return fig_to_base64(fig)
113
+
114
+ def show_all_annotations_on_image_base64(raw_image, masks=None, scores=None, boxes=None, input_points=None, input_labels=None, model_name=None):
115
+ fig, ax = plt.subplots(figsize=(10, 10))
116
+ ax.imshow(np.array(raw_image))
117
+
118
+ if masks is not None:
119
+ if masks.ndim == 4:
120
+ mask = masks[0, 0]
121
+ elif masks.ndim == 3:
122
+ mask = masks[0]
123
+ else:
124
+ mask = masks
125
+ if isinstance(mask, torch.Tensor):
126
+ mask = mask.cpu().detach().numpy()
127
+ show_mask(mask, ax)
128
+
129
+ if scores is not None:
130
+ if isinstance(scores, torch.Tensor):
131
+ scores = scores.flatten()
132
+ score = scores[0].item()
133
+ else:
134
+ score = float(np.array(scores).flatten()[0])
135
+ ax.set_title(f"{model_name} - Score: {score:.3f}")
136
+
137
+
138
+ if input_points is not None:
139
+ input_points = np.array(input_points)
140
+ labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels)
141
+ show_points(input_points, labels, ax)
142
+
143
+ if boxes is not None:
144
+ for box in boxes:
145
+ show_box(box, ax)
146
+
147
+ ax.axis("off")
148
+ return fig_to_base64(fig)