File size: 2,084 Bytes
e4c1abf
1e2b659
 
 
b4b347f
1e2b659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4c1abf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e81a5a2
b4b347f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import pipeline
import io 
import matplotlib.pyplot as plt 
import torch
from PIL import Image

def render_results_in_image(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)

    ax = plt.gca()

    for prediction in in_results:

        x, y = prediction['box']['xmin'], prediction['box']['ymin']
        w = prediction['box']['xmax'] - prediction['box']['xmin']
        h = prediction['box']['ymax'] - prediction['box']['ymin']

        ax.add_patch(plt.Rectangle((x, y),
                                   w,
                                   h,
                                   fill=False,
                                   color="green",
                                   linewidth=2))
        ax.text(
           x,
           y,
           f"{prediction['label']}: {round(prediction['score']*100, 1)}%",
           color='red'
        )

    plt.axis("off")

    # Save the modified image to a BytesIO object
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png',
                bbox_inches='tight',
                pad_inches=0)
    img_buf.seek(0)
    modified_image = Image.open(img_buf)

    # Close the plot to prevent it from being displayed
    plt.close()

    return modified_image


od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")

import gradio as gr

def get_pipeline_prediction(pil_image):
    #first get the pipeline output given the pil image
    pipeline_output = od_pipe(pil_image)

    #then process the image using the pipeline output
    processed_image = render_results_in_image(pil_image, pipeline_output)
    return processed_image

demo = gr.Interface(
    fn= get_pipeline_prediction,
    inputs=gr.Image(label="Input Image",
                    type="pil"),
                    outputs=gr.Image(label="Output Image with predictions",
                                     type="pil"),
    title="Object Detection API",
    description="Just upload your image and let ObjectDetect API work its magic, revealing the objects waiting to be discovered"
)

demo.launch()