quydoan commited on
Commit
20b303f
·
1 Parent(s): ce120d9

Add cat/dog detector app

Browse files
Files changed (2) hide show
  1. app.py +113 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import requests # To handle image URLs if needed, but we focus on uploads
6
+
7
+ # Load the model and processor
8
+ # Using revision="no_timm" to potentially avoid the timm dependency if not installed,
9
+ # but it's safer to include timm in requirements.txt
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101-dc5")
12
+
13
+ # Define class names for filtering (check model.config.id2label for exact mapping)
14
+ # Common COCO IDs: cat=16, dog=17 (0-indexed) but let's use labels
15
+ # We need to get the actual labels the model uses
16
+ id2label = model.config.id2label
17
+ target_labels = ["cat", "dog"]
18
+ target_ids = [label_id for label_id, label in id2label.items() if label in target_labels]
19
+
20
+ # Colors for bounding boxes (simple example)
21
+ colors = {"cat": "red", "dog": "blue"}
22
+
23
+ def detect_objects(image_input):
24
+ """
25
+ Detects cats and dogs in the input image using DETR.
26
+
27
+ Args:
28
+ image_input (PIL.Image.Image): Input image.
29
+
30
+ Returns:
31
+ PIL.Image.Image: Image with bounding boxes drawn around detected cats/dogs.
32
+ """
33
+ if image_input is None:
34
+ return None
35
+
36
+ # Convert Gradio input (if numpy) to PIL Image, although type="pil" should handle this
37
+ if not isinstance(image_input, Image.Image):
38
+ image = Image.fromarray(image_input)
39
+ else:
40
+ image = image_input.copy() # Work on a copy
41
+
42
+ # Preprocess the image
43
+ inputs = processor(images=image, return_tensors="pt")
44
+
45
+ # Perform inference
46
+ outputs = model(**inputs)
47
+
48
+ # Post-process the results
49
+ # Convert outputs (bounding boxes and class logits) to COCO API format
50
+ target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
51
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lower threshold (e.g., 0.5) might find more objects
52
+
53
+ # Draw bounding boxes for cats and dogs
54
+ draw = ImageDraw.Draw(image)
55
+ try:
56
+ # Use a default font or specify a path to a .ttf file if available in the Space
57
+ font = ImageFont.load_default()
58
+ except IOError:
59
+ print("Default font not found. Using basic drawing without text.")
60
+ font = None
61
+
62
+ detections_found = False
63
+ for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
64
+ label_id = label_id.item()
65
+ if label_id in target_ids:
66
+ detections_found = True
67
+ box = [round(i, 2) for i in box.tolist()]
68
+ label = id2label[label_id]
69
+ box_color = colors.get(label, "green") # Default to green if label not in colors dict
70
+
71
+ print(f"Detected {label} with confidence {round(score.item(), 3)} at {box}")
72
+
73
+ # Draw rectangle
74
+ draw.rectangle(box, outline=box_color, width=3)
75
+
76
+ # Draw label text
77
+ if font:
78
+ text = f"{label}: {score.item():.2f}"
79
+ text_width, text_height = font.getsize(text) if hasattr(font, 'getsize') else (50, 10) # Estimate size if getsize not available
80
+ text_bg_coords = [(box[0], box[1]), (box[0] + text_width + 4, box[1] + text_height + 4)]
81
+ draw.rectangle(text_bg_coords, fill=box_color)
82
+ draw.text((box[0] + 2, box[1] + 2), text, fill="white", font=font)
83
+
84
+ if not detections_found:
85
+ print("No cats or dogs detected with the current threshold.")
86
+ # Optionally add text to the image saying nothing was found
87
+ # draw.text((10, 10), "No cats or dogs detected", fill="black", font=font)
88
+
89
+
90
+ return image
91
+
92
+ # Create the Gradio interface
93
+ title = "Cat & Dog Detector (using DETR ResNet-101)"
94
+ description = ("Upload an image and the model will draw bounding boxes "
95
+ "around detected cats and dogs. Uses the facebook/detr-resnet-101-dc5 model from Hugging Face.")
96
+
97
+ iface = gr.Interface(
98
+ fn=detect_objects,
99
+ inputs=gr.Image(type="pil", label="Upload Image"),
100
+ outputs=gr.Image(type="pil", label="Output Image with Detections"),
101
+ title=title,
102
+ description=description,
103
+ examples=[
104
+ # You can add paths to example images if you upload them to your space
105
+ # Or provide URLs
106
+ ["http://images.cocodataset.org/val2017/000000039769.jpg"], # Example image URL with cats
107
+ ["https://storage.googleapis.com/petbacker/images/blog/2017/dog-and-cat-cover.jpg"] # Example image with dog and cat
108
+ ],
109
+ allow_flagging="never" # You can change flagging options if needed
110
+ )
111
+
112
+ # Launch the app
113
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.15.0
3
+ Pillow>=9.0.0
4
+ gradio>=3.0.0
5
+ requests>=2.25.0
6
+ timm>=0.5.4