File size: 2,325 Bytes
8fbac9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac86774
 
 
8fbac9e
 
 
19d656f
8fbac9e
 
 
 
 
995510e
8fbac9e
ac86774
 
8fbac9e
 
 
 
19d656f
 
 
 
8fbac9e
 
 
 
 
 
 
 
19d656f
 
 
 
8fbac9e
 
 
 
 
ac86774
 
 
 
 
 
 
 
 
8fbac9e
19d656f
8fbac9e
 
 
 
 
7593627
 
8fbac9e
 
 
9f4163a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import numpy as np
from CLIP.clip import ClipWrapper, saliency_configs
from time import time
from matplotlib import pyplot as plt
import io
from PIL import Image


def plot_to_png(fig):
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    img = np.array(Image.open(buf)).astype(np.uint8)
    return img


def generate_relevancy(
    img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool
):
    labels = labels.split(",")
    prompts = [prompt]
    assert img.dtype == np.uint8
    img = Image.fromarray(img).resize((244 * 2, 244 * 2))
    h, w, c = img.shape
    grads = ClipWrapper.get_clip_saliency(
        img=img,
        text_labels=np.array(labels),
        prompts=prompts,
        **saliency_configs[saliency_config](h),
    )[0]
    if subtract_mean:
        grads -= grads.mean(axis=0)
    grads = grads.cpu().numpy()
    vmin = 0.002
    cmap = plt.get_cmap("jet")
    vmax = 0.008

    returns = []
    for label_grad, label in zip(grads, labels):
        fig, ax = plt.subplots(1, 1)
        ax.axis("off")
        ax.imshow(img)
        ax.set_title(label, fontsize=12)
        grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0)
        colored_grad = cmap(grad)
        grad = 1 - grad
        colored_grad[..., -1] = grad * 0.7
        ax.imshow(colored_grad)
        plt.tight_layout(pad=0)
        returns.append(plot_to_png(fig))
        plt.close(fig)
    return returns


iface = gr.Interface(
    fn=generate_relevancy,
    inputs=[
        gr.Image(type="numpy", label="Image"),
        gr.Textbox(label="Labels (comma separated)"),
        gr.Textbox(label="Prompt"),
        gr.Dropdown(
            value="ours",
            choices=["ours", "chefer_et_al"],
            label="Relevancy Configuration",
        ),
        gr.Checkbox(value=True, label="subtract mean"),
    ],
    outputs=gr.Gallery(label="Relevancy Maps", type="numpy"),
    examples=[
        [
            "https://semantic-abstraction.cs.columbia.edu/downloads/matterport.png",
            "basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall",
            "a photograph of a {} in a home.",
            "ours",
            True,
        ]
    ],
)
iface.launch()