File size: 1,843 Bytes
b1c60a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
import cv2

# Load the pre-trained CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def apply_gradcam(image, text):
    inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    
    image_embeds = outputs.image_embeds
    text_embeds = outputs.text_embeds
    similarity = torch.nn.functional.cosine_similarity(image_embeds, text_embeds)
    similarity.backward()

    gradients = model.get_input_embeddings().weight.grad
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    activations = outputs.last_hidden_state
    for i in range(pooled_gradients.shape[0]):
        activations[:, i, :, :] *= pooled_gradients[i]
    heatmap = torch.mean(activations, dim=1).squeeze().detach().cpu().numpy()

    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
    heatmap = cv2.resize(heatmap, (image.size[0], image.size[1]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    superimposed_img = cv2.addWeighted(np.array(image), 0.6, heatmap, 0.4, 0)
    return superimposed_img

def highlight_image(image, text):
    highlighted_image = apply_gradcam(image, text)
    return Image.fromarray(highlighted_image)

# Define Gradio interface
iface = gr.Interface(
    fn=highlight_image,
    inputs=[gr.Image(type="pil"), gr.Textbox(label="Text Description")],
    outputs=gr.Image(type="pil"),
    title="Image Text Highlight",
    description="Upload an image and provide a text description to highlight the relevant part of the image."
)

# Launch the Gradio app
iface.launch()