click2mask / scripts /similarity_tests /paper_similarities_tests.py
omeregev's picture
Initial commit
6df18f5
"""Similarity tests as done in the CLick2Mask paper https://arxiv.org/abs/2409.08272.
Usage:
python clip_similarity_tests.py --path PATH [--edited_alpha_clip_outs OUTPUT_PATH]
Arguments:
path: A path to a directory which should contain:
1. A sub-directory for each item we wish to compare the methods on.
Each sub-directory should be named as the item's index, and include the files:
img.png - The *input* file from Emu Edit benchmark
emu.png - Emu Edit output
mb.png - MagicBrush output
ours.png - Our output
2. A json named "captions_and_prompts.json", with the following structure:
{
"<item_index>": {
"input_caption": "<Input_caption from Emu Edit benchmark>.",
"output_caption": "<Output_caption from Emu Edit benchmark>"
"prompt": "<Un-localized instruction from Emu Edit benchmark:
instruction from Emu Edit benchmark
without the word indicating addition ('add', 'insert', etc.),
and without the location to be edited>"
},
"<another item_index>": {
...
},
...
}
}
edited_alpha_clip_outs:
Optional output path for visualization of Edited Alpha-CLIP mask extractions.
See further explanations in methods below.
"""
import argparse
import torch
import os
import json
from similarities import SimTests
from edited_alpha_clip import EditedAlphaCLip
join = os.path.join
# Edited Alpha-CLIP similarity
def calc_edited_alpha_clip_sim(edited_alpha_clip, path, methods, texts, save_outs=None):
print("\nCalculating Edited Alpha-CLIP similarities...")
res_edited_alpha_clip = {m: {} for m in methods}
dirs = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
save_to = None
for method in methods:
for d in dirs:
image_in_p = join(path, d, "img.png")
image_out_p = join(path, d, f"{method}.png")
prompt = texts[d]["prompt"]
if save_outs:
os.makedirs(join(save_outs, d), exist_ok=True)
save_to = join(save_outs, d, method)
changed_sim_out = edited_alpha_clip.edited_alpha_clip_sim(
image_in_p, image_out_p, prompt, save_outs=save_to
)
res_edited_alpha_clip[method][d] = changed_sim_out
print(f'{"*" * 4}\nEdited Alpha-CLIP similarities: (higher is better)')
for method in methods:
print(
f"{method}: {torch.cat(list(res_edited_alpha_clip[method].values())).mean()}"
)
print(f'{"*" * 4}')
if save_outs:
print(f"Extracted masks saved to {save_outs}")
print("\n")
# CLIP similarity
def calc_clip_sim(sim_tests, path, methods, texts):
print("Calculating CLIP similarities...")
res_clip_out = {m: {} for m in methods}
res_clip_direction = {m: {} for m in methods}
dirs = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
for method in methods:
for d in dirs:
image_in = sim_tests.read_image(join(path, d, "img.png"))
image_out = sim_tests.read_image(join(path, d, f"{method}.png"))
text_in = texts[d]["input_caption"]
text_out = texts[d]["output_caption"]
sim_out, sim_direction = sim_tests.clip_sim(
image_in=image_in,
image_out=image_out,
text_in=text_in,
text_out=text_out,
)
res_clip_out[method][d] = sim_out
res_clip_direction[method][d] = sim_direction
print(f'{"*" * 4}\nCLIP output similarities: (higher is better)')
for method in methods:
print(f"{method}: {torch.cat(list(res_clip_out[method].values())).mean()}")
print(f'{"*" * 4}\n')
print(f'{"*" * 4}\nDirectional CLIP similarities: (higher is better)')
for method in methods:
print(
f"{method}: {torch.cat(list(res_clip_direction[method].values())).mean()}"
)
print(f'{"*" * 4}\n\n')
# L1 distance
def calc_l1(sim_tests, path, methods):
print("Calculating L1 distances...")
res_L1 = {m: {} for m in methods}
dirs = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
for method in methods:
for d in dirs:
image_in = sim_tests.read_image(
join(path, d, "img.png"), dest_size=(512, 512)
)
image_out = sim_tests.read_image(
join(path, d, f"{method}.png"), dest_size=(512, 512)
)
res_L1[method][d] = sim_tests.L1_dist(
image_in=image_in, image_out=image_out
)
print(f'{"*" * 4}\nL1 distances: (lower is better)')
for method in methods:
print(f"{method}: {torch.cat(list(res_L1[method].values())).mean()}")
print(f'{"*" * 4}\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path", required=True, help="Your path as explained above")
parser.add_argument(
"--edited_alpha_clip_outs",
help="Optional output path for visualization of Edited Alpha-CLIP mask extractions.",
)
args = parser.parse_args()
path = args.path
edited_alpha_clip_outs = args.edited_alpha_clip_outs
methods = ("emu", "mb", "ours")
# A json with the dictionary explained above
with open(join(path, "captions_and_prompts.json"), "r") as f:
texts = json.load(f)
edited_alpha_clip = EditedAlphaCLip()
calc_edited_alpha_clip_sim(
edited_alpha_clip, path, methods, texts, save_outs=edited_alpha_clip_outs
)
sim_tests = SimTests()
calc_clip_sim(sim_tests, path, methods, texts)
calc_l1(sim_tests, path, methods)