sergiopaniego HF Staff commited on
Commit
a09a58c
·
1 Parent(s): ca815e1
app.py CHANGED
@@ -9,8 +9,9 @@ import os
9
  import pandas as pd
10
  from utils import *
11
  from PIL import Image
 
 
12
 
13
- # Carga de modelos
14
  sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
15
  sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
16
 
@@ -31,23 +32,26 @@ def predict_masks_and_scores(model, processor, raw_image, input_points=None, inp
31
  scores = outputs.iou_scores
32
  return masks, scores
33
 
34
- def encode_pil_to_base64(pil_image):
35
- buffer = BytesIO()
36
- pil_image.save(buffer, format="PNG")
37
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
38
-
39
- def compare_images_points_and_masks(user_image, input_boxes, input_points):
40
- for example_path, example_data in example_data_map.items():
41
- if example_data["size"] == list(user_image.size):
42
- user_image = Image.open(example_data['original_image_path'])
43
- input_boxes = input_boxes.values.tolist()
44
- input_points = input_points.values.tolist()
45
-
46
- input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]]
47
- input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]]
48
-
49
- input_boxes = input_boxes if input_boxes[0] else None
50
- input_points = input_points if input_points[0] else None
 
 
 
51
 
52
  sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
53
  sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
@@ -89,22 +93,7 @@ def compare_images_points_and_masks(user_image, input_boxes, input_points):
89
  """
90
  return html_code
91
 
92
- def load_examples(json_file="examples.json"):
93
- with open(json_file, "r") as f:
94
- examples = json.load(f)
95
- return examples
96
-
97
- examples = load_examples()
98
- example_paths = [example["image_path"] for example in examples]
99
- example_data_map = {
100
- example["image_path"]: {
101
- "original_image_path": example["original_image_path"],
102
- "points": example["points"],
103
- "boxes": example["boxes"],
104
- "size": example["size"]
105
- }
106
- for example in examples
107
- }
108
 
109
  theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
110
  with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo:
@@ -113,53 +102,15 @@ with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo:
113
  gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
114
  gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
115
 
116
- with gr.Row():
117
- image_input = gr.Image(
118
- type="pil",
119
- label="Example image (click below to load)",
120
- interactive=False,
121
- height=500,
122
- show_label=True
123
- )
124
-
125
- gr.Examples(
126
- examples=example_paths,
127
- inputs=[image_input],
128
- label="Click an example to try 👇",
129
- )
130
-
131
  result_html = gr.HTML(elem_id="result-html")
132
-
133
- with gr.Row():
134
- points_input = gr.Dataframe(
135
- headers=["x", "y"],
136
- label="Points",
137
- datatype=["number", "number"],
138
- col_count=(2, "fixed")
139
- )
140
- boxes_input = gr.Dataframe(
141
- headers=["x0", "y0", "x1", "y1"],
142
- label="Boxes",
143
- datatype=["number", "number", "number", "number"],
144
- col_count=(4, "fixed")
145
- )
146
-
147
- def on_image_change(image):
148
- for example_path, example_data in example_data_map.items():
149
- print(image.size)
150
- if example_data["size"] == list(image.size):
151
- return example_data["points"], example_data["boxes"]
152
- return [], []
153
-
154
- image_input.change(
155
- fn=on_image_change,
156
- inputs=[image_input],
157
- outputs=[points_input, boxes_input]
158
  )
159
 
160
- compare_button = gr.Button("Compare points and masks")
161
- compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html)
162
-
163
  gr.HTML("""
164
  <style>
165
  #result-html {
@@ -174,4 +125,5 @@ with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo:
174
  </style>
175
  """)
176
 
 
177
  demo.launch()
 
9
  import pandas as pd
10
  from utils import *
11
  from PIL import Image
12
+ from gradio_image_prompter import ImagePrompter
13
+
14
 
 
15
  sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
16
  sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
17
 
 
32
  scores = outputs.iou_scores
33
  return masks, scores
34
 
35
+ def process_inputs(prompts):
36
+ raw_entries = prompts["points"]
37
+
38
+ input_points = []
39
+ input_boxes = []
40
+
41
+ for entry in raw_entries:
42
+ x1, y1, type_, x2, y2, cls = entry
43
+ if type_ == 1:
44
+ input_points.append([int(x1), int(y1)])
45
+ elif type_ == 2:
46
+ x_min = int(min(x1, x2))
47
+ y_min = int(min(y1, y2))
48
+ x_max = int(max(x1, x2))
49
+ y_max = int(max(y1, y2))
50
+ input_boxes.append([x_min, y_min, x_max, y_max])
51
+
52
+ input_boxes = [input_boxes] if input_boxes else None
53
+ input_points = [input_points] if input_points else None
54
+ user_image = prompts['image']
55
 
56
  sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
57
  sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
 
93
  """
94
  return html_code
95
 
96
+ example_paths = [{"image": 'images/' + path} for path in os.listdir('images')]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald")
99
  with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo:
 
102
  gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it")
