oxkitsune commited on
Commit
b4a4fd5
·
1 Parent(s): 57aeafc

add detr model

Browse files
Files changed (2) hide show
  1. app.py +27 -24
  2. requirements.txt +10 -2
app.py CHANGED
@@ -11,6 +11,13 @@ import tempfile
11
  import time
12
  import uuid
13
 
 
 
 
 
 
 
 
14
  import cv2
15
  import gradio as gr
16
  import rerun as rr
@@ -22,13 +29,16 @@ from gradio_rerun.events import (
22
  TimeUpdate,
23
  )
24
  import spaces
 
 
 
 
25
 
26
- import subprocess
27
- subprocess.run(
28
- "pip install gradio_rerun-0.23.0a2-py3-none-any.whl",
29
- shell=True,
30
- )
31
 
 
 
32
 
33
  # Whenever we need a recording, we construct a new recording stream.
34
  # As long as the app and recording IDs remain the same, the data
@@ -42,6 +52,7 @@ def get_recording(recording_id: str) -> rr.RecordingStream:
42
  #
43
  # This is the preferred way to work with Rerun in Gradio since your data can be immediately and
44
  # incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage.
 
45
  def streaming_repeated_blur(recording_id: str, img):
46
  # Here we get a recording using the provided recording id.
47
  rec = get_recording(recording_id)
@@ -52,30 +63,27 @@ def streaming_repeated_blur(recording_id: str, img):
52
 
53
  blueprint = rrb.Blueprint(
54
  rrb.Horizontal(
55
- rrb.Spatial2DView(origin="image/original"),
56
- rrb.Spatial2DView(origin="image/blurred"),
57
  ),
58
  collapse_panels=True,
59
  )
60
 
61
  rec.send_blueprint(blueprint)
62
  rec.set_time("iteration", sequence=0)
63
- rec.log("image/original", rr.Image(img))
64
  yield stream.read()
65
 
66
- blur = img
67
- for i in range(100):
68
- rec.set_time("iteration", sequence=i)
69
 
70
- # Pretend blurring takes a while so we can see streaming in action.
71
- time.sleep(0.1)
72
- blur = cv2.GaussianBlur(blur, (5, 5), 0)
73
- rec.log("image/blurred", rr.Image(blur))
74
 
75
- # Each time we yield bytes from the stream back to Gradio, they
76
- # are incrementally sent to the viewer. Make sure to yield any time
77
- # you want the user to be able to see progress.
78
- yield stream.read()
 
 
 
79
 
80
  # Ensure we consume everything from the recording.
81
  stream.flush()
@@ -162,11 +170,6 @@ def track_current_time(evt: TimeUpdate):
162
  def track_current_timeline_and_time(evt: TimelineChange):
163
  return evt.timeline, evt.time
164
 
165
- @spaces.GPU
166
- def run_inference(img):
167
- print("running inference")
168
- pass
169
-
170
  with gr.Blocks() as demo:
171
  with gr.Row():
172
  img = gr.Image(interactive=True, label="Image")
 
11
  import time
12
  import uuid
13
 
14
+ import subprocess
15
+ subprocess.run(
16
+ "pip install gradio_rerun-0.23.0a2.tar.gz",
17
+ shell=True,
18
+ )
19
+
20
+
21
  import cv2
22
  import gradio as gr
23
  import rerun as rr
 
29
  TimeUpdate,
30
  )
31
  import spaces
32
+ from transformers import DetrImageProcessor, DetrForObjectDetection
33
+ import torch
34
+ from PIL import Image
35
+ import requests
36
 
37
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
38
+ image = Image.open(requests.get(url, stream=True).raw)
 
 
 
39
 
40
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
41
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
42
 
43
  # Whenever we need a recording, we construct a new recording stream.
44
  # As long as the app and recording IDs remain the same, the data
 
52
  #
53
  # This is the preferred way to work with Rerun in Gradio since your data can be immediately and
54
  # incrementally seen by the viewer. Also, there are no ephemeral RRDs to cleanup or manage.
55
+ @spaces.GPU
56
  def streaming_repeated_blur(recording_id: str, img):
57
  # Here we get a recording using the provided recording id.
58
  rec = get_recording(recording_id)
 
63
 
64
  blueprint = rrb.Blueprint(
65
  rrb.Horizontal(
66
+ rrb.Spatial2DView(origin="image"),
 
67
  ),
68
  collapse_panels=True,
69
  )
70
 
