""" Demonstrates integrating Rerun visualization with Gradio and HF ZeroGPU. """ import uuid import gradio as gr import rerun as rr import rerun.blueprint as rrb from gradio_rerun import Rerun import spaces from transformers import DetrImageProcessor, DetrForObjectDetection import torch processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") # Whenever we need a recording, we construct a new recording stream. # As long as the app and recording IDs remain the same, the data # will be merged by the Viewer. def get_recording(recording_id: str) -> rr.RecordingStream: return rr.RecordingStream( application_id="rerun_example_gradio", recording_id=recording_id ) # A task can directly log to a binary stream, which is routed to the embedded viewer. # Incremental chunks are yielded to the viewer using `yield stream.read()`. # # This is the preferred way to work with Rerun in Gradio since your data can be immediately and # incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage. @spaces.GPU def streaming_object_detection(recording_id: str, img): # Here we get a recording using the provided recording id. rec = get_recording(recording_id) stream = rec.binary_stream() if img is None: raise gr.Error("Must provide an image detect objects in.") blueprint = rrb.Blueprint( rrb.Horizontal( rrb.Spatial2DView(origin="image"), ), collapse_panels=True, ) rec.send_blueprint(blueprint) rec.set_time("iteration", sequence=0) rec.log("image", rr.Image(img)) yield stream.read() with torch.inference_mode(): inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) # convert outputs (bounding boxes and class logits) to COCO API # let's only keep detections with score > 0.85 height, width = img.shape[:2] target_sizes = torch.tensor([[height, width]]) # [height, width] order results = processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=0.85 )[0] rec.log( "image/objects", rr.Boxes2D( array=results["boxes"], array_format=rr.Box2DFormat.XYXY, labels=[model.config.id2label[label.item()] for label in results["labels"]], colors=[ ( label.item() * 50 % 255, (label.item() * 80 + 40) % 255, (label.item() * 120 + 100) % 255, ) for label in results["labels"] ], ), ) # Ensure we consume everything from the recording. stream.flush() yield stream.read() with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): with gr.Accordion("Your image", open=True): img = gr.Image(interactive=True, label="Image") detect_objects = gr.Button("Detect objects") with gr.Column(scale=4): viewer = Rerun( streaming=True, panel_states={ "time": "collapsed", "blueprint": "hidden", "selection": "hidden", }, height=700, ) # We make a new recording id, and store it in a Gradio's session state. recording_id = gr.State(uuid.uuid4()) # When registering the event listeners, we pass the `recording_id` in as input in order to create # a recording stream using that id. detect_objects.click( # Using the `viewer` as an output allows us to stream data to it by yielding bytes from the callback. streaming_object_detection, inputs=[recording_id, img], outputs=[viewer], ) if __name__ == "__main__": demo.launch(ssr_mode=False)