File size: 5,996 Bytes
6df18f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""Similarity tests as done in the CLick2Mask paper https://arxiv.org/abs/2409.08272.

Usage:
    python clip_similarity_tests.py --path PATH [--edited_alpha_clip_outs OUTPUT_PATH]

    Arguments:
        path: A path to a directory which should contain:

            1. A sub-directory for each item we wish to compare the methods on.
            Each sub-directory should be named as the item's index, and include the files:
            img.png  - The *input* file from Emu Edit benchmark
            emu.png  - Emu Edit output
            mb.png   - MagicBrush output
            ours.png - Our output

            2. A json named "captions_and_prompts.json", with the following structure:
            {
                "<item_index>": {
                        "input_caption": "<Input_caption from Emu Edit benchmark>.",
                        "output_caption": "<Output_caption from Emu Edit benchmark>"
                        "prompt": "<Un-localized instruction from Emu Edit benchmark:
                                    instruction from Emu Edit benchmark
                                    without the word indicating addition ('add', 'insert', etc.),
                                    and without the location to be edited>"
                },
                "<another item_index>": {
                        ...
                },
                ...
                }
            }

        edited_alpha_clip_outs:
            Optional output path for visualization of Edited Alpha-CLIP mask extractions.

See further explanations in methods below.
"""

import argparse
import torch
import os
import json
from similarities import SimTests
from edited_alpha_clip import EditedAlphaCLip


join = os.path.join


# Edited Alpha-CLIP similarity
def calc_edited_alpha_clip_sim(edited_alpha_clip, path, methods, texts, save_outs=None):
    print("\nCalculating Edited Alpha-CLIP similarities...")
    res_edited_alpha_clip = {m: {} for m in methods}

    dirs = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
    save_to = None
    for method in methods:
        for d in dirs:
            image_in_p = join(path, d, "img.png")
            image_out_p = join(path, d, f"{method}.png")
            prompt = texts[d]["prompt"]

            if save_outs:
                os.makedirs(join(save_outs, d), exist_ok=True)
                save_to = join(save_outs, d, method)
            changed_sim_out = edited_alpha_clip.edited_alpha_clip_sim(
                image_in_p, image_out_p, prompt, save_outs=save_to
            )
            res_edited_alpha_clip[method][d] = changed_sim_out

    print(f'{"*" * 4}\nEdited Alpha-CLIP similarities: (higher is better)')
    for method in methods:
        print(
            f"{method}: {torch.cat(list(res_edited_alpha_clip[method].values())).mean()}"
        )
    print(f'{"*" * 4}')
    if save_outs:
        print(f"Extracted masks saved to {save_outs}")
    print("\n")


# CLIP similarity
def calc_clip_sim(sim_tests, path, methods, texts):
    print("Calculating CLIP similarities...")
    res_clip_out = {m: {} for m in methods}
    res_clip_direction = {m: {} for m in methods}

    dirs = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
    for method in methods:
        for d in dirs:
            image_in = sim_tests.read_image(join(path, d, "img.png"))
            image_out = sim_tests.read_image(join(path, d, f"{method}.png"))
            text_in = texts[d]["input_caption"]
            text_out = texts[d]["output_caption"]

            sim_out, sim_direction = sim_tests.clip_sim(
                image_in=image_in,
                image_out=image_out,
                text_in=text_in,
                text_out=text_out,
            )
            res_clip_out[method][d] = sim_out
            res_clip_direction[method][d] = sim_direction

    print(f'{"*" * 4}\nCLIP output similarities: (higher is better)')
    for method in methods:
        print(f"{method}: {torch.cat(list(res_clip_out[method].values())).mean()}")
    print(f'{"*" * 4}\n')

    print(f'{"*" * 4}\nDirectional CLIP similarities: (higher is better)')
    for method in methods:
        print(
            f"{method}: {torch.cat(list(res_clip_direction[method].values())).mean()}"
        )
    print(f'{"*" * 4}\n\n')


# L1 distance
def calc_l1(sim_tests, path, methods):
    print("Calculating L1 distances...")
    res_L1 = {m: {} for m in methods}

    dirs = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
    for method in methods:
        for d in dirs:
            image_in = sim_tests.read_image(
                join(path, d, "img.png"), dest_size=(512, 512)
            )
            image_out = sim_tests.read_image(
                join(path, d, f"{method}.png"), dest_size=(512, 512)
            )
            res_L1[method][d] = sim_tests.L1_dist(
                image_in=image_in, image_out=image_out
            )

    print(f'{"*" * 4}\nL1 distances: (lower is better)')
    for method in methods:
        print(f"{method}: {torch.cat(list(res_L1[method].values())).mean()}")
    print(f'{"*" * 4}\n')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", required=True, help="Your path as explained above")
    parser.add_argument(
        "--edited_alpha_clip_outs",
        help="Optional output path for visualization of Edited Alpha-CLIP mask extractions.",
    )
    args = parser.parse_args()
    path = args.path
    edited_alpha_clip_outs = args.edited_alpha_clip_outs

    methods = ("emu", "mb", "ours")

    # A json with the dictionary explained above
    with open(join(path, "captions_and_prompts.json"), "r") as f:
        texts = json.load(f)

    edited_alpha_clip = EditedAlphaCLip()
    calc_edited_alpha_clip_sim(
        edited_alpha_clip, path, methods, texts, save_outs=edited_alpha_clip_outs
    )

    sim_tests = SimTests()
    calc_clip_sim(sim_tests, path, methods, texts)
    calc_l1(sim_tests, path, methods)