|
import spaces |
|
import gradio as gr |
|
import cv2 |
|
import tempfile |
|
from PIL import Image, ImageDraw, ImageFont |
|
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor |
|
import torch |
|
import requests |
|
|
|
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") |
|
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd", torch_dtype=torch.float16).to("cuda") |
|
model = torch.compile(model, mode="reduce-overhead") |
|
|
|
|
|
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
inputs = image_processor(images=image, return_tensors="pt").to("cuda", torch.float16) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
def draw_bounding_boxes(image, results, model, threshold=0.3): |
|
draw = ImageDraw.Draw(image) |
|
for result in results: |
|
for score, label_id, box in zip( |
|
result["scores"], result["labels"], result["boxes"] |
|
): |
|
if score > threshold: |
|
label = model.config.id2label[label_id.item()] |
|
box = [round(i) for i in box.tolist()] |
|
draw.rectangle(box, outline="red", width=3) |
|
draw.text((box[0], box[1]), f"{label}: {score:.2f}", fill="red") |
|
return image |
|
|
|
import time |
|
|
|
@spaces.GPU |
|
def inference(image, conf_threshold): |
|
inputs = image_processor(images=image, return_tensors="pt") |
|
|
|
start = time.time() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
results = image_processor.post_process_object_detection( |
|
outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=conf_threshold |
|
) |
|
end = time.time() |
|
print("time: ", end - start) |
|
|
|
bbs = draw_bounding_boxes(image, results, model, threshold=conf_threshold) |
|
print("bbs: ", time.time() - end) |
|
return bbs |
|
|
|
|
|
css = """.my-group {max-width: 600px !important; max-height: 600 !important;} |
|
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" |
|
|
|
with gr.Blocks(css=css) as app: |
|
gr.HTML( |
|
""" |
|
<h1 style='text-align: center'> |
|
Near Real-Time Webcam Stream with RT-DETR |
|
</h1> |
|
""" |
|
) |
|
gr.HTML( |
|
""" |
|
<h3 style='text-align: center'> |
|
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://github.com/lyuwenyu/RT-DETR' target='_blank'>github</a> |
|
</h3> |
|
""" |
|
) |
|
with gr.Column(elem_classes=["my-column"]): |
|
with gr.Group(elem_classes=["my-group"]): |
|
image = gr.Image( |
|
type="pil", |
|
label="Image", |
|
sources="webcam", |
|
) |
|
conf_threshold = gr.Slider( |
|
label="Confidence Threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.85, |
|
) |
|
image.stream( |
|
fn=inference, |
|
inputs=[image, conf_threshold], |
|
outputs=[image], |
|
stream_every=0.1, |
|
time_limit=30, |
|
) |
|
if __name__ == "__main__": |
|
app.launch() |
|
|