File size: 2,650 Bytes
184fd64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8bd0d3
184fd64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84fc022
184fd64
 
477e5a7
184fd64
b8bd0d3
 
 
184fd64
 
 
477e5a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184fd64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image
from io import BytesIO
from transformers import OwlViTProcessor, OwlViTForObjectDetection


# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")


def query_image(img, text_queries, score_threshold):
    text_queries = text_queries.split(",")

    img = np.array(img)

    target_sizes = torch.Tensor([img.shape[:2]])
    inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    outputs.logits = outputs.logits.cpu()
    outputs.pred_boxes = outputs.pred_boxes.cpu()
    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

    font = cv2.FONT_HERSHEY_SIMPLEX

    for box, score, label in zip(boxes, scores, labels):
        box = [int(i) for i in box.tolist()]

        if score >= score_threshold:
            img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
            if box[3] + 25 > 768:
                y = box[3] - 10
            else:
                y = box[3] + 25

            img = cv2.putText(
                img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
            )
    return img


description = """
\n\nYou can use OWL-ViT to query images with text descriptions of any object.
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.
"""
upload = gr.Interface(
    query_image, 
    inputs=[gr.Image(source="upload"), 
            "text", 
            gr.Slider(0, 1, value=0.1)],
    outputs="image",
    title="Zero-Shot Object Detection with OWL-ViT",
    description=description,
    examples=["./examples/IMGP0178.jpg"],)

web = gr.Interface(
    query_image, 
    inputs=[gr.Image(source="webcam"), 
            "text", 
            gr.Slider(0, 1, value=0.1)],
    outputs="image",
    title="Zero-Shot Object Detection with OWL-ViT",
    description=description,
    examples=["./examples/IMGP0178.jpg"],)

demo = gr.TabbedInterface(interface_list=[upload, web],
                          tab_names=["From a File", "From your Webcam"])

demo.launch()