Spaces:
Running
on
Zero
Running
on
Zero
update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,8 @@ from PIL import Image
|
|
| 5 |
import torchvision.transforms as transforms
|
| 6 |
from torch import nn
|
| 7 |
import numpy as np
|
|
|
|
|
|
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
|
|
@@ -315,6 +317,7 @@ def compute_ncut(
|
|
| 315 |
):
|
| 316 |
from ncut_pytorch import NCUT, rgb_from_tsne_3d
|
| 317 |
|
|
|
|
| 318 |
eigvecs, eigvals = NCUT(
|
| 319 |
num_eig=num_eig,
|
| 320 |
num_sample=num_sample_ncut,
|
|
@@ -322,12 +325,17 @@ def compute_ncut(
|
|
| 322 |
affinity_focal_gamma=affinity_focal_gamma,
|
| 323 |
knn=knn_ncut,
|
| 324 |
).fit_transform(features.reshape(-1, features.shape[-1]))
|
|
|
|
|
|
|
|
|
|
| 325 |
X_3d, rgb = rgb_from_tsne_3d(
|
| 326 |
eigvecs,
|
| 327 |
num_sample=num_sample_tsne,
|
| 328 |
perplexity=perplexity,
|
| 329 |
knn=knn_tsne,
|
| 330 |
)
|
|
|
|
|
|
|
| 331 |
rgb = rgb.reshape(features.shape[:3] + (3,))
|
| 332 |
return rgb
|
| 333 |
|
|
@@ -368,9 +376,13 @@ def main_fn(
|
|
| 368 |
perplexity = num_sample_tsne - 1
|
| 369 |
|
| 370 |
images = [image[0] for image in images]
|
|
|
|
|
|
|
| 371 |
features = extract_features(
|
| 372 |
images, model_name=model_name, node_type=node_type, layer=layer
|
| 373 |
)
|
|
|
|
|
|
|
| 374 |
rgb = compute_ncut(
|
| 375 |
features,
|
| 376 |
num_eig=num_eig,
|
|
@@ -391,7 +403,7 @@ demo = gr.Interface(
|
|
| 391 |
main_fn,
|
| 392 |
[
|
| 393 |
gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
|
| 394 |
-
gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
|
| 395 |
gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
|
| 396 |
gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
|
| 397 |
gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
|
|
|
|
| 5 |
import torchvision.transforms as transforms
|
| 6 |
from torch import nn
|
| 7 |
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
|
|
|
|
| 317 |
):
|
| 318 |
from ncut_pytorch import NCUT, rgb_from_tsne_3d
|
| 319 |
|
| 320 |
+
start = time.time()
|
| 321 |
eigvecs, eigvals = NCUT(
|
| 322 |
num_eig=num_eig,
|
| 323 |
num_sample=num_sample_ncut,
|
|
|
|
| 325 |
affinity_focal_gamma=affinity_focal_gamma,
|
| 326 |
knn=knn_ncut,
|
| 327 |
).fit_transform(features.reshape(-1, features.shape[-1]))
|
| 328 |
+
print(f"NCUT time: {time.time() - start:.2f}s")
|
| 329 |
+
|
| 330 |
+
start = time.time()
|
| 331 |
X_3d, rgb = rgb_from_tsne_3d(
|
| 332 |
eigvecs,
|
| 333 |
num_sample=num_sample_tsne,
|
| 334 |
perplexity=perplexity,
|
| 335 |
knn=knn_tsne,
|
| 336 |
)
|
| 337 |
+
print(f"t-SNE time: {time.time() - start:.2f}s")
|
| 338 |
+
|
| 339 |
rgb = rgb.reshape(features.shape[:3] + (3,))
|
| 340 |
return rgb
|
| 341 |
|
|
|
|
| 376 |
perplexity = num_sample_tsne - 1
|
| 377 |
|
| 378 |
images = [image[0] for image in images]
|
| 379 |
+
|
| 380 |
+
start = time.time()
|
| 381 |
features = extract_features(
|
| 382 |
images, model_name=model_name, node_type=node_type, layer=layer
|
| 383 |
)
|
| 384 |
+
print(f"Feature extraction time: {time.time() - start:.2f}s")
|
| 385 |
+
|
| 386 |
rgb = compute_ncut(
|
| 387 |
features,
|
| 388 |
num_eig=num_eig,
|
|
|
|
| 403 |
main_fn,
|
| 404 |
[
|
| 405 |
gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
|
| 406 |
+
gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
|
| 407 |
gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
|
| 408 |
gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
|
| 409 |
gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
|