71
  rec.send_blueprint(blueprint)
72
  rec.set_time("iteration", sequence=0)
73
+ rec.log("image", rr.Image(img))
74
  yield stream.read()
75
 
 
 
 
76
 
77
+ inputs = processor(images=image, return_tensors="pt")
78
+ outputs = model(**inputs)
 
 
79
 
80
+ # convert outputs (bounding boxes and class logits) to COCO API
81
+ # let's only keep detections with score > 0.9
82
+ target_sizes = torch.tensor([image.size[::-1]])
83
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
84
+
85
+ print(results)
86
+ rec.log("image/objects", rr.Boxes2D(sizes=results["boxes"], labels=[model.config.id2label[label.item()] for label in results["labels"]]))
87
 
88
  # Ensure we consume everything from the recording.
89
  stream.flush()
 
170
  def track_current_timeline_and_time(evt: TimelineChange):
171
  return evt.timeline, evt.time
172
 
 
 
 
 
 
173
  with gr.Blocks() as demo:
174
  with gr.Row():
175
  img = gr.Image(interactive=True, label="Image")
requirements.txt CHANGED
@@ -13,6 +13,7 @@ cffi==1.17.1
13
  charset-normalizer==3.4.1
14
  click==8.1.8
15
  cryptography==44.0.2
 
16
  decorator==5.2.1
17
  dill==0.3.8
18
  exceptiongroup==1.2.2
@@ -53,6 +54,7 @@ nvidia-cufft-cu12==11.2.1.3
53
  nvidia-curand-cu12==10.3.5.147
54
  nvidia-cusolver-cu12==11.6.1.9
55
  nvidia-cusparse-cu12==12.3.1.170
 
56
  nvidia-nccl-cu12==2.21.5
57
  nvidia-nvjitlink-cu12==12.4.127
58
  nvidia-nvtx-cu12==12.4.127
@@ -79,11 +81,13 @@ python-dateutil==2.9.0.post0
79
  python-multipart==0.0.20
80
  pytz==2025.2
81
  PyYAML==6.0.2
 
82
  requests==2.32.3
83
  rerun-sdk==0.23.0a2
84
  rich==14.0.0
85
  ruff==0.11.5
86
  safehttpx==0.1.6
 
87
  semantic-version==2.10.0
88
  shellingham==1.5.4
89
  six==1.17.0
@@ -92,11 +96,15 @@ spaces==0.34.2
92
  stack-data==0.6.3
93
  starlette==0.46.2
94
  sympy==1.13.1
 
 
95
  tomlkit==0.13.2
96
- torch==2.5.1
 
97
  tqdm==4.67.1
98
  traitlets==5.14.3
99
- triton==3.1.0
 
100
  typer==0.15.2
101
  typing-inspection==0.4.0
102
  typing_extensions==4.13.2
 
13
  charset-normalizer==3.4.1
14
  click==8.1.8
15
  cryptography==44.0.2
16
+ datasets==3.5.0
17
  decorator==5.2.1
18
  dill==0.3.8
19
  exceptiongroup==1.2.2
 
54
  nvidia-curand-cu12==10.3.5.147
55
  nvidia-cusolver-cu12==11.6.1.9
56
  nvidia-cusparse-cu12==12.3.1.170
57
+ nvidia-cusparselt-cu12==0.6.2
58
  nvidia-nccl-cu12==2.21.5
59
  nvidia-nvjitlink-cu12==12.4.127
60
  nvidia-nvtx-cu12==12.4.127
 
81
  python-multipart==0.0.20
82
  pytz==2025.2
83
  PyYAML==6.0.2
84
+ regex==2024.11.6
85
  requests==2.32.3
86
  rerun-sdk==0.23.0a2
87
  rich==14.0.0
88
  ruff==0.11.5
89
  safehttpx==0.1.6
90
+ safetensors==0.5.3
91
  semantic-version==2.10.0
92
  shellingham==1.5.4
93
  six==1.17.0
 
96
  stack-data==0.6.3
97
  starlette==0.46.2
98
  sympy==1.13.1
99
+ timm==1.0.15
100
+ tokenizers==0.21.1
101
  tomlkit==0.13.2
102
+ torch==2.6.0
103
+ torchvision==0.21.0
104
  tqdm==4.67.1
105
  traitlets==5.14.3
106
+ transformers==4.51.3
107
+ triton==3.2.0
108
  typer==0.15.2
109
  typing-inspection==0.4.0
110
  typing_extensions==4.13.2