huy-ha's picture
return results in gallery
19d656f
raw
history blame
2.33 kB
import gradio as gr
import numpy as np
from CLIP.clip import ClipWrapper, saliency_configs
from time import time
from matplotlib import pyplot as plt
import io
from PIL import Image
def plot_to_png(fig):
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
img = np.array(Image.open(buf)).astype(np.uint8)
return img
def generate_relevancy(
img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool
):
labels = labels.split(",")
prompts = [prompt]
assert img.dtype == np.uint8
img = Image.fromarray(img).resize((244 * 2, 244 * 2))
h, w, c = img.shape
grads = ClipWrapper.get_clip_saliency(
img=img,
text_labels=np.array(labels),
prompts=prompts,
**saliency_configs[saliency_config](h),
)[0]
if subtract_mean:
grads -= grads.mean(axis=0)
grads = grads.cpu().numpy()
vmin = 0.002
cmap = plt.get_cmap("jet")
vmax = 0.008
returns = []
for label_grad, label in zip(grads, labels):
fig, ax = plt.subplots(1, 1)
ax.axis("off")
ax.imshow(img)
ax.set_title(label, fontsize=12)
grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0)
colored_grad = cmap(grad)
grad = 1 - grad
colored_grad[..., -1] = grad * 0.7
ax.imshow(colored_grad)
plt.tight_layout(pad=0)
returns.append(plot_to_png(fig))
plt.close(fig)
return returns
iface = gr.Interface(
fn=generate_relevancy,
inputs=[
gr.Image(type="numpy", label="Image"),
gr.Textbox(label="Labels (comma separated)"),
gr.Textbox(label="Prompt"),
gr.Dropdown(
value="ours",
choices=["ours", "chefer_et_al"],
label="Relevancy Configuration",
),
gr.Checkbox(value=True, label="subtract mean"),
],
outputs=gr.Gallery(label="Relevancy Maps", type="numpy"),
examples=[
[
"https://semantic-abstraction.cs.columbia.edu/downloads/matterport.png",
"basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall",
"a photograph of a {} in a home.",
"ours",
True,
]
],
)
iface.launch()