venkyvicky's picture
Upload app.py
9483059 verified
raw
history blame
1.84 kB
import torch
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Load model and processor once
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
def detect_objects(image):
# Run DETR model
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Draw boxes on image
fig, ax = plt.subplots(1)
ax.imshow(image)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
xmin, ymin, xmax, ymax = box.tolist()
ax.add_patch(patches.Rectangle(
(xmin, ymin), xmax - xmin, ymax - ymin,
linewidth=2, edgecolor='red', facecolor='none'))
ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}",
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
# Save output to bytes buffer
buf = io.BytesIO()
plt.axis("off")
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# Create Gradio interface
interface = gr.Interface(fn=detect_objects,
inputs=gr.Image(type="pil"),
outputs="image",
title="DETR Object Detection",
description="Upload an image to detect objects using Facebook's DETR model.")
# Launch the app locally
if __name__ == "__main__":
interface.launch(share=True)