103
  gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)")
104
 
105
+ print('example_paths', example_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  result_html = gr.HTML(elem_id="result-html")
107
+ gr.Interface(
108
+ fn=process_inputs,
109
+ #examples=example_paths,
110
+ inputs=ImagePrompter(show_label=False),
111
+ outputs=result_html,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
 
 
 
 
114
  gr.HTML("""
115
  <style>
116
  #result-html {
 
125
  </style>
126
  """)
127
 
128
+
129
  demo.launch()
images/image_0.png CHANGED

Git LFS Details

  • SHA256: 8acfddb52061db75859d2452ba8e168ee96245ac31618729edaa31548854f0ea
  • Pointer size: 132 Bytes
  • Size of remote file: 1.97 MB

Git LFS Details

  • SHA256: 75b113e521d89addb6c48344ef27fefd0f494eafc703e9d0657978929fce4601
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
images/image_1.png CHANGED

Git LFS Details

  • SHA256: 6082fb3d5849b14ad082d234c1c1c7ff8de195af6a062ad152386d75d39fb9bf
  • Pointer size: 131 Bytes
  • Size of remote file: 628 kB

Git LFS Details

  • SHA256: 7e5ccc2cbc51e4849bba6d8984b5705835f332506a187dda680b207cc7a1fab2
  • Pointer size: 131 Bytes
  • Size of remote file: 613 kB
images/image_2.png CHANGED

Git LFS Details

  • SHA256: fb89b1ab049b0bbf11943b11c644e2a7970eccf1ac2bcde73c68ed6bb53096c9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.61 MB

Git LFS Details

  • SHA256: d42a70173297297b654cd067e7ed3de717c3d2b37fd6d13b0396e5fc58449850
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
images/image_3.png CHANGED

Git LFS Details

  • SHA256: 9e53cb4a9b443b74a1b7991d9c198e7435181c061e7802b10184ef055da3e384
  • Pointer size: 131 Bytes
  • Size of remote file: 433 kB

Git LFS Details

  • SHA256: 23fe057297248971db5dc01f17b6c631636cc462711ee52c8d221b131c8a456d
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
images/image_4.png DELETED

Git LFS Details

  • SHA256: b4aa85548f056d5717742e20f81f7a55674e3ce05db3b0b6ac70c0dfeffdaa56
  • Pointer size: 132 Bytes
  • Size of remote file: 6.14 MB
images/image_5.png CHANGED

Git LFS Details

  • SHA256: 9944d95815a122c048b9deec0ae28443f7d64f915455ad7036c852a00a8ad4bd
  • Pointer size: 131 Bytes
  • Size of remote file: 951 kB

Git LFS Details

  • SHA256: 453a3e1627effb4d8ed6049e8d457ebe7f869537acf8e2846b36cc62ee23d1a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
images/image_6.png CHANGED

Git LFS Details

  • SHA256: f440f6ba3b9bb265740edddcdb803836b54add9609282464d7e82ba5f452237d
  • Pointer size: 131 Bytes
  • Size of remote file: 307 kB

Git LFS Details

  • SHA256: 631bc19a9b5a3bd291de6375abf63c33234aaee5194ca95245d418581ff294d1
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
images/original_image_0.png DELETED

Git LFS Details

  • SHA256: 75b113e521d89addb6c48344ef27fefd0f494eafc703e9d0657978929fce4601
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
images/original_image_1.png DELETED

Git LFS Details

  • SHA256: 7e5ccc2cbc51e4849bba6d8984b5705835f332506a187dda680b207cc7a1fab2
  • Pointer size: 131 Bytes
  • Size of remote file: 613 kB
images/original_image_2.png DELETED

Git LFS Details

  • SHA256: d42a70173297297b654cd067e7ed3de717c3d2b37fd6d13b0396e5fc58449850
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
images/original_image_3.png DELETED

Git LFS Details

  • SHA256: 23fe057297248971db5dc01f17b6c631636cc462711ee52c8d221b131c8a456d
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
images/original_image_4.png DELETED

Git LFS Details

  • SHA256: e697f853c0cdc07e3bf4982e96e38b77707f91256aef087147c8897784fe90bc
  • Pointer size: 132 Bytes
  • Size of remote file: 6.05 MB
images/original_image_5.png DELETED

Git LFS Details

  • SHA256: 453a3e1627effb4d8ed6049e8d457ebe7f869537acf8e2846b36cc62ee23d1a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
images/original_image_6.png DELETED

Git LFS Details

  • SHA256: 631bc19a9b5a3bd291de6375abf63c33234aaee5194ca95245d418581ff294d1
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
requirements.txt CHANGED
@@ -5,4 +5,6 @@ pillow
5
  torch
6
  git+https://github.com/huggingface/transformers.git
7
  matplotlib
8
- numpy
 
 
 
5
  torch
6
  git+https://github.com/huggingface/transformers.git
7
  matplotlib
8
+ numpy
9
+ gradio-image-prompter
10
+ pydantic==2.10.6