Spaces:
Sleeping
Sleeping
| """ | |
| Demonstrates integrating Rerun visualization with Gradio and HF ZeroGPU. | |
| """ | |
| import uuid | |
| import gradio as gr | |
| import rerun as rr | |
| import rerun.blueprint as rrb | |
| from gradio_rerun import Rerun | |
| import spaces | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| import torch | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| # Whenever we need a recording, we construct a new recording stream. | |
| # As long as the app and recording IDs remain the same, the data | |
| # will be merged by the Viewer. | |
| def get_recording(recording_id: str) -> rr.RecordingStream: | |
| return rr.RecordingStream( | |
| application_id="rerun_example_gradio", recording_id=recording_id | |
| ) | |
| # A task can directly log to a binary stream, which is routed to the embedded viewer. | |
| # Incremental chunks are yielded to the viewer using `yield stream.read()`. | |
| # | |
| # This is the preferred way to work with Rerun in Gradio since your data can be immediately and | |
| # incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage. | |
| def streaming_object_detection(recording_id: str, img): | |
| # Here we get a recording using the provided recording id. | |
| rec = get_recording(recording_id) | |
| stream = rec.binary_stream() | |
| if img is None: | |
| raise gr.Error("Must provide an image detect objects in.") | |
| blueprint = rrb.Blueprint( | |
| rrb.Horizontal( | |
| rrb.Spatial2DView(origin="image"), | |
| ), | |
| collapse_panels=True, | |
| ) | |
| rec.send_blueprint(blueprint) | |
| rec.set_time("iteration", sequence=0) | |
| rec.log("image", rr.Image(img)) | |
| yield stream.read() | |
| with torch.inference_mode(): | |
| inputs = processor(images=img, return_tensors="pt") | |
| outputs = model(**inputs) | |
| # convert outputs (bounding boxes and class logits) to COCO API | |
| # let's only keep detections with score > 0.85 | |
| height, width = img.shape[:2] | |
| target_sizes = torch.tensor([[height, width]]) # [height, width] order | |
| results = processor.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=0.85 | |
| )[0] | |
| rec.log( | |
| "image/objects", | |
| rr.Boxes2D( | |
| array=results["boxes"], | |
| array_format=rr.Box2DFormat.XYXY, | |
| labels=[model.config.id2label[label.item()] for label in results["labels"]], | |
| colors=[ | |
| ( | |
| label.item() * 50 % 255, | |
| (label.item() * 80 + 40) % 255, | |
| (label.item() * 120 + 100) % 255, | |
| ) | |
| for label in results["labels"] | |
| ], | |
| ), | |
| ) | |
| # Ensure we consume everything from the recording. | |
| stream.flush() | |
| yield stream.read() | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Your image", open=True): | |
| img = gr.Image(interactive=True, label="Image") | |
| detect_objects = gr.Button("Detect objects") | |
| with gr.Column(scale=4): | |
| viewer = Rerun( | |
| streaming=True, | |
| panel_states={ | |
| "time": "collapsed", | |
| "blueprint": "hidden", | |
| "selection": "hidden", | |
| }, | |
| height=700, | |
| ) | |
| # We make a new recording id, and store it in a Gradio's session state. | |
| recording_id = gr.State(uuid.uuid4()) | |
| # When registering the event listeners, we pass the `recording_id` in as input in order to create | |
| # a recording stream using that id. | |
| detect_objects.click( | |
| # Using the `viewer` as an output allows us to stream data to it by yielding bytes from the callback. | |
| streaming_object_detection, | |
| inputs=[recording_id, img], | |
| outputs=[viewer], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |