Spaces:
Sleeping
Sleeping
Add cat/dog detector app
Browse files- app.py +113 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
import requests # To handle image URLs if needed, but we focus on uploads
|
6 |
+
|
7 |
+
# Load the model and processor
|
8 |
+
# Using revision="no_timm" to potentially avoid the timm dependency if not installed,
|
9 |
+
# but it's safer to include timm in requirements.txt
|
10 |
+
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
|
11 |
+
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101-dc5")
|
12 |
+
|
13 |
+
# Define class names for filtering (check model.config.id2label for exact mapping)
|
14 |
+
# Common COCO IDs: cat=16, dog=17 (0-indexed) but let's use labels
|
15 |
+
# We need to get the actual labels the model uses
|
16 |
+
id2label = model.config.id2label
|
17 |
+
target_labels = ["cat", "dog"]
|
18 |
+
target_ids = [label_id for label_id, label in id2label.items() if label in target_labels]
|
19 |
+
|
20 |
+
# Colors for bounding boxes (simple example)
|
21 |
+
colors = {"cat": "red", "dog": "blue"}
|
22 |
+
|
23 |
+
def detect_objects(image_input):
|
24 |
+
"""
|
25 |
+
Detects cats and dogs in the input image using DETR.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image_input (PIL.Image.Image): Input image.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
PIL.Image.Image: Image with bounding boxes drawn around detected cats/dogs.
|
32 |
+
"""
|
33 |
+
if image_input is None:
|
34 |
+
return None
|
35 |
+
|
36 |
+
# Convert Gradio input (if numpy) to PIL Image, although type="pil" should handle this
|
37 |
+
if not isinstance(image_input, Image.Image):
|
38 |
+
image = Image.fromarray(image_input)
|
39 |
+
else:
|
40 |
+
image = image_input.copy() # Work on a copy
|
41 |
+
|
42 |
+
# Preprocess the image
|
43 |
+
inputs = processor(images=image, return_tensors="pt")
|
44 |
+
|
45 |
+
# Perform inference
|
46 |
+
outputs = model(**inputs)
|
47 |
+
|
48 |
+
# Post-process the results
|
49 |
+
# Convert outputs (bounding boxes and class logits) to COCO API format
|
50 |
+
target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
|
51 |
+
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lower threshold (e.g., 0.5) might find more objects
|
52 |
+
|
53 |
+
# Draw bounding boxes for cats and dogs
|
54 |
+
draw = ImageDraw.Draw(image)
|
55 |
+
try:
|
56 |
+
# Use a default font or specify a path to a .ttf file if available in the Space
|
57 |
+
font = ImageFont.load_default()
|
58 |
+
except IOError:
|
59 |
+
print("Default font not found. Using basic drawing without text.")
|
60 |
+
font = None
|
61 |
+
|
62 |
+
detections_found = False
|
63 |
+
for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
|
64 |
+
label_id = label_id.item()
|
65 |
+
if label_id in target_ids:
|
66 |
+
detections_found = True
|
67 |
+
box = [round(i, 2) for i in box.tolist()]
|
68 |
+
label = id2label[label_id]
|
69 |
+
box_color = colors.get(label, "green") # Default to green if label not in colors dict
|
70 |
+
|
71 |
+
print(f"Detected {label} with confidence {round(score.item(), 3)} at {box}")
|
72 |
+
|
73 |
+
# Draw rectangle
|
74 |
+
draw.rectangle(box, outline=box_color, width=3)
|
75 |
+
|
76 |
+
# Draw label text
|
77 |
+
if font:
|
78 |
+
text = f"{label}: {score.item():.2f}"
|
79 |
+
text_width, text_height = font.getsize(text) if hasattr(font, 'getsize') else (50, 10) # Estimate size if getsize not available
|
80 |
+
text_bg_coords = [(box[0], box[1]), (box[0] + text_width + 4, box[1] + text_height + 4)]
|
81 |
+
draw.rectangle(text_bg_coords, fill=box_color)
|
82 |
+
draw.text((box[0] + 2, box[1] + 2), text, fill="white", font=font)
|
83 |
+
|
84 |
+
if not detections_found:
|
85 |
+
print("No cats or dogs detected with the current threshold.")
|
86 |
+
# Optionally add text to the image saying nothing was found
|
87 |
+
# draw.text((10, 10), "No cats or dogs detected", fill="black", font=font)
|
88 |
+
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
# Create the Gradio interface
|
93 |
+
title = "Cat & Dog Detector (using DETR ResNet-101)"
|
94 |
+
description = ("Upload an image and the model will draw bounding boxes "
|
95 |
+
"around detected cats and dogs. Uses the facebook/detr-resnet-101-dc5 model from Hugging Face.")
|
96 |
+
|
97 |
+
iface = gr.Interface(
|
98 |
+
fn=detect_objects,
|
99 |
+
inputs=gr.Image(type="pil", label="Upload Image"),
|
100 |
+
outputs=gr.Image(type="pil", label="Output Image with Detections"),
|
101 |
+
title=title,
|
102 |
+
description=description,
|
103 |
+
examples=[
|
104 |
+
# You can add paths to example images if you upload them to your space
|
105 |
+
# Or provide URLs
|
106 |
+
["http://images.cocodataset.org/val2017/000000039769.jpg"], # Example image URL with cats
|
107 |
+
["https://storage.googleapis.com/petbacker/images/blog/2017/dog-and-cat-cover.jpg"] # Example image with dog and cat
|
108 |
+
],
|
109 |
+
allow_flagging="never" # You can change flagging options if needed
|
110 |
+
)
|
111 |
+
|
112 |
+
# Launch the app
|
113 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.9.0
|
2 |
+
transformers>=4.15.0
|
3 |
+
Pillow>=9.0.0
|
4 |
+
gradio>=3.0.0
|
5 |
+
requests>=2.25.0
|
6 |
+
timm>=0.5.4
|