Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Demonstrates integrating Rerun visualization with Gradio and HF ZeroGPU. | |
""" | |
import uuid | |
import gradio as gr | |
import rerun as rr | |
import rerun.blueprint as rrb | |
from gradio_rerun import Rerun | |
import spaces | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
import torch | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
# Whenever we need a recording, we construct a new recording stream. | |
# As long as the app and recording IDs remain the same, the data | |
# will be merged by the Viewer. | |
def get_recording(recording_id: str) -> rr.RecordingStream: | |
return rr.RecordingStream( | |
application_id="rerun_example_gradio", recording_id=recording_id | |
) | |
# A task can directly log to a binary stream, which is routed to the embedded viewer. | |
# Incremental chunks are yielded to the viewer using `yield stream.read()`. | |
# | |
# This is the preferred way to work with Rerun in Gradio since your data can be immediately and | |
# incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage. | |
def streaming_object_detection(recording_id: str, img): | |
# Here we get a recording using the provided recording id. | |
rec = get_recording(recording_id) | |
stream = rec.binary_stream() | |
if img is None: | |
raise gr.Error("Must provide an image detect objects in.") | |
blueprint = rrb.Blueprint( | |
rrb.Horizontal( | |
rrb.Spatial2DView(origin="image"), | |
), | |
collapse_panels=True, | |
) | |
rec.send_blueprint(blueprint) | |
rec.set_time("iteration", sequence=0) | |
rec.log("image", rr.Image(img)) | |
yield stream.read() | |
with torch.inference_mode(): | |
inputs = processor(images=img, return_tensors="pt") | |
outputs = model(**inputs) | |
# convert outputs (bounding boxes and class logits) to COCO API | |
# let's only keep detections with score > 0.85 | |
height, width = img.shape[:2] | |
target_sizes = torch.tensor([[height, width]]) # [height, width] order | |
results = processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=0.85 | |
)[0] | |
rec.log( | |
"image/objects", | |
rr.Boxes2D( | |
array=results["boxes"], | |
array_format=rr.Box2DFormat.XYXY, | |
labels=[model.config.id2label[label.item()] for label in results["labels"]], | |
colors=[ | |
( | |
label.item() * 50 % 255, | |
(label.item() * 80 + 40) % 255, | |
(label.item() * 120 + 100) % 255, | |
) | |
for label in results["labels"] | |
], | |
), | |
) | |
# Ensure we consume everything from the recording. | |
stream.flush() | |
yield stream.read() | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Accordion("Your image", open=True): | |
img = gr.Image(interactive=True, label="Image") | |
detect_objects = gr.Button("Detect objects") | |
with gr.Column(scale=4): | |
viewer = Rerun( | |
streaming=True, | |
panel_states={ | |
"time": "collapsed", | |
"blueprint": "hidden", | |
"selection": "hidden", | |
}, | |
height=700, | |
) | |
# We make a new recording id, and store it in a Gradio's session state. | |
recording_id = gr.State(uuid.uuid4()) | |
# When registering the event listeners, we pass the `recording_id` in as input in order to create | |
# a recording stream using that id. | |
detect_objects.click( | |
# Using the `viewer` as an output allows us to stream data to it by yielding bytes from the callback. | |
streaming_object_detection, | |
inputs=[recording_id, img], | |
outputs=[viewer], | |
) | |
if __name__ == "__main__": | |
demo.launch(ssr_mode=False) | |