Spaces:
Running
on
Zero
Running
on
Zero
add paper
Browse files- app.py +151 -17
- packages.txt +2 -0
app.py
CHANGED
|
@@ -183,6 +183,29 @@ downscaled_outputs = default_outputs
|
|
| 183 |
example_items = downscaled_images[:3] + downscaled_outputs[:3]
|
| 184 |
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
def ncut_run(
|
| 188 |
model,
|
|
@@ -212,7 +235,11 @@ def ncut_run(
|
|
| 212 |
video_output=False,
|
| 213 |
):
|
| 214 |
logging_str = ""
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
logging_str += f"Resolution: {resolution}\n"
|
| 217 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
| 218 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
|
@@ -227,9 +254,13 @@ def ncut_run(
|
|
| 227 |
node_type = node_type.split(":")[0].strip()
|
| 228 |
|
| 229 |
start = time.time()
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 234 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 235 |
|
|
@@ -301,8 +332,25 @@ def ncut_run(
|
|
| 301 |
)
|
| 302 |
logging_str += _logging_str
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
rgb = dont_use_too_much_green(rgb)
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
if video_output:
|
| 308 |
video_path = get_random_path()
|
|
@@ -313,16 +361,19 @@ def ncut_run(
|
|
| 313 |
return to_pil_images(rgb), logging_str
|
| 314 |
|
| 315 |
def _ncut_run(*args, **kwargs):
|
| 316 |
-
try:
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
except Exception as e:
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 328 |
@spaces.GPU(duration=20)
|
|
@@ -376,6 +427,28 @@ def transform_image(image, resolution=(1024, 1024)):
|
|
| 376 |
image = (image - 0.5) / 0.5
|
| 377 |
return image
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
def run_fn(
|
| 380 |
images,
|
| 381 |
model_name="SAM(sam_vit_b)",
|
|
@@ -416,12 +489,21 @@ def run_fn(
|
|
| 416 |
sampling_method = "farthest"
|
| 417 |
|
| 418 |
# resize the images before acquiring GPU
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
images = [tup[0] for tup in images]
|
| 421 |
images = [transform_image(image, resolution=resolution) for image in images]
|
| 422 |
images = torch.stack(images)
|
| 423 |
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
| 426 |
model.timestep = layer
|
| 427 |
layer = 1
|
|
@@ -932,7 +1014,59 @@ with demo:
|
|
| 932 |
# Last button only reveals the last row and hides itself
|
| 933 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 934 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 935 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 936 |
with gr.Row():
|
| 937 |
with gr.Column():
|
| 938 |
gr.Markdown("##### POWERED BY [ncut-pytorch](https://ncut-pytorch.readthedocs.io/) ")
|
|
|
|
| 183 |
example_items = downscaled_images[:3] + downscaled_outputs[:3]
|
| 184 |
|
| 185 |
|
| 186 |
+
def run_alignedthreemodelattnnodes(images, model, batch_size=1):
|
| 187 |
+
|
| 188 |
+
use_cuda = torch.cuda.is_available()
|
| 189 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 190 |
+
|
| 191 |
+
if use_cuda:
|
| 192 |
+
model = model.to(device)
|
| 193 |
+
|
| 194 |
+
chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 195 |
+
|
| 196 |
+
outputs = []
|
| 197 |
+
for idxs in chunked_idxs:
|
| 198 |
+
inp = images[idxs]
|
| 199 |
+
if use_cuda:
|
| 200 |
+
inp = inp.to(device)
|
| 201 |
+
out = model(inp)
|
| 202 |
+
# normalize before save
|
| 203 |
+
out = F.normalize(out, dim=-1)
|
| 204 |
+
outputs.append(out.cpu().float())
|
| 205 |
+
outputs = torch.cat(outputs, dim=0)
|
| 206 |
+
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
|
| 210 |
def ncut_run(
|
| 211 |
model,
|
|
|
|
| 235 |
video_output=False,
|
| 236 |
):
|
| 237 |
logging_str = ""
|
| 238 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 239 |
+
# dirty patch for the alignedcut paper
|
| 240 |
+
resolution = (672, 672)
|
| 241 |
+
else:
|
| 242 |
+
resolution = RES_DICT[model_name]
|
| 243 |
logging_str += f"Resolution: {resolution}\n"
|
| 244 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
| 245 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
|
|
|
| 254 |
node_type = node_type.split(":")[0].strip()
|
| 255 |
|
| 256 |
start = time.time()
|
| 257 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 258 |
+
# dirty patch for the alignedcut paper
|
| 259 |
+
features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE)
|
| 260 |
+
else:
|
| 261 |
+
features = extract_features(
|
| 262 |
+
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
| 263 |
+
)
|
| 264 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 265 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 266 |
|
|
|
|
| 332 |
)
|
| 333 |
logging_str += _logging_str
|
| 334 |
|
| 335 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 336 |
+
# dirty patch for the alignedcut paper
|
| 337 |
+
galleries = []
|
| 338 |
+
for i_node in range(rgb.shape[1]):
|
| 339 |
+
_rgb = rgb[:, i_node]
|
| 340 |
+
galleries.append(to_pil_images(_rgb))
|
| 341 |
+
return *galleries, logging_str
|
| 342 |
+
|
| 343 |
rgb = dont_use_too_much_green(rgb)
|
| 344 |
|
| 345 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 346 |
+
# dirty patch for the alignedcut paper
|
| 347 |
+
print("AlignedThreeModelAttnNodes")
|
| 348 |
+
galleries = []
|
| 349 |
+
for i_node in range(rgb.shape[1]):
|
| 350 |
+
_rgb = rgb[:, i_node]
|
| 351 |
+
print(_rgb.shape)
|
| 352 |
+
galleries.append(to_pil_images(_rgb))
|
| 353 |
+
return *galleries, logging_str
|
| 354 |
|
| 355 |
if video_output:
|
| 356 |
video_path = get_random_path()
|
|
|
|
| 361 |
return to_pil_images(rgb), logging_str
|
| 362 |
|
| 363 |
def _ncut_run(*args, **kwargs):
|
| 364 |
+
# try:
|
| 365 |
+
# ret = ncut_run(*args, **kwargs)
|
| 366 |
+
# if torch.cuda.is_available():
|
| 367 |
+
# torch.cuda.empty_cache()
|
| 368 |
+
# return ret
|
| 369 |
+
# except Exception as e:
|
| 370 |
+
# gr.Error(str(e))
|
| 371 |
+
# if torch.cuda.is_available():
|
| 372 |
+
# torch.cuda.empty_cache()
|
| 373 |
+
# return [], "Error: " + str(e)
|
| 374 |
+
|
| 375 |
+
ret = ncut_run(*args, **kwargs)
|
| 376 |
+
return ret
|
| 377 |
|
| 378 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 379 |
@spaces.GPU(duration=20)
|
|
|
|
| 427 |
image = (image - 0.5) / 0.5
|
| 428 |
return image
|
| 429 |
|
| 430 |
+
def load_alignedthreemodel():
|
| 431 |
+
|
| 432 |
+
os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1")
|
| 433 |
+
# pull
|
| 434 |
+
os.system("git -C alignedthreeattn pull >> /dev/null 2>&1")
|
| 435 |
+
# add to path
|
| 436 |
+
import sys
|
| 437 |
+
sys.path.append("alignedthreeattn")
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes
|
| 441 |
+
|
| 442 |
+
align_weights = torch.load("alignedthreeattn/align_weights.pth")
|
| 443 |
+
model = ThreeAttnNodes(align_weights)
|
| 444 |
+
|
| 445 |
+
# url = 'https://huggingface.co/huzey/aligned_model_test/resolve/main/3attn_nodes.pth'
|
| 446 |
+
# save_path = "alignedthreemodel.pth"
|
| 447 |
+
# if not os.path.exists(save_path):
|
| 448 |
+
# os.system(f"wget {url} -O {save_path} -q")
|
| 449 |
+
# model = torch.load(save_path)
|
| 450 |
+
return model
|
| 451 |
+
|
| 452 |
def run_fn(
|
| 453 |
images,
|
| 454 |
model_name="SAM(sam_vit_b)",
|
|
|
|
| 489 |
sampling_method = "farthest"
|
| 490 |
|
| 491 |
# resize the images before acquiring GPU
|
| 492 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 493 |
+
# dirty patch for the alignedcut paper
|
| 494 |
+
resolution = (672, 672)
|
| 495 |
+
else:
|
| 496 |
+
resolution = RES_DICT[model_name]
|
| 497 |
images = [tup[0] for tup in images]
|
| 498 |
images = [transform_image(image, resolution=resolution) for image in images]
|
| 499 |
images = torch.stack(images)
|
| 500 |
|
| 501 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 502 |
+
# dirty patch for the alignedcut paper
|
| 503 |
+
model = load_alignedthreemodel()
|
| 504 |
+
else:
|
| 505 |
+
model = load_model(model_name)
|
| 506 |
+
|
| 507 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
| 508 |
model.timestep = layer
|
| 509 |
layer = 1
|
|
|
|
| 1014 |
# Last button only reveals the last row and hides itself
|
| 1015 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 1016 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 1017 |
+
|
| 1018 |
+
with gr.Tab('Compare (Aligned)'):
|
| 1019 |
+
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 1020 |
+
gr.Markdown('---')
|
| 1021 |
+
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
| 1022 |
+
gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.')
|
| 1023 |
+
gr.Markdown('---')
|
| 1024 |
+
with gr.Row():
|
| 1025 |
+
with gr.Column(scale=5, min_width=200):
|
| 1026 |
+
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1027 |
+
|
| 1028 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
| 1029 |
+
num_images_slider.value = 100
|
| 1030 |
+
|
| 1031 |
+
with gr.Column(scale=5, min_width=200):
|
| 1032 |
+
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
| 1033 |
+
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
| 1034 |
+
[
|
| 1035 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1036 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1037 |
+
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1038 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1039 |
+
sampling_method_dropdown
|
| 1040 |
+
] = make_parameters_section()
|
| 1041 |
+
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
| 1042 |
+
model_dropdown.visible = False
|
| 1043 |
+
layer_slider.visible = False
|
| 1044 |
+
node_type_dropdown.visible = False
|
| 1045 |
+
# logging text box
|
| 1046 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1047 |
+
|
| 1048 |
+
galleries = []
|
| 1049 |
+
for i_model, model_name in enumerate(["CLIP", "DINO", "MAE"]):
|
| 1050 |
+
with gr.Row():
|
| 1051 |
+
for i_layer in range(1, 13):
|
| 1052 |
+
with gr.Column(scale=5, min_width=200):
|
| 1053 |
+
gr.Markdown(f'### {model_name} Layer {i_layer}')
|
| 1054 |
+
output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
|
| 1055 |
+
galleries.append(output_gallery)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
clear_images_button.click(lambda x: [] * (len(galleries) + 1), outputs=[input_gallery] + galleries)
|
| 1059 |
+
submit_button.click(
|
| 1060 |
+
run_fn,
|
| 1061 |
+
inputs=[
|
| 1062 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 1063 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1064 |
+
embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1065 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
| 1066 |
+
],
|
| 1067 |
+
outputs=galleries + [logging_text],
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
with gr.Row():
|
| 1071 |
with gr.Column():
|
| 1072 |
gr.Markdown("##### POWERED BY [ncut-pytorch](https://ncut-pytorch.readthedocs.io/) ")
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git-all
|
| 2 |
+
git-lfs
|