mrdbourke commited on
Commit
ddd1e4a
·
verified ·
1 Parent(s): 6f0c93e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ from transformers import GroundingDinoProcessor
4
+ from modeling_grounding_dino import GroundingDinoForObjectDetection
5
+
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from itertools import cycle
8
+
9
+ import gradio as gr
10
+
11
+ import spaces
12
+
13
+ # Load model and processor
14
+ model_id = "fushh7/llmdet_swin_large_hf"
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ print(f"[INFO] Using device: {DEVICE}")
18
+ print(f"[INFO] Loading model from {model_id}...")
19
+
20
+ processor = GroundingDinoProcessor.from_pretrained(model_id)
21
+ model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE)
22
+ model.eval();
23
+
24
+ print("[INFO] Model loaded successfully.")
25
+
26
+ # Pre-defined palette (extend or tweak as you like)
27
+ BOX_COLORS = [
28
+ "deepskyblue", "red", "lime", "dodgerblue",
29
+ "cyan", "magenta", "yellow",
30
+ "orange", "chartreuse"
31
+ ]
32
+
33
+ def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
34
+ """
35
+ Draw bounding boxes and labels on a PIL Image.
36
+
37
+ :param image: PIL Image object
38
+ :param boxes: Iterable of [x_min, y_min, x_max, y_max]
39
+ :param labels: Iterable of label strings
40
+ :param scores: Iterable of scalar confidences (0-1)
41
+ :param colors: List/tuple of colour names or RGB tuples
42
+ :param font_path: Path to a TTF font for labels
43
+ :param font_size: Int size of font to use, default 16
44
+ :return: PIL Image with drawn boxes
45
+ """
46
+ # Ensure we can iterate colours indefinitely
47
+ colour_cycle = cycle(colors)
48
+ draw = ImageDraw.Draw(image)
49
+
50
+ # Pick a font (fallback to default if missing)
51
+ try:
52
+ font = ImageFont.truetype(font_path, size=font_size)
53
+ except IOError:
54
+ font = ImageFont.load_default(size=font_size)
55
+
56
+ # Assign a consistent colour per label (optional)
57
+ label_to_colour = {}
58
+
59
+ for box, label, score in zip(boxes, labels, scores):
60
+ # Reuse colour if label seen before, else take next from cycle
61
+ colour = label_to_colour.setdefault(label, next(colour_cycle))
62
+
63
+ x_min, y_min, x_max, y_max = map(int, box)
64
+
65
+ # Draw rectangle
66
+ draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2)
67
+
68
+ # Compose text
69
+ text = f"{label} ({score:.3f})"
70
+ text_size = draw.textbbox((0, 0), text, font=font)[2:]
71
+
72
+ # Draw text background for legibility
73
+ bg_coords = [x_min, y_min - text_size[1] - 4,
74
+ x_min + text_size[0] + 4, y_min]
75
+ draw.rectangle(bg_coords, fill=colour)
76
+
77
+ # Draw text
78
+ draw.text((x_min + 2, y_min - text_size[1] - 2),
79
+ text, fill="black", font=font)
80
+
81
+ return image
82
+
83
+ def resize_image_max_dimension(image, max_size=1024):
84
+ """
85
+ Resize an image so that the longest side is at most max_size pixels,
86
+ while maintaining the aspect ratio.
87
+
88
+ :param image: PIL Image object
89
+ :param max_size: Maximum dimension in pixels (default: 1024)
90
+ :return: PIL Image object (resized)
91
+ """
92
+ width, height = image.size
93
+
94
+ # Check if resizing is needed
95
+ if max(width, height) <= max_size:
96
+ return image
97
+
98
+ # Calculate new dimensions maintaining aspect ratio
99
+ ratio = max_size / max(width, height)
100
+ new_width = int(width * ratio)
101
+ new_height = int(height * ratio)
102
+
103
+ # Resize the image using high-quality resampling
104
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
105
+
106
+ @spaces.GPU(duration=120)
107
+ def detect_and_draw(
108
+ img: Image.Image,
109
+ text_query: str,
110
+ box_threshold: float = 0.4,
111
+ text_threshold: float = 0.3
112
+ ) -> Image.Image:
113
+ """
114
+ Detect objects described in `text_query`, draw boxes, return the image.
115
+ Note: `text_query` must be lowercase and each concept ends with a dot
116
+ (e.g. 'a cat. a remote control.')
117
+ """
118
+
119
+ # Make sure text is lowered
120
+ text_query = text_query.lower()
121
+
122
+ # If the image size is too large, we make it smaller
123
+ img = resize_image_max_dimension(img, max_size=1024)
124
+
125
+ # Preprocess the image
126
+ inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE)
127
+
128
+ with torch.no_grad():
129
+ outputs = model(**inputs)
130
+
131
+ results = processor.post_process_grounded_object_detection(
132
+ outputs,
133
+ inputs.input_ids,
134
+ box_threshold=box_threshold,
135
+ text_threshold=text_threshold,
136
+ target_sizes=[img.size[::-1]]
137
+ )[0]
138
+
139
+ img_out = img.copy()
140
+ img_out = draw_boxes(
141
+ img_out,
142
+ boxes = results["boxes"].cpu().numpy(),
143
+ labels = results.get("text_labels", results.get("labels", [])),
144
+ scores = results["scores"]
145
+ )
146
+ return img_out
147
+
148
+ # Create Gradio demo
149
+ demo = gr.Interface(
150
+ fn = detect_and_draw,
151
+ inputs = [
152
+ gr.Image(type="pil", label="Image"),
153
+ gr.Textbox(value="",
154
+ label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"),
155
+ gr.Slider(0.0, 1.0, 0.4, 0.05, label="Box Threshold"),
156
+ gr.Slider(0.0, 1.0, 0.3, 0.05, label="Text Threshold")
157
+ ],
158
+ outputs = gr.Image(type="pil", label="Detections"),
159
+ title = "LLMDet Demo: Open-Vocabulary Grounded Object Detection",
160
+ description = """Upload an image, enter text queries, and adjust thresholds to see detections.
161
+
162
+ Adapted from LLMDet GitHub repo [Hugging Face demo](https://github.com/iSEE-Laboratory/LLMDet/tree/main/hf_model).
163
+
164
+ See original:
165
+ * [LLMDet GitHub](https://github.com/iSEE-Laboratory/LLMDet/tree/main?tab=readme-ov-file)
166
+ * [LLMDet Paper](https://arxiv.org/abs/2501.18954) - LLMDet: Learning Strong Open-Vocabulary Object Detectors under the Supervision of Large Language Models
167
+ * [LLMDet model checkpoint](https://huggingface.co/fushh7/llmdet_swin_large_hf)
168
+ """
169
+ )
170
+
171
+ demo.launch()