Spaces:
Build error
Build error
import gradio as gr | |
import numpy as np | |
from CLIP.clip import ClipWrapper, saliency_configs | |
from time import time | |
from matplotlib import pyplot as plt | |
import io | |
from PIL import Image | |
def plot_to_png(fig): | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
img = np.array(Image.open(buf)).astype(np.uint8) | |
return img | |
def generate_relevancy( | |
img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool | |
): | |
labels = labels.split(",") | |
prompts = [prompt] | |
assert img.dtype == np.uint8 | |
h, w, c = img.shape | |
start = time() | |
grads = ClipWrapper.get_clip_saliency( | |
img=img, | |
text_labels=np.array(labels), | |
prompts=prompts, | |
**saliency_configs[saliency_config](h), | |
)[0] | |
if subtract_mean: | |
grads -= grads.mean(axis=0) | |
grads = grads.cpu().numpy() | |
num_axes = int(np.ceil(np.sqrt(len(labels)))) | |
fig, axes = plt.subplots(num_axes, num_axes) | |
if num_axes == 1: | |
axes = [axes] | |
else: | |
axes = axes.flatten() | |
vmin = 0.002 | |
cmap = plt.get_cmap("jet") | |
vmax = 0.008 | |
for ax, label_grad, label in zip(axes, grads, labels): | |
ax.axis("off") | |
ax.imshow(img) | |
ax.set_title(label, fontsize=12) | |
grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0) | |
colored_grad = cmap(grad) | |
grad = 1 - grad | |
colored_grad[..., -1] = grad * 0.7 | |
ax.imshow(colored_grad) | |
plt.tight_layout(pad=0) | |
img = plot_to_png(fig) | |
plt.close(fig) | |
return img | |
iface = gr.Interface( | |
fn=generate_relevancy, | |
inputs=[ | |
gr.Image(type="numpy", label="Image"), | |
gr.Textbox(label="Labels (comma separated)"), | |
gr.Textbox(label="Prompt"), | |
gr.Dropdown( | |
value="ours", | |
choices=["ours", "chefer_et_al"], | |
label="Relevancy Configuration", | |
), | |
gr.Checkbox(value=True, label="subtract mean"), | |
], | |
outputs=gr.Image(type="numpy"), | |
examples=[ | |
[ | |
"https://semantic-abstraction.cs.columbia.edu/downloads/matterport.png", | |
"basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall", | |
"a photograph of a {} in a home.", | |
] | |
], | |
) | |
iface.launch(inbrowser=True) | |