oxkitsune commited on
Commit
79c4fdc
·
1 Parent(s): 453d9af

Actually run inference on the image

Browse files
Files changed (1) hide show
  1. app.py +26 -101
app.py CHANGED
@@ -12,6 +12,7 @@ import time
12
  import uuid
13
 
14
  import subprocess
 
15
  subprocess.run(
16
  "pip install gradio_rerun-0.23.0a2.tar.gz",
17
  shell=True,
@@ -40,11 +41,14 @@ image = Image.open(requests.get(url, stream=True).raw)
40
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
41
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
42
 
 
43
  # Whenever we need a recording, we construct a new recording stream.
44
  # As long as the app and recording IDs remain the same, the data
45
  # will be merged by the Viewer.
46
  def get_recording(recording_id: str) -> rr.RecordingStream:
47
- return rr.RecordingStream(application_id="rerun_example_gradio", recording_id=recording_id)
 
 
48
 
49
 
50
  # A task can directly log to a binary stream, which is routed to the embedded viewer.
@@ -53,7 +57,7 @@ def get_recording(recording_id: str) -> rr.RecordingStream:
53
  # This is the preferred way to work with Rerun in Gradio since your data can be immediately and
54
  # incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage.
55
  @spaces.GPU
56
- def streaming_repeated_blur(recording_id: str, img):
57
  # Here we get a recording using the provided recording id.
58
  rec = get_recording(recording_id)
59
  stream = rec.binary_stream()
@@ -73,108 +77,38 @@ def streaming_repeated_blur(recording_id: str, img):
73
  rec.log("image", rr.Image(img))
74
  yield stream.read()
75
 
76
-
77
- inputs = processor(images=image, return_tensors="pt")
78
- outputs = model(**inputs)
79
 
80
  # convert outputs (bounding boxes and class logits) to COCO API
81
  # let's only keep detections with score > 0.9
82
- target_sizes = torch.tensor([image.size[::-1]])
83
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
84
-
85
- print(results)
86
- rec.log("image/objects", rr.Boxes2D(sizes=results["boxes"], labels=[model.config.id2label[label.item()] for label in results["labels"]]))
87
-
88
- # Ensure we consume everything from the recording.
89
- stream.flush()
90
- yield stream.read()
91
-
92
-
93
- # In this example the user is able to add keypoints to an image visualized in Rerun.
94
- # These keypoints are stored in the global state, we use the session id to keep track of which keypoints belong
95
- # to a specific session (https://www.gradio.app/guides/state-in-blocks).
96
- #
97
- # The current session can be obtained by adding a parameter of type `gradio.Request` to your event listener functions.
98
- Keypoint = tuple[float, float]
99
- keypoints_per_session_per_sequence_index: dict[str, dict[int, list[Keypoint]]] = {}
100
-
101
-
102
- def get_keypoints_for_user_at_sequence_index(request: gr.Request, sequence: int) -> list[Keypoint]:
103
- per_sequence = keypoints_per_session_per_sequence_index[request.session_hash]
104
- if sequence not in per_sequence:
105
- per_sequence[sequence] = []
106
-
107
- return per_sequence[sequence]
108
-
109
-
110
- def initialize_instance(request: gr.Request) -> None:
111
- keypoints_per_session_per_sequence_index[request.session_hash] = {}
112
 
 
 
 
 
113
 
114
- def cleanup_instance(request: gr.Request) -> None:
115
- if request.session_hash in keypoints_per_session_per_sequence_index:
116
- del keypoints_per_session_per_sequence_index[request.session_hash]
117
-
118
-
119
- # In this function, the `request` and `evt` parameters will be automatically injected by Gradio when this
120
- # event listener is fired.
121
- #
122
- # `SelectionChange` is a subclass of `EventData`: https://www.gradio.app/docs/gradio/eventdata
123
- # `gr.Request`: https://www.gradio.app/main/docs/gradio/request
124
- def register_keypoint(
125
- active_recording_id: str,
126
- current_timeline: str,
127
- current_time: float,
128
- request: gr.Request,
129
- evt: SelectionChange,
130
- ):
131
- if active_recording_id == "":
132
- return
133
-
134
- if current_timeline != "iteration":
135
- return
136
-
137
- # We can only log a keypoint if the user selected only a single item.
138
- if len(evt.items) != 1:
139
- return
140
- item = evt.items[0]
141
-
142
- # If the selected item isn't an entity, or we don't have its position, then bail out.
143
- if item.kind != "entity" or item.position is None:
144
- return
145
-
146
- # Now we can produce a valid keypoint.
147
- rec = get_recording(active_recording_id)
148
- stream = rec.binary_stream()
149
-
150
- # We round `current_time` toward 0, because that gives us the sequence index
151
- # that the user is currently looking at, due to the Viewer's latest-at semantics.
152
- index = math.floor(current_time)
153
-
154
- # We keep track of the keypoints per sequence index for each user manually.
155
- keypoints = get_keypoints_for_user_at_sequence_index(request, index)
156
- keypoints.append(item.position[0:2])
157
-
158
- rec.set_time("iteration", sequence=index)
159
- rec.log(f"{item.entity_path}/keypoint", rr.Points2D(keypoints, radii=2))
160
 
161
  # Ensure we consume everything from the recording.
162
  stream.flush()
163
  yield stream.read()
164
 
165
 
166
- def track_current_time(evt: TimeUpdate):
167
- return evt.time
168
-
169
-
170
- def track_current_timeline_and_time(evt: TimelineChange):
171
- return evt.timeline, evt.time
172
-
173
  with gr.Blocks() as demo:
174
  with gr.Row():
175
  img = gr.Image(interactive=True, label="Image")
176
  with gr.Column():
177
- stream_blur = gr.Button("Stream Repeated Blur")
178
 
179
  with gr.Row():
180
  viewer = Rerun(
@@ -195,20 +129,11 @@ with gr.Blocks() as demo:
195
 
196
  # When registering the event listeners, we pass the `recording_id` in as input in order to create
197
  # a recording stream using that id.
198
- stream_blur.click(
199
  # Using the `viewer` as an output allows us to stream data to it by yielding bytes from the callback.
200
- streaming_repeated_blur,
201
  inputs=[recording_id, img],
202
  outputs=[viewer],
203
  )
204
- viewer.selection_change(
205
- register_keypoint,
206
- inputs=[recording_id, current_timeline, current_time],
207
- outputs=[viewer],
208
- )
209
- viewer.time_update(track_current_time, outputs=[current_time])
210
- viewer.timeline_change(track_current_timeline_and_time, outputs=[current_timeline, current_time])
211
-
212
-
213
  if __name__ == "__main__":
214
- demo.launch()
 
12
  import uuid
13
 
14
  import subprocess
15
+
16
  subprocess.run(
17
  "pip install gradio_rerun-0.23.0a2.tar.gz",
18
  shell=True,
 
41
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
42
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
43
 
44
+
45
  # Whenever we need a recording, we construct a new recording stream.
46
  # As long as the app and recording IDs remain the same, the data
47
  # will be merged by the Viewer.
48
  def get_recording(recording_id: str) -> rr.RecordingStream:
49
+ return rr.RecordingStream(
50
+ application_id="rerun_example_gradio", recording_id=recording_id
51
+ )
52
 
53
 
54
  # A task can directly log to a binary stream, which is routed to the embedded viewer.
 
57
  # This is the preferred way to work with Rerun in Gradio since your data can be immediately and
58
  # incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage.
59
  @spaces.GPU
60
+ def streaming_object_detection(recording_id: str, img):
61
  # Here we get a recording using the provided recording id.
62
  rec = get_recording(recording_id)
63
  stream = rec.binary_stream()
 
77
  rec.log("image", rr.Image(img))
78
  yield stream.read()
79
 
80
+ with torch.inference_mode():
81
+ inputs = processor(images=img, return_tensors="pt")
82
+ outputs = model(**inputs)
83
 
84
  # convert outputs (bounding boxes and class logits) to COCO API
85
  # let's only keep detections with score > 0.9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ target_sizes = torch.tensor([img.height, img.width])
88
+ results = processor.post_process_object_detection(
89
+ outputs, target_sizes=target_sizes, threshold=0.9
90
+ )[0]
91
 
92
+ print(results)
93
+ rec.log(
94
+ "image/objects",
95
+ rr.Boxes2D(
96
+ array=results["boxes"],
97
+ array_format=rr.Box2DFormat.XYXY,
98
+ labels=[model.config.id2label[label.item()] for label in results["labels"]],
99
+ ),
100
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Ensure we consume everything from the recording.
103
  stream.flush()
104
  yield stream.read()
105
 
106
 
 
 
 
 
 
 
 
107
  with gr.Blocks() as demo:
108
  with gr.Row():
109
  img = gr.Image(interactive=True, label="Image")
110
  with gr.Column():
111
+ detect_objects = gr.Button("Detect objects")
112
 
113
  with gr.Row():
114
  viewer = Rerun(
 
129
 
130
  # When registering the event listeners, we pass the `recording_id` in as input in order to create
131
  # a recording stream using that id.
132
+ detect_objects.click(
133
  # Using the `viewer` as an output allows us to stream data to it by yielding bytes from the callback.
134
+ streaming_object_detection,
135
  inputs=[recording_id, img],
136
  outputs=[viewer],
137
  )
 
 
 
 
 
 
 
 
 
138
  if __name__ == "__main__":
139
+ demo.launch(ssr_mode=False)