SkalskiP commited on
Commit
d88b494
·
verified ·
1 Parent(s): 5b55299

Update app.py to allow new nano, small, and medium checkpoints

Browse files
Files changed (1) hide show
  1. app.py +42 -11
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
  from PIL import Image
9
- from rfdetr import RFDETRBase, RFDETRLarge
10
  from rfdetr.detr import RFDETR
11
  from rfdetr.util.coco_classes import COCO_CLASSES
12
 
@@ -25,13 +25,16 @@ by [Roboflow](https://roboflow.com/) and released under the Apache 2.0 license.
25
  """
26
 
27
  IMAGE_PROCESSING_EXAMPLES = [
28
- ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 728, "large"],
29
- ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 728, "large"],
30
- ['https://media.roboflow.com/notebooks/examples/dog-2.jpeg', 0.5, 560, "base"],
 
 
 
31
  ]
32
  VIDEO_PROCESSING_EXAMPLES = [
33
- ["videos/people-walking.mp4", 0.3, 728, "large"],
34
- ["videos/vehicles.mp4", 0.3, 728, "large"],
35
  ]
36
 
37
  COLOR = sv.ColorPalette.from_hex([
@@ -77,6 +80,12 @@ def detect_and_annotate(
77
 
78
 
79
  def load_model(resolution: int, checkpoint: str) -> RFDETR:
 
 
 
 
 
 
80
  if checkpoint == "base":
81
  return RFDETRBase(resolution=resolution)
82
  elif checkpoint == "large":
@@ -84,12 +93,33 @@ def load_model(resolution: int, checkpoint: str) -> RFDETR:
84
  raise TypeError("Checkpoint must be a base or large.")
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def image_processing_inference(
88
  input_image: Image.Image,
89
  confidence: float,
90
  resolution: int,
91
  checkpoint: str
92
  ):
 
93
  model = load_model(resolution=resolution, checkpoint=checkpoint)
94
  return detect_and_annotate(model=model, image=input_image, confidence=confidence)
95
 
@@ -100,6 +130,7 @@ def video_processing_inference(
100
  resolution: int,
101
  checkpoint: str,
102
  ):
 
103
  model = load_model(resolution=resolution, checkpoint=checkpoint)
104
 
105
  name = generate_unique_name()
@@ -151,14 +182,14 @@ with gr.Blocks() as demo:
151
  )
152
  image_processing_resolution_slider = gr.Slider(
153
  label="Inference resolution",
154
- minimum=560,
155
- maximum=1120,
156
- step=56,
157
- value=728,
158
  )
159
  image_processing_checkpoint_dropdown = gr.Dropdown(
160
  label="Checkpoint",
161
- choices=["base", "large"],
162
  value="base"
163
  )
164
  with gr.Column():
 
6
  import numpy as np
7
  import supervision as sv
8
  from PIL import Image
9
+ from rfdetr import RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRBase, RFDETRLarge
10
  from rfdetr.detr import RFDETR
11
  from rfdetr.util.coco_classes import COCO_CLASSES
12
 
 
25
  """
26
 
27
  IMAGE_PROCESSING_EXAMPLES = [
28
+ ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 1024, "medium"],
29
+ ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 1024, "medium"],
30
+ ['https://media.roboflow.com/supervision/image-examples/motorbike.png', 0.3, 1024, "medium"],
31
+ ['https://media.roboflow.com/notebooks/examples/dog-2.jpeg', 0.5, 512, "nano"],
32
+ ['https://media.roboflow.com/notebooks/examples/dog-3.jpeg', 0.5, 512, "nano"],
33
+ ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 0.5, 512, "nano"],
34
  ]
35
  VIDEO_PROCESSING_EXAMPLES = [
36
+ ["videos/people-walking.mp4", 0.3, 1024, "medium"],
37
+ ["videos/vehicles.mp4", 0.3, 1024, "medium"],
38
  ]
39
 
40
  COLOR = sv.ColorPalette.from_hex([
 
80
 
81
 
82
  def load_model(resolution: int, checkpoint: str) -> RFDETR:
83
+ if checkpoint == "nano":
84
+ return RFDETRNano(resolution=resolution)
85
+ if checkpoint == "small":
86
+ return RFDETRSmall(resolution=resolution)
87
+ if checkpoint == "medium":
88
+ return RFDETRMedium(resolution=resolution)
89
  if checkpoint == "base":
90
  return RFDETRBase(resolution=resolution)
91
  elif checkpoint == "large":
 
93
  raise TypeError("Checkpoint must be a base or large.")
94
 
95
 
96
+ def adjust_resolution(checkpoint: str, resolution: int) -> int:
97
+ if checkpoint in {"nano", "small", "medium"}:
98
+ divisor = 32
99
+ elif checkpoint in {"base", "large"}:
100
+ divisor = 56
101
+ else:
102
+ raise ValueError(f"Unknown checkpoint: {checkpoint}")
103
+
104
+ remainder = resolution % divisor
105
+ if remainder == 0:
106
+ return resolution
107
+ lower = resolution - remainder
108
+ upper = lower + divisor
109
+
110
+ if resolution - lower < upper - resolution:
111
+ return lower
112
+ else:
113
+ return upper
114
+
115
+
116
  def image_processing_inference(
117
  input_image: Image.Image,
118
  confidence: float,
119
  resolution: int,
120
  checkpoint: str
121
  ):
122
+ resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
123
  model = load_model(resolution=resolution, checkpoint=checkpoint)
124
  return detect_and_annotate(model=model, image=input_image, confidence=confidence)
125
 
 
130
  resolution: int,
131
  checkpoint: str,
132
  ):
133
+ resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
134
  model = load_model(resolution=resolution, checkpoint=checkpoint)
135
 
136
  name = generate_unique_name()
 
182
  )
183
  image_processing_resolution_slider = gr.Slider(
184
  label="Inference resolution",
185
+ minimum=224,
186
+ maximum=2240,
187
+ step=1,
188
+ value=896,
189
  )
190
  image_processing_checkpoint_dropdown = gr.Dropdown(
191
  label="Checkpoint",
192
+ choices=["nano", "small", "medium", "base", "large"],
193
  value="base"
194
  )
195
  with gr.Column():