aharley commited on
Commit
09e82bb
·
1 Parent(s): c0337cc

added gpu-based drawing

Browse files
Files changed (1) hide show
  1. app.py +234 -7
app.py CHANGED
@@ -4,6 +4,8 @@
4
  import os
5
  import sys
6
  import uuid
 
 
7
 
8
  import gradio as gr
9
  import mediapy
@@ -93,6 +95,211 @@ def get_points_on_a_grid(
93
  )
94
  return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def paint_point_track(
97
  frames: np.ndarray,
98
  point_tracks: np.ndarray,
@@ -330,7 +537,8 @@ def preprocess_video_input(video_path):
330
  def track(
331
  video_preview,
332
  video_input,
333
- video_fps,
 
334
  query_points,
335
  query_points_color,
336
  query_count,
@@ -338,6 +546,10 @@ def track(
338
  # tracking_mode = 'selected'
339
  # if query_count == 0:
340
  # tracking_mode = 'grid'
 
 
 
 
341
 
342
  device = "cuda" if torch.cuda.is_available() else "cpu"
343
  dtype = torch.float if device == "cuda" else torch.float
@@ -407,7 +619,7 @@ def track(
407
  # add_support_grid=True
408
 
409
 
410
- query_frame = 0
411
 
412
  torch.cuda.empty_cache()
413
 
@@ -444,11 +656,17 @@ def track(
444
  # colors.extend(frame_colors)
445
  # colors = np.array(colors)
446
 
447
- traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
448
- visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
 
 
449
 
450
  tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
451
- visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy() > 0.9
 
 
 
 
452
 
453
  # sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
454
  # print('sc', sc)
@@ -467,8 +685,15 @@ def track(
467
  for frame_colors in query_points_color:
468
  colors.extend(frame_colors)
469
  colors = np.array(colors)
 
 
 
 
 
470
 
471
- painted_video = paint_point_track(video_preview,tracks,visibs,colors)
 
 
472
  print("7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
473
 
474
  # save video
@@ -546,7 +771,8 @@ with gr.Blocks() as demo:
546
 
547
  with gr.Row():
548
  current_frame = gr.Image(
549
- label="Click to add query points",
 
550
  type="numpy",
551
  interactive=False
552
  )
@@ -679,6 +905,7 @@ with gr.Blocks() as demo:
679
  video_preview,
680
  video_input,
681
  video_fps,
 
682
  query_points,
683
  query_points_color,
684
  query_count,
 
4
  import os
5
  import sys
6
  import uuid
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
 
10
  import gradio as gr
11
  import mediapy
 
95
  )
96
  return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
97
 
98
+ def paint_point_track_gpu_scatter(
99
+ frames: np.ndarray,
100
+ point_tracks: np.ndarray,
101
+ visibles: np.ndarray,
102
+ colormap: Optional[List[Tuple[int, int, int]]] = None,
103
+ radius: int = 2,
104
+ sharpness: float = 0.15,
105
+ ) -> np.ndarray:
106
+ print('starting vis')
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+ frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
109
+ point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
110
+ visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
111
+ T, C, H, W = frames_t.shape
112
+ P = point_tracks.shape[0]
113
+ if colormap is None:
114
+ colormap = get_colors(P)
115
+ colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3]
116
+ D = radius * 2 + 1
117
+ y = torch.arange(D, device=device).float()[:, None] - radius
118
+ x = torch.arange(D, device=device).float()[None, :] - radius
119
+ dist2 = x**2 + y**2
120
+ icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D]
121
+ icon = icon.view(1, D, D)
122
+ dx = torch.arange(-radius, radius + 1, device=device)
123
+ dy = torch.arange(-radius, radius + 1, device=device)
124
+ disp_y, disp_x = torch.meshgrid(dy, dx, indexing="ij") # [D,D]
125
+ for t in range(T):
126
+ mask = visibles_t[:, t] # [P]
127
+ if mask.sum() == 0:
128
+ continue
129
+ xy = point_tracks_t[mask, t] + 0.5 # [N,2]
130
+ xy[:, 0] = xy[:, 0].clamp(0, W - 1)
131
+ xy[:, 1] = xy[:, 1].clamp(0, H - 1)
132
+ colors_now = colors[mask] # [N,3]
133
+ N = xy.shape[0]
134
+ cx = xy[:, 0].long() # [N]
135
+ cy = xy[:, 1].long()
136
+ x_grid = cx[:, None, None] + disp_x # [N,D,D]
137
+ y_grid = cy[:, None, None] + disp_y # [N,D,D]
138
+ valid = (x_grid >= 0) & (x_grid < W) & (y_grid >= 0) & (y_grid < H)
139
+ x_valid = x_grid[valid] # [K]
140
+ y_valid = y_grid[valid]
141
+ icon_weights = icon.expand(N, D, D)[valid] # [K]
142
+ colors_valid = colors_now[:, :, None, None].expand(N, 3, D, D).permute(1, 0, 2, 3)[
143
+ :, valid
144
+ ] # [3, K]
145
+ idx_flat = (y_valid * W + x_valid).long() # [K]
146
+
147
+ accum = torch.zeros_like(frames_t[t]) # [3, H, W]
148
+ weight = torch.zeros(1, H * W, device=device) # [1, H*W]
149
+ img_flat = accum.view(C, -1) # [3, H*W]
150
+ weighted_colors = colors_valid * icon_weights # [3, K]
151
+ img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
152
+ weight.scatter_add_(1, idx_flat.unsqueeze(0), icon_weights.unsqueeze(0))
153
+ weight = weight.view(1, H, W)
154
+ # accum = accum / (weight + 1e-6) # avoid division by 0
155
+ # frames_t[t] = torch.where(weight > 0, accum, frames_t[t])
156
+ # frames_t[t] = frames_t[t] * (1 - weight) + accum
157
+
158
+ # alpha = weight.clamp(0, 1)
159
+ alpha = weight.clamp(0, 1) * 0.75 # transparency
160
+ accum = accum / (weight + 1e-6) # [3, H, W]
161
+ frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha
162
+
163
+ # img_flat = frames_t[t].view(C, -1) # [3, H*W]
164
+ # weighted_colors = colors_valid * icon_weights # [3, K]
165
+ # img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
166
+ print('done vis')
167
+ return frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
168
+
169
+ def paint_point_track_gpu(
170
+ frames: np.ndarray,
171
+ point_tracks: np.ndarray,
172
+ visibles: np.ndarray,
173
+ colormap: Optional[List[Tuple[int, int, int]]] = None,
174
+ radius: int = 2,
175
+ sharpness: float = 0.15,
176
+ ) -> np.ndarray:
177
+ device = "cuda" if torch.cuda.is_available() else "cpu"
178
+ # Setup
179
+ frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
180
+ point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
181
+ visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
182
+ T, C, H, W = frames_t.shape
183
+ P = point_tracks.shape[0]
184
+
185
+ # Colors
186
+ if colormap is None:
187
+ colormap = get_colors(P) # or any fixed list of RGB
188
+ colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3]
189
+
190
+ # Icon kernel [K,K]
191
+ D = radius * 2 + 1
192
+ y = torch.arange(D, device=device).float()[:, None] - radius - 1
193
+ x = torch.arange(D, device=device).float()[None, :] - radius - 1
194
+ dist2 = x**2 + y**2
195
+ icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D]
196
+ icon = icon.unsqueeze(0) # [1,D,D] for broadcasting
197
+
198
+ # Create coordinate grids
199
+ for t in range(T):
200
+ image = frames_t[t]
201
+ # Select visible points
202
+ visible_mask = visibles_t[:, t]
203
+ pt_xy = point_tracks_t[visible_mask, t] # [N,2]
204
+ colors_t = colors[visible_mask] # [N,3]
205
+ N = pt_xy.shape[0]
206
+ if N == 0:
207
+ continue
208
+
209
+ # Integer centers
210
+ pt_xy = pt_xy + 0.5 # correct center offset
211
+ pt_xy[:, 0] = pt_xy[:, 0].clamp(0, W - 1)
212
+ pt_xy[:, 1] = pt_xy[:, 1].clamp(0, H - 1)
213
+ ix = pt_xy[:, 0].long() # [N]
214
+ iy = pt_xy[:, 1].long()
215
+
216
+ # Build grid of indices for patch around each point
217
+ dx = torch.arange(-radius, radius + 1, device=device)
218
+ dy = torch.arange(-radius, radius + 1, device=device)
219
+ dx_grid, dy_grid = torch.meshgrid(dx, dy, indexing='ij')
220
+ dx_flat = dx_grid.reshape(-1)
221
+ dy_flat = dy_grid.reshape(-1)
222
+ patch_x = ix[:, None] + dx_flat[None, :] # [N,K*K]
223
+ patch_y = iy[:, None] + dy_flat[None, :] # [N,K*K]
224
+
225
+ # Mask out-of-bounds
226
+ valid = (patch_x >= 0) & (patch_x < W) & (patch_y >= 0) & (patch_y < H)
227
+ flat_idx = (patch_y * W + patch_x).long() # [N,K*K]
228
+
229
+ # Flatten icon and colors
230
+ icon_flat = icon.view(1, -1) # [1, K*K]
231
+ color_patches = colors_t[:, :, None] * icon_flat[:, None, :] # [N,3,K*K]
232
+
233
+ # Flatten to write into 1D image
234
+ img_flat = image.view(C, -1) # [3, H*W]
235
+ for i in range(N):
236
+ valid_mask = valid[i]
237
+ idxs = flat_idx[i][valid_mask]
238
+ vals = color_patches[i, :, valid_mask] # [3, valid_count]
239
+ img_flat[:, idxs] += vals
240
+
241
+ out_frames = frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
242
+ return out_frames
243
+
244
+
245
+ def paint_point_track_parallel(
246
+ frames: np.ndarray,
247
+ point_tracks: np.ndarray,
248
+ visibles: np.ndarray,
249
+ colormap: Optional[List[Tuple[int, int, int]]] = None,
250
+ max_workers: int = 8,
251
+ ) -> np.ndarray:
252
+ num_points, num_frames = point_tracks.shape[:2]
253
+ if colormap is None:
254
+ colormap = get_colors(num_colors=num_points)
255
+ height, width = frames.shape[1:3]
256
+ radius = 2
257
+ print('radius', radius)
258
+ diam = radius * 2 + 1
259
+ # Precompute the icon and its bilinear components
260
+ quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
261
+ quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
262
+ icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
263
+ sharpness = 0.15
264
+ icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
265
+ icon = 1 - icon[:, :, np.newaxis]
266
+ icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
267
+ icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
268
+ icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
269
+ icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
270
+
271
+ def draw_point(image, i, t):
272
+ if not visibles[i, t]:
273
+ return
274
+ x, y = point_tracks[i, t, :] + 0.5
275
+ x = min(max(x, 0.0), width)
276
+ y = min(max(y, 0.0), height)
277
+ x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
278
+ x2, y2 = x1 + 1, y1 + 1
279
+ patch = (
280
+ icon1 * (x2 - x) * (y2 - y)
281
+ + icon2 * (x2 - x) * (y - y1)
282
+ + icon3 * (x - x1) * (y2 - y)
283
+ + icon4 * (x - x1) * (y - y1)
284
+ )
285
+ x_ub = x1 + 2 * radius + 2
286
+ y_ub = y1 + 2 * radius + 2
287
+ image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[y1:y_ub, x1:x_ub, :] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
288
+
289
+ video = frames.copy()
290
+ for t in range(num_frames):
291
+ image = np.pad(
292
+ video[t],
293
+ [(radius + 1, radius + 1), (radius + 1, radius + 1), (0, 0)],
294
+ )
295
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
296
+ futures = [executor.submit(draw_point, image, i, t) for i in range(num_points)]
297
+ _ = [f.result() for f in futures] # wait for all threads
298
+ video[t] = image[radius + 1 : -radius - 1, radius + 1 : -radius - 1].astype(np.uint8)
299
+
300
+ return video
301
+
302
+
303
  def paint_point_track(
304
  frames: np.ndarray,
305
  point_tracks: np.ndarray,
 
537
  def track(
538
  video_preview,
539
  video_input,
540
+ video_fps,
541
+ query_frame,
542
  query_points,
543
  query_points_color,
544
  query_count,
 
546
  # tracking_mode = 'selected'
547
  # if query_count == 0:
548
  # tracking_mode = 'grid'
549
+
550
+ # print('query_frames', query_frames)
551
+ # query_frame = int(query_frames[0])
552
+ # # query_frame = 0
553
 
554
  device = "cuda" if torch.cuda.is_available() else "cpu"
555
  dtype = torch.float if device == "cuda" else torch.float
 
619
  # add_support_grid=True
620
 
621
 
622
+ # query_frame = 0
623
 
624
  torch.cuda.empty_cache()
625
 
 
656
  # colors.extend(frame_colors)
657
  # colors = np.array(colors)
658
 
659
+ # traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
660
+ # visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
661
+ traj_maps_e = traj_maps_e[:,:,:,::2,::2] # subsample
662
+ visconf_maps_e = visconf_maps_e[:,:,:,::2,::2] # subsample
663
 
664
  tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
665
+ visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
666
+ confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
667
+
668
+ visibs = (visibs * confs) > 0.9 # N,T
669
+
670
 
671
  # sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
672
  # print('sc', sc)
 
685
  for frame_colors in query_points_color:
686
  colors.extend(frame_colors)
687
  colors = np.array(colors)
688
+
689
+ inds = np.sum(visibs * 1.0, axis=1) >= min(T//4,3)
690
+ tracks = tracks[inds]
691
+ visibs = visibs[inds]
692
+ colors = colors[inds]
693
 
694
+ # painted_video = paint_point_track_parallel(video_preview,tracks,visibs,colors)
695
+ # painted_video = paint_point_track_gpu(video_preview,tracks,visibs,colors)
696
+ painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors)
697
  print("7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
698
 
699
  # save video
 
771
 
772
  with gr.Row():
773
  current_frame = gr.Image(
774
+ # label="Click to add query points",
775
+ label="Query frame",
776
  type="numpy",
777
  interactive=False
778
  )
 
905
  video_preview,
906
  video_input,
907
  video_fps,
908
+ query_frames,
909
  query_points,
910
  query_points_color,
911
  query_count,