Spaces:
Running
on
Zero
Running
on
Zero
add application
Browse files
app.py
CHANGED
|
@@ -419,7 +419,7 @@ def segment_fg_bg(images):
|
|
| 419 |
# transform the input images
|
| 420 |
input_images = (input_images - means) / stds
|
| 421 |
# output = model(input_images)[:, 5]
|
| 422 |
-
output = model(input_images)['attn'][6]
|
| 423 |
fg_act = output[:, 6, 6].mean(0)
|
| 424 |
bg_act = output[:, 0, 0].mean(0)
|
| 425 |
fg_acts.append(fg_act)
|
|
@@ -455,8 +455,8 @@ def segment_fg_bg(images):
|
|
| 455 |
# output = model(input_images)[:, 5]
|
| 456 |
output = model(input_images)['attn'][6]
|
| 457 |
output = F.normalize(output, dim=-1)
|
| 458 |
-
heatmap_fg = output @ fg_act[:, None]
|
| 459 |
-
heatmap_bg = output @ bg_act[:, None]
|
| 460 |
heatmap_fgs.append(heatmap_fg.cpu())
|
| 461 |
heatmap_bgs.append(heatmap_bg.cpu())
|
| 462 |
heatmap_fg = torch.cat(heatmap_fgs, dim=0)
|
|
@@ -498,8 +498,8 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 498 |
left = F.normalize(left, dim=-1)
|
| 499 |
right = F.normalize(right, dim=-1)
|
| 500 |
heatmap = left @ right.T
|
| 501 |
-
heatmap = F.normalize(heatmap, dim=-1)
|
| 502 |
-
num_samples = clusters + 20
|
| 503 |
if num_samples > fps_idx.shape[0]:
|
| 504 |
num_samples = fps_idx.shape[0]
|
| 505 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
@@ -939,7 +939,7 @@ def ncut_run(
|
|
| 939 |
return video_path, logging_str
|
| 940 |
|
| 941 |
cluster_images = None
|
| 942 |
-
if plot_clusters:
|
| 943 |
start = time.time()
|
| 944 |
progress_start = 0.6
|
| 945 |
progress(progress_start, desc="Plotting Clusters")
|
|
@@ -955,7 +955,7 @@ def ncut_run(
|
|
| 955 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 956 |
|
| 957 |
norm_images = None
|
| 958 |
-
if alignedcut_eig_norm_plot:
|
| 959 |
norm_images = []
|
| 960 |
# eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
|
| 961 |
vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
|
|
@@ -977,7 +977,7 @@ def ncut_run(
|
|
| 977 |
|
| 978 |
|
| 979 |
def _ncut_run(*args, **kwargs):
|
| 980 |
-
n_ret = kwargs.
|
| 981 |
try:
|
| 982 |
if torch.cuda.is_available():
|
| 983 |
torch.cuda.empty_cache()
|
|
@@ -1653,8 +1653,9 @@ def load_and_append(existing_images, *args, **kwargs):
|
|
| 1653 |
gr.Info(f"Total images: {len(existing_images)}")
|
| 1654 |
return existing_images
|
| 1655 |
|
| 1656 |
-
def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False):
|
| 1657 |
-
|
|
|
|
| 1658 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
| 1659 |
format="webp")
|
| 1660 |
|
|
@@ -2020,10 +2021,12 @@ def add_download_button(gallery, filename_prefix="output"):
|
|
| 2020 |
return create_file_button, download_button
|
| 2021 |
|
| 2022 |
|
| 2023 |
-
def make_output_images_section():
|
| 2024 |
-
|
|
|
|
| 2025 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
| 2026 |
-
|
|
|
|
| 2027 |
return output_gallery
|
| 2028 |
|
| 2029 |
def make_parameters_section(is_lisa=False, model_ratio=True):
|
|
@@ -2133,6 +2136,8 @@ demo = gr.Blocks(
|
|
| 2133 |
css=custom_css,
|
| 2134 |
)
|
| 2135 |
with demo:
|
|
|
|
|
|
|
| 2136 |
with gr.Tab('AlignedCut'):
|
| 2137 |
|
| 2138 |
with gr.Row():
|
|
@@ -3081,7 +3086,304 @@ with demo:
|
|
| 3081 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 3082 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 3083 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3084 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3085 |
|
| 3086 |
with gr.Tab('📄About'):
|
| 3087 |
with gr.Column():
|
|
|
|
| 419 |
# transform the input images
|
| 420 |
input_images = (input_images - means) / stds
|
| 421 |
# output = model(input_images)[:, 5]
|
| 422 |
+
output = model(input_images)['attn'][6] # [B, H=14, W=14, C]
|
| 423 |
fg_act = output[:, 6, 6].mean(0)
|
| 424 |
bg_act = output[:, 0, 0].mean(0)
|
| 425 |
fg_acts.append(fg_act)
|
|
|
|
| 455 |
# output = model(input_images)[:, 5]
|
| 456 |
output = model(input_images)['attn'][6]
|
| 457 |
output = F.normalize(output, dim=-1)
|
| 458 |
+
heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
|
| 459 |
+
heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
|
| 460 |
heatmap_fgs.append(heatmap_fg.cpu())
|
| 461 |
heatmap_bgs.append(heatmap_bg.cpu())
|
| 462 |
heatmap_fg = torch.cat(heatmap_fgs, dim=0)
|
|
|
|
| 498 |
left = F.normalize(left, dim=-1)
|
| 499 |
right = F.normalize(right, dim=-1)
|
| 500 |
heatmap = left @ right.T
|
| 501 |
+
heatmap = F.normalize(heatmap, dim=-1) # [300, N_pixel] PCA-> [300, 8]
|
| 502 |
+
num_samples = clusters + 20 # 100/120
|
| 503 |
if num_samples > fps_idx.shape[0]:
|
| 504 |
num_samples = fps_idx.shape[0]
|
| 505 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
|
| 939 |
return video_path, logging_str
|
| 940 |
|
| 941 |
cluster_images = None
|
| 942 |
+
if plot_clusters and kwargs.get("n_ret", 1) > 1:
|
| 943 |
start = time.time()
|
| 944 |
progress_start = 0.6
|
| 945 |
progress(progress_start, desc="Plotting Clusters")
|
|
|
|
| 955 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 956 |
|
| 957 |
norm_images = None
|
| 958 |
+
if alignedcut_eig_norm_plot and kwargs.get("n_ret", 1) > 1:
|
| 959 |
norm_images = []
|
| 960 |
# eig_magnitude = torch.clamp(eig_magnitude, 0, 1)
|
| 961 |
vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
|
|
|
|
| 977 |
|
| 978 |
|
| 979 |
def _ncut_run(*args, **kwargs):
|
| 980 |
+
n_ret = kwargs.get("n_ret", 1)
|
| 981 |
try:
|
| 982 |
if torch.cuda.is_available():
|
| 983 |
torch.cuda.empty_cache()
|
|
|
|
| 1653 |
gr.Info(f"Total images: {len(existing_images)}")
|
| 1654 |
return existing_images
|
| 1655 |
|
| 1656 |
+
def make_input_images_section(rows=1, cols=3, height="auto", advanced=False, is_random=False, allow_download=False, markdown=True):
|
| 1657 |
+
if markdown:
|
| 1658 |
+
gr.Markdown('### Input Images')
|
| 1659 |
input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False,
|
| 1660 |
format="webp")
|
| 1661 |
|
|
|
|
| 2021 |
return create_file_button, download_button
|
| 2022 |
|
| 2023 |
|
| 2024 |
+
def make_output_images_section(markdown=True, button=True):
|
| 2025 |
+
if markdown:
|
| 2026 |
+
gr.Markdown('### Output Images')
|
| 2027 |
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
| 2028 |
+
if button:
|
| 2029 |
+
add_rotate_flip_buttons(output_gallery)
|
| 2030 |
return output_gallery
|
| 2031 |
|
| 2032 |
def make_parameters_section(is_lisa=False, model_ratio=True):
|
|
|
|
| 2136 |
css=custom_css,
|
| 2137 |
)
|
| 2138 |
with demo:
|
| 2139 |
+
|
| 2140 |
+
|
| 2141 |
with gr.Tab('AlignedCut'):
|
| 2142 |
|
| 2143 |
with gr.Row():
|
|
|
|
| 3086 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 3087 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 3088 |
|
| 3089 |
+
with gr.Tab('Application'):
|
| 3090 |
+
gr.Markdown("Draw some points on the image to find corrsponding segments in other images. E.g. click on one face to segment all the face. [Video Tutorial (coming...)]()")
|
| 3091 |
+
with gr.Row():
|
| 3092 |
+
with gr.Column(scale=5, min_width=200):
|
| 3093 |
+
gr.Markdown("### Step 0: Load Images")
|
| 3094 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(markdown=False)
|
| 3095 |
+
submit_button.visible = False
|
| 3096 |
+
num_images_slider.value = 30
|
| 3097 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 3098 |
+
with gr.Column(scale=5, min_width=200):
|
| 3099 |
+
gr.Markdown("### Step 1: NCUT Embedding")
|
| 3100 |
+
output_gallery = make_output_images_section(markdown=False, button=False)
|
| 3101 |
+
submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary')
|
| 3102 |
+
add_rotate_flip_buttons(output_gallery)
|
| 3103 |
+
[
|
| 3104 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 3105 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 3106 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 3107 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 3108 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 3109 |
+
] = make_parameters_section()
|
| 3110 |
|
| 3111 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 3112 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 3113 |
+
|
| 3114 |
+
submit_button.click(
|
| 3115 |
+
partial(run_fn, n_ret=1),
|
| 3116 |
+
inputs=[
|
| 3117 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 3118 |
+
positive_prompt, negative_prompt,
|
| 3119 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 3120 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 3121 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 3122 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
| 3123 |
+
],
|
| 3124 |
+
outputs=[output_gallery, logging_text],
|
| 3125 |
+
)
|
| 3126 |
+
|
| 3127 |
+
with gr.Column(scale=5, min_width=200):
|
| 3128 |
+
gr.Markdown("### Step 2a: Pick an Image")
|
| 3129 |
+
from gradio_image_prompter import ImagePrompter
|
| 3130 |
+
image_type_radio = gr.Radio(["Original", "NCUT"], label="Image Display Type", value="Original", elem_id="image_type_radio")
|
| 3131 |
+
with gr.Row():
|
| 3132 |
+
image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True)
|
| 3133 |
+
image2_slider = gr.Slider(0, 100, step=1, label="Image#2 Index", value=1, elem_id="image2_slider", interactive=True)
|
| 3134 |
+
image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True)
|
| 3135 |
+
load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary')
|
| 3136 |
+
gr.Markdown("### Step 2b: Draw Points")
|
| 3137 |
+
gr.Markdown("##### 🖱️ Left Click: Foreground")
|
| 3138 |
+
gr.Markdown("##### 🖱️ Middle Click: Background")
|
| 3139 |
+
gr.Markdown("""
|
| 3140 |
+
<h5>
|
| 3141 |
+
Top Right
|
| 3142 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none"
|
| 3143 |
+
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
|
| 3144 |
+
style="vertical-align: middle; height: 1em; width: 1em; display: inline;">
|
| 3145 |
+
<polyline points="1 4 1 10 7 10"></polyline>
|
| 3146 |
+
<path d="M3.51 15a9 9 0 1 0 2.13-9.36L1 10"></path>
|
| 3147 |
+
</svg> :
|
| 3148 |
+
Remove Last Point
|
| 3149 |
+
</h5>
|
| 3150 |
+
""")
|
| 3151 |
+
prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
|
| 3152 |
+
prompt_image2 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
|
| 3153 |
+
prompt_image3 = ImagePrompter(show_label=False, elem_id="prompt_image", interactive=True)
|
| 3154 |
+
# def update_number_of_images(images):
|
| 3155 |
+
# if images is None:
|
| 3156 |
+
# return gr.update(max=0, value=0)
|
| 3157 |
+
# return gr.update(max=len(images)-1, value=1)
|
| 3158 |
+
# input_gallery.change(update_number_of_images, inputs=input_gallery, outputs=image1_slider)
|
| 3159 |
+
|
| 3160 |
+
def update_prompt_image(original_images, ncut_images, image_type, index):
|
| 3161 |
+
if image_type == "Original":
|
| 3162 |
+
images = original_images
|
| 3163 |
+
else:
|
| 3164 |
+
images = ncut_images
|
| 3165 |
+
if images is None:
|
| 3166 |
+
return
|
| 3167 |
+
total_len = len(images)
|
| 3168 |
+
if total_len == 0:
|
| 3169 |
+
return
|
| 3170 |
+
if index >= total_len:
|
| 3171 |
+
index = total_len - 1
|
| 3172 |
+
return gr.update(value={'image': images[index][0]})
|
| 3173 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image1_slider], outputs=[prompt_image1])
|
| 3174 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image2_slider], outputs=[prompt_image2])
|
| 3175 |
+
load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image3_slider], outputs=[prompt_image3])
|
| 3176 |
+
|
| 3177 |
+
image3_slider.visible = False
|
| 3178 |
+
prompt_image3.visible = False
|
| 3179 |
+
|
| 3180 |
+
|
| 3181 |
+
|
| 3182 |
+
with gr.Column(scale=5, min_width=200):
|
| 3183 |
+
gr.Markdown("### Step 3: Segment and Crop")
|
| 3184 |
+
mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
| 3185 |
+
run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary')
|
| 3186 |
+
add_download_button(mask_gallery, "mask")
|
| 3187 |
+
distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold", value=0.5, elem_id="distance_threshold", info="increase for smaller mask")
|
| 3188 |
+
# filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox")
|
| 3189 |
+
distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False)
|
| 3190 |
+
crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
| 3191 |
+
add_download_button(crop_gallery, "cropped")
|
| 3192 |
+
crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True)
|
| 3193 |
+
area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False)
|
| 3194 |
+
|
| 3195 |
+
# logging_image = gr.Image(value=None, label="Logging Image", elem_id="logging_image", interactive=False)
|
| 3196 |
+
|
| 3197 |
+
# prompt_image.change(lambda x: gr.update(value=x.get('image', None)), inputs=prompt_image, outputs=[logging_image])
|
| 3198 |
+
|
| 3199 |
+
def relative_xy(prompts):
|
| 3200 |
+
image = prompts['image']
|
| 3201 |
+
points = np.asarray(prompts['points'])
|
| 3202 |
+
if points.shape[0] == 0:
|
| 3203 |
+
return [], []
|
| 3204 |
+
is_point = points[:, 5] == 4.0
|
| 3205 |
+
points = points[is_point]
|
| 3206 |
+
is_positive = points[:, 2] == 1.0
|
| 3207 |
+
is_negative = points[:, 2] == 0.0
|
| 3208 |
+
xy = points[:, :2].tolist()
|
| 3209 |
+
if isinstance(image, str):
|
| 3210 |
+
image = Image.open(image)
|
| 3211 |
+
image = np.array(image)
|
| 3212 |
+
h, w = image.shape[:2]
|
| 3213 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
| 3214 |
+
# print(new_xy)
|
| 3215 |
+
return new_xy, is_positive
|
| 3216 |
+
|
| 3217 |
+
def xy_rgb(prompts, image_idx, ncut_images):
|
| 3218 |
+
image = ncut_images[image_idx]
|
| 3219 |
+
xy, is_positive = relative_xy(prompts)
|
| 3220 |
+
rgbs = []
|
| 3221 |
+
for i, (x, y) in enumerate(xy):
|
| 3222 |
+
rgb = image.getpixel((int(x*image.width), int(y*image.height)))
|
| 3223 |
+
rgbs.append((rgb, is_positive[i]))
|
| 3224 |
+
return rgbs
|
| 3225 |
+
|
| 3226 |
+
def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3,
|
| 3227 |
+
crop_expand, distance_threshold, distance_power, area_threshold):
|
| 3228 |
+
ncut_images = [image[0] for image in ncut_images]
|
| 3229 |
+
if len(ncut_images) == 0:
|
| 3230 |
+
return []
|
| 3231 |
+
if isinstance(ncut_images[0], str):
|
| 3232 |
+
ncut_images = [Image.open(image) for image in ncut_images]
|
| 3233 |
+
|
| 3234 |
+
rgbs = xy_rgb(prompts1, image_idx1, ncut_images) + \
|
| 3235 |
+
xy_rgb(prompts2, image_idx2, ncut_images) + \
|
| 3236 |
+
xy_rgb(prompts3, image_idx3, ncut_images)
|
| 3237 |
+
# print(rgbs)
|
| 3238 |
+
|
| 3239 |
+
|
| 3240 |
+
ncut_images = [np.array(image).astype(np.float32) for image in ncut_images]
|
| 3241 |
+
ncut_pixels = [image.reshape(-1, 3) for image in ncut_images]
|
| 3242 |
+
h, w = ncut_images[0].shape[:2]
|
| 3243 |
+
ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255
|
| 3244 |
+
# normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1)
|
| 3245 |
+
positive_distances = []
|
| 3246 |
+
negative_distances = []
|
| 3247 |
+
for rgb, is_positive in rgbs:
|
| 3248 |
+
rgb = torch.tensor(rgb).float() / 255
|
| 3249 |
+
# rgb = F.normalize(rgb, p=2, dim=-1)
|
| 3250 |
+
distance = (ncut_pixels - rgb[None]).norm(dim=-1)
|
| 3251 |
+
distance = distance.squeeze(-1)
|
| 3252 |
+
if is_positive:
|
| 3253 |
+
positive_distances.append(distance)
|
| 3254 |
+
else:
|
| 3255 |
+
negative_distances.append(distance)
|
| 3256 |
+
if len(positive_distances) == 0:
|
| 3257 |
+
raise gr.Error("No prompt points. Please draw some points on the image.")
|
| 3258 |
+
positive_distances = torch.stack(positive_distances)
|
| 3259 |
+
negative_flag = len(negative_distances) > 0
|
| 3260 |
+
if len(negative_distances) == 0:
|
| 3261 |
+
negative_distances = positive_distances * 0 # dummy
|
| 3262 |
+
else:
|
| 3263 |
+
negative_distances = torch.stack(negative_distances)
|
| 3264 |
+
|
| 3265 |
+
positive_distance = positive_distances.min(dim=0).values
|
| 3266 |
+
negative_distance = negative_distances.min(dim=0).values
|
| 3267 |
+
# positive_distance = positive_distances.mean(dim=0)
|
| 3268 |
+
# negative_distance = negative_distances.mean(dim=0)
|
| 3269 |
+
|
| 3270 |
+
def to_mask(heatmap, threshold):
|
| 3271 |
+
heatmap = 1 / (heatmap + 1e-6)
|
| 3272 |
+
heatmap = heatmap.reshape(len(ncut_images), h, w)
|
| 3273 |
+
vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99)
|
| 3274 |
+
heatmap = (heatmap - vmin) / (vmax - vmin)
|
| 3275 |
+
mask = heatmap > threshold
|
| 3276 |
+
return mask
|
| 3277 |
+
|
| 3278 |
+
positive_mask = to_mask(positive_distance, distance_threshold)
|
| 3279 |
+
if negative_flag:
|
| 3280 |
+
negative_mask = to_mask(negative_distance, distance_threshold)
|
| 3281 |
+
positive_mask = positive_mask & ~negative_mask
|
| 3282 |
+
|
| 3283 |
+
|
| 3284 |
+
#convert to PIL
|
| 3285 |
+
mask = positive_mask.cpu().numpy()
|
| 3286 |
+
mask = mask.astype(np.uint8) * 255
|
| 3287 |
+
mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
|
| 3288 |
+
|
| 3289 |
+
import cv2
|
| 3290 |
+
def get_bboxes_and_clean_mask(pil_mask, min_area=500):
|
| 3291 |
+
"""
|
| 3292 |
+
Args:
|
| 3293 |
+
- pil_mask: A Pillow image of a binary mask with 255 for the object and 0 for the background.
|
| 3294 |
+
- min_area: Minimum area for a connected component to be considered valid (default 500).
|
| 3295 |
+
|
| 3296 |
+
Returns:
|
| 3297 |
+
- bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
|
| 3298 |
+
- cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
|
| 3299 |
+
"""
|
| 3300 |
+
# Convert the Pillow image to a NumPy array
|
| 3301 |
+
mask = np.array(pil_mask)
|
| 3302 |
+
|
| 3303 |
+
# Ensure the mask is binary (0 or 255)
|
| 3304 |
+
mask = np.where(mask > 127, 255, 0).astype(np.uint8)
|
| 3305 |
+
|
| 3306 |
+
# Remove small noise using morphological operations (denoising)
|
| 3307 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 3308 |
+
cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 3309 |
+
|
| 3310 |
+
# Find connected components in the cleaned mask
|
| 3311 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned_mask, connectivity=8)
|
| 3312 |
+
|
| 3313 |
+
# Initialize an empty mask to store the final cleaned mask
|
| 3314 |
+
final_cleaned_mask = np.zeros_like(cleaned_mask)
|
| 3315 |
+
|
| 3316 |
+
# Collect bounding boxes for components that are larger than the threshold and update the cleaned mask
|
| 3317 |
+
bounding_boxes = []
|
| 3318 |
+
for i in range(1, num_labels): # Skip label 0 (background)
|
| 3319 |
+
x, y, w, h, area = stats[i]
|
| 3320 |
+
if area >= min_area:
|
| 3321 |
+
# Add the bounding box of the valid component
|
| 3322 |
+
bounding_boxes.append((x, y, w, h))
|
| 3323 |
+
# Keep the valid components in the final cleaned mask
|
| 3324 |
+
final_cleaned_mask[labels == i] = 255
|
| 3325 |
+
|
| 3326 |
+
# Convert the final cleaned mask back to a Pillow image
|
| 3327 |
+
cleaned_pil_mask = Image.fromarray(final_cleaned_mask)
|
| 3328 |
+
|
| 3329 |
+
return bounding_boxes, cleaned_pil_mask
|
| 3330 |
+
|
| 3331 |
+
bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask])
|
| 3332 |
+
|
| 3333 |
+
# combine the masks, also draw the bounding boxes
|
| 3334 |
+
combined_masks = []
|
| 3335 |
+
for i_image in range(len(mask)):
|
| 3336 |
+
noisy_mask = np.array(mask[i_image].convert("RGB"))
|
| 3337 |
+
bbox = bboxs[i_image]
|
| 3338 |
+
clean_mask = np.array(filtered_masks[i_image].convert("RGB"))
|
| 3339 |
+
combined_mask = noisy_mask * 0.4 + clean_mask
|
| 3340 |
+
combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8)
|
| 3341 |
+
for x, y, w, h in bbox:
|
| 3342 |
+
cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2)
|
| 3343 |
+
combined_mask = Image.fromarray(combined_mask)
|
| 3344 |
+
combined_masks.append(combined_mask)
|
| 3345 |
+
|
| 3346 |
+
def extend_the_mask(xywh, factor=1.5):
|
| 3347 |
+
x, y, w, h = xywh
|
| 3348 |
+
x -= w * (factor - 1) / 2
|
| 3349 |
+
y -= h * (factor - 1) / 2
|
| 3350 |
+
w *= factor
|
| 3351 |
+
h *= factor
|
| 3352 |
+
return x, y, w, h
|
| 3353 |
+
|
| 3354 |
+
def resize_the_mask(xywh, original_size, target_size):
|
| 3355 |
+
x, y, w, h = xywh
|
| 3356 |
+
x *= target_size[0] / original_size[0]
|
| 3357 |
+
y *= target_size[1] / original_size[1]
|
| 3358 |
+
w *= target_size[0] / original_size[0]
|
| 3359 |
+
h *= target_size[1] / original_size[1]
|
| 3360 |
+
x, y, w, h = int(x), int(y), int(w), int(h)
|
| 3361 |
+
return x, y, w, h
|
| 3362 |
+
|
| 3363 |
+
def crop_image(image, xywh, mask_h, mask_w, factor=1.0):
|
| 3364 |
+
x, y, w, h = xywh
|
| 3365 |
+
x, y, w, h = resize_the_mask((x, y, w, h), (mask_h, mask_w), image.size)
|
| 3366 |
+
_x, _y, _w, _h = extend_the_mask((x, y, w, h), factor=factor)
|
| 3367 |
+
crop = image.crop((_x, _y, _x + _w, _y + _h))
|
| 3368 |
+
return crop
|
| 3369 |
+
|
| 3370 |
+
original_images = [image[0] for image in original_images]
|
| 3371 |
+
if isinstance(original_images[0], str):
|
| 3372 |
+
original_images = [Image.open(image) for image in original_images]
|
| 3373 |
+
|
| 3374 |
+
mask_h, mask_w = filtered_masks[0].size
|
| 3375 |
+
cropped_images = []
|
| 3376 |
+
for _image, _bboxs in zip(original_images, bboxs):
|
| 3377 |
+
for _bbox in _bboxs:
|
| 3378 |
+
cropped_images.append(crop_image(_image, _bbox, mask_h, mask_w, factor=crop_expand))
|
| 3379 |
+
|
| 3380 |
+
return combined_masks, cropped_images
|
| 3381 |
+
|
| 3382 |
+
run_crop_button.click(run_crop,
|
| 3383 |
+
inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider,
|
| 3384 |
+
crop_expand_slider, distance_threshold_slider, distance_power_slider, area_threshold_slider],
|
| 3385 |
+
outputs=[mask_gallery, crop_gallery])
|
| 3386 |
+
|
| 3387 |
|
| 3388 |
with gr.Tab('📄About'):
|
| 3389 |
with gr.Column():
|