aharley commited on
Commit
f7f5275
·
1 Parent(s): 376df90
Files changed (1) hide show
  1. app.py +694 -0
app.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this is built from https://huggingface.co/spaces/facebook/cotracker/blob/main/app.py
2
+ # which was built from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py
3
+
4
+ import os
5
+ import sys
6
+ import uuid
7
+
8
+ import gradio as gr
9
+ import mediapy
10
+ import numpy as np
11
+ import cv2
12
+ import matplotlib
13
+ import torch
14
+ import colorsys
15
+ import random
16
+ from typing import List, Optional, Sequence, Tuple
17
+ import spaces
18
+ import numpy as np
19
+ import utils.basic
20
+ import utils.improc
21
+
22
+
23
+ # Generate random colormaps for visualizing different points.
24
+ def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
25
+ """Gets colormap for points."""
26
+ colors = []
27
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
28
+ hue = i / 360.0
29
+ lightness = (50 + np.random.rand() * 10) / 100.0
30
+ saturation = (90 + np.random.rand() * 10) / 100.0
31
+ color = colorsys.hls_to_rgb(hue, lightness, saturation)
32
+ colors.append(
33
+ (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
34
+ )
35
+ random.shuffle(colors)
36
+ return colors
37
+
38
+ def get_points_on_a_grid(
39
+ size: int,
40
+ extent: Tuple[float, ...],
41
+ center: Optional[Tuple[float, ...]] = None,
42
+ device: Optional[torch.device] = torch.device("cpu"),
43
+ ):
44
+ r"""Get a grid of points covering a rectangular region
45
+
46
+ `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
47
+ :attr:`size` grid fo points distributed to cover a rectangular area
48
+ specified by `extent`.
49
+
50
+ The `extent` is a pair of integer :math:`(H,W)` specifying the height
51
+ and width of the rectangle.
52
+
53
+ Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
54
+ specifying the vertical and horizontal center coordinates. The center
55
+ defaults to the middle of the extent.
56
+
57
+ Points are distributed uniformly within the rectangle leaving a margin
58
+ :math:`m=W/64` from the border.
59
+
60
+ It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
61
+ points :math:`P_{ij}=(x_i, y_i)` where
62
+
63
+ .. math::
64
+ P_{ij} = \left(
65
+ c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
66
+ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
67
+ \right)
68
+
69
+ Points are returned in row-major order.
70
+
71
+ Args:
72
+ size (int): grid size.
73
+ extent (tuple): height and with of the grid extent.
74
+ center (tuple, optional): grid center.
75
+ device (str, optional): Defaults to `"cpu"`.
76
+
77
+ Returns:
78
+ Tensor: grid.
79
+ """
80
+ if size == 1:
81
+ return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
82
+
83
+ if center is None:
84
+ center = [extent[0] / 2, extent[1] / 2]
85
+
86
+ margin = extent[1] / 64
87
+ range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
88
+ range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
89
+ grid_y, grid_x = torch.meshgrid(
90
+ torch.linspace(*range_y, size, device=device),
91
+ torch.linspace(*range_x, size, device=device),
92
+ indexing="ij",
93
+ )
94
+ return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
95
+
96
+ def paint_point_track(
97
+ frames: np.ndarray,
98
+ point_tracks: np.ndarray,
99
+ visibles: np.ndarray,
100
+ colormap: Optional[List[Tuple[int, int, int]]] = None,
101
+ ) -> np.ndarray:
102
+ """Converts a sequence of points to color code video.
103
+
104
+ Args:
105
+ frames: [num_frames, height, width, 3], np.uint8, [0, 255]
106
+ point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
107
+ visibles: [num_points, num_frames], bool
108
+ colormap: colormap for points, each point has a different RGB color.
109
+
110
+ Returns:
111
+ video: [num_frames, height, width, 3], np.uint8, [0, 255]
112
+ """
113
+ num_points, num_frames = point_tracks.shape[0:2]
114
+ if colormap is None:
115
+ colormap = get_colors(num_colors=num_points)
116
+ height, width = frames.shape[1:3]
117
+ dot_size_as_fraction_of_min_edge = 0.015
118
+ # radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
119
+ radius = 2
120
+ # print('radius', radius)
121
+ diam = radius * 2 + 1
122
+ quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
123
+ quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
124
+ icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
125
+ sharpness = 0.15
126
+ icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
127
+ icon = 1 - icon[:, :, np.newaxis]
128
+ icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
129
+ icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
130
+ icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
131
+ icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
132
+
133
+ video = frames.copy()
134
+ for t in range(num_frames):
135
+ # Pad so that points that extend outside the image frame don't crash us
136
+ image = np.pad(
137
+ video[t],
138
+ [
139
+ (radius + 1, radius + 1),
140
+ (radius + 1, radius + 1),
141
+ (0, 0),
142
+ ],
143
+ )
144
+ for i in range(num_points):
145
+ # The icon is centered at the center of a pixel, but the input coordinates
146
+ # are raster coordinates. Therefore, to render a point at (1,1) (which
147
+ # lies on the corner between four pixels), we need 1/4 of the icon placed
148
+ # centered on the 0'th row, 0'th column, etc. We need to subtract
149
+ # 0.5 to make the fractional position come out right.
150
+ x, y = point_tracks[i, t, :] + 0.5
151
+ x = min(max(x, 0.0), width)
152
+ y = min(max(y, 0.0), height)
153
+
154
+ if visibles[i, t]:
155
+ x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
156
+ x2, y2 = x1 + 1, y1 + 1
157
+
158
+ # bilinear interpolation
159
+ patch = (
160
+ icon1 * (x2 - x) * (y2 - y)
161
+ + icon2 * (x2 - x) * (y - y1)
162
+ + icon3 * (x - x1) * (y2 - y)
163
+ + icon4 * (x - x1) * (y - y1)
164
+ )
165
+ x_ub = x1 + 2 * radius + 2
166
+ y_ub = y1 + 2 * radius + 2
167
+ image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
168
+ y1:y_ub, x1:x_ub, :
169
+ ] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
170
+
171
+ # Remove the pad
172
+ video[t] = image[
173
+ radius + 1 : -radius - 1, radius + 1 : -radius - 1
174
+ ].astype(np.uint8)
175
+ return video
176
+
177
+
178
+ PREVIEW_WIDTH = 768 # Width of the preview video
179
+ PREVIEW_HEIGHT = 768
180
+ # VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
181
+ POINT_SIZE = 1 # Size of the query point in the preview video
182
+ FRAME_LIMIT = 300 # Limit the number of frames to process
183
+
184
+
185
+ def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
186
+ print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
187
+
188
+ current_frame = video_queried_preview[int(frame_num)]
189
+
190
+ # Get the mouse click
191
+ query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num))
192
+
193
+ # Choose the color for the point from matplotlib colormap
194
+ color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20)
195
+ color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
196
+ # print(f"Color: {color}")
197
+ query_points_color[int(frame_num)].append(color)
198
+
199
+ # Draw the point on the frame
200
+ x, y = evt.index
201
+ current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1)
202
+
203
+ # Update the frame
204
+ video_queried_preview[int(frame_num)] = current_frame_draw
205
+
206
+ # Update the query count
207
+ query_count += 1
208
+ return (
209
+ current_frame_draw, # Updated frame for preview
210
+ video_queried_preview, # Updated preview video
211
+ query_points, # Updated query points
212
+ query_points_color, # Updated query points color
213
+ query_count # Updated query count
214
+ )
215
+
216
+
217
+ def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
218
+ if len(query_points[int(frame_num)]) == 0:
219
+ return (
220
+ video_queried_preview[int(frame_num)],
221
+ video_queried_preview,
222
+ query_points,
223
+ query_points_color,
224
+ query_count
225
+ )
226
+
227
+ # Get the last point
228
+ query_points[int(frame_num)].pop(-1)
229
+ query_points_color[int(frame_num)].pop(-1)
230
+
231
+ # Redraw the frame
232
+ current_frame_draw = video_preview[int(frame_num)].copy()
233
+ for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
234
+ x, y, _ = point
235
+ current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
236
+
237
+ # Update the query count
238
+ query_count -= 1
239
+
240
+ # Update the frame
241
+ video_queried_preview[int(frame_num)] = current_frame_draw
242
+ return (
243
+ current_frame_draw, # Updated frame for preview
244
+ video_queried_preview, # Updated preview video
245
+ query_points, # Updated query points
246
+ query_points_color, # Updated query points color
247
+ query_count # Updated query count
248
+ )
249
+
250
+
251
+ def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
252
+ query_count -= len(query_points[int(frame_num)])
253
+
254
+ query_points[int(frame_num)] = []
255
+ query_points_color[int(frame_num)] = []
256
+
257
+ video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
258
+
259
+ return (
260
+ video_preview[int(frame_num)], # Set the preview frame to the original frame
261
+ video_queried_preview,
262
+ query_points, # Cleared query points
263
+ query_points_color, # Cleared query points color
264
+ query_count # New query count
265
+ )
266
+
267
+
268
+
269
+ def clear_all_fn(frame_num, video_preview):
270
+ return (
271
+ video_preview[int(frame_num)],
272
+ video_preview.copy(),
273
+ [[] for _ in range(len(video_preview))],
274
+ [[] for _ in range(len(video_preview))],
275
+ 0
276
+ )
277
+
278
+
279
+ def choose_frame(frame_num, video_preview_array):
280
+ return video_preview_array[int(frame_num)]
281
+
282
+
283
+ def preprocess_video_input(video_path):
284
+ video_arr = mediapy.read_video(video_path)
285
+ video_fps = video_arr.metadata.fps
286
+ num_frames = video_arr.shape[0]
287
+ if num_frames > FRAME_LIMIT:
288
+ gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5)
289
+ video_arr = video_arr[:FRAME_LIMIT]
290
+ num_frames = FRAME_LIMIT
291
+
292
+ # Resize to preview size for faster processing, width = PREVIEW_WIDTH
293
+ height, width = video_arr.shape[1:3]
294
+ if height > width:
295
+ new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
296
+ else:
297
+ new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
298
+ preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
299
+ # input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
300
+ # input_video = video_arr
301
+ input_video = preview_video
302
+
303
+ preview_video = np.array(preview_video)
304
+ input_video = np.array(input_video)
305
+
306
+ interactive = True
307
+
308
+ return (
309
+ video_arr, # Original video
310
+ preview_video, # Original preview video, resized for faster processing
311
+ preview_video.copy(), # Copy of preview video for visualization
312
+ input_video, # Resized video input for model
313
+ # None, # video_feature, # Extracted feature
314
+ video_fps, # Set the video FPS
315
+ gr.update(open=False), # Close the video input drawer
316
+ # tracking_mode, # Set the tracking mode
317
+ preview_video[0], # Set the preview frame to the first frame
318
+ gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive
319
+ [[] for _ in range(num_frames)], # Set query_points to empty
320
+ [[] for _ in range(num_frames)], # Set query_points_color to empty
321
+ [[] for _ in range(num_frames)],
322
+ 0, # Set query count to 0
323
+ gr.update(interactive=interactive), # Make the buttons interactive
324
+ gr.update(interactive=interactive),
325
+ gr.update(interactive=interactive),
326
+ gr.update(interactive=True),
327
+ )
328
+
329
+ @spaces.GPU
330
+ def track(
331
+ video_preview,
332
+ video_input,
333
+ video_fps,
334
+ query_points,
335
+ query_points_color,
336
+ query_count,
337
+ ):
338
+ # tracking_mode = 'selected'
339
+ # if query_count == 0:
340
+ # tracking_mode = 'grid'
341
+
342
+ device = "cuda" if torch.cuda.is_available() else "cpu"
343
+ dtype = torch.float if device == "cuda" else torch.float
344
+
345
+ print("0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
346
+
347
+ # # Convert query points to tensor, normalize to input resolution
348
+ # if tracking_mode!='grid':
349
+ # query_points_tensor = []
350
+ # for frame_points in query_points:
351
+ # query_points_tensor.extend(frame_points)
352
+
353
+ # query_points_tensor = torch.tensor(query_points_tensor).float()
354
+ # query_points_tensor *= torch.tensor([
355
+ # VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1
356
+ # ]) / torch.tensor([
357
+ # [video_preview.shape[2], video_preview.shape[1], 1]
358
+ # ])
359
+ # query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) # xyt -> tyx
360
+ # query_points_tensor = query_points_tensor[:, :, [0, 2, 1]] # tyx -> txy
361
+
362
+ video_input = torch.tensor(video_input).unsqueeze(0).to(dtype)
363
+ print("1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
364
+
365
+ # model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online")
366
+ # model = model.to(device)
367
+
368
+ from nets.alltracker import Net
369
+ model = Net(16)
370
+ url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth"
371
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
372
+ model.load_state_dict(state_dict['model'], strict=True)
373
+ print('loaded weights from', url)
374
+ model = model.to(device)
375
+ print("2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
376
+
377
+ video_input = video_input.permute(0, 1, 4, 2, 3)
378
+
379
+ print('video_input', video_input.shape)
380
+ # model(video_input, iters=4, sw=None, is_training=False)
381
+ # # model(video_chunk=video_input, is_first_step=True, grid_size=0, queries=queries, add_support_grid=add_support_grid)
382
+
383
+ _, T, _, H, W = video_input.shape
384
+ utils.basic.print_stats('video_input', video_input)
385
+ print("3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
386
+
387
+ grid_xy = utils.basic.gridcloud2d(1, H, W, norm=False, device='cpu:0').float() # 1,H*W,2
388
+ grid_xy = grid_xy.permute(0,2,1).reshape(1,1,2,H,W) # 1,1,2,H,W
389
+ print("4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
390
+
391
+
392
+ # if tracking_mode=='grid':
393
+ # xy = get_points_on_a_grid(15, video_input.shape[3:], device=device)
394
+ # queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
395
+ # add_support_grid=False
396
+ # cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
397
+ # query_points_color = [[]]
398
+ # query_count = queries.shape[1]
399
+ # for i in range(query_count):
400
+ # # Choose the color for the point from matplotlib colormap
401
+ # color = cmap(i / float(query_count))
402
+ # color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
403
+ # query_points_color[0].append(color)
404
+
405
+ # else:
406
+ # queries = query_points_tensor
407
+ # add_support_grid=True
408
+
409
+
410
+ query_frame = 0
411
+
412
+ torch.cuda.empty_cache()
413
+
414
+ with torch.no_grad():
415
+ # model.forward_sliding(
416
+ flows_e, visconf_maps_e, _, _ = \
417
+ model.forward_sliding(video_input[:, query_frame:], iters=4, sw=None, is_training=False)
418
+ traj_maps_e = flows_e + grid_xy # B,Tf,2,H,W
419
+ print("5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
420
+
421
+ if query_frame > 0:
422
+ backward_flows_e, backward_visconf_maps_e, _, _ = \
423
+ model.forward_sliding(video_input[:, :query_frame+1].flip([1]), iters=4, sw=None, is_training=False)
424
+ backward_traj_maps_e = backward_flows_e + grid_xy # B,Tb,2,H,W, reversed
425
+ backward_traj_maps_e = backward_traj_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
426
+ backward_visconf_maps_e = backward_visconf_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
427
+ traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W
428
+ visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W
429
+ print("6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
430
+
431
+ # for ind in range(0, video_input.shape[1] - model.step, model.step):
432
+ # pred_tracks, pred_visibility = model(
433
+ # video_chunk=video_input[:, ind : ind + model.step * 2],
434
+ # grid_size=0,
435
+ # queries=queries,
436
+ # add_support_grid=add_support_grid
437
+ # ) # B T N 2, B T N 1
438
+ # tracks = (pred_tracks * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device) / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device))[0].permute(1, 0, 2).cpu().numpy()
439
+ # pred_occ = pred_visibility[0].permute(1, 0).cpu().numpy()
440
+
441
+ # # make color array
442
+ # colors = []
443
+ # for frame_colors in query_points_color:
444
+ # colors.extend(frame_colors)
445
+ # colors = np.array(colors)
446
+
447
+ traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
448
+ visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
449
+
450
+ tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
451
+ visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy() > 0.9
452
+
453
+ # sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
454
+ # print('sc', sc)
455
+ # tracks = tracks * sc
456
+
457
+ query_count = tracks.shape[0]
458
+ cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
459
+ query_points_color = [[]]
460
+ for i in range(query_count):
461
+ # Choose the color for the point from matplotlib colormap
462
+ color = cmap(i / float(query_count))
463
+ color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
464
+ query_points_color[0].append(color)
465
+ # make color array
466
+ colors = []
467
+ for frame_colors in query_points_color:
468
+ colors.extend(frame_colors)
469
+ colors = np.array(colors)
470
+
471
+ painted_video = paint_point_track(video_preview,tracks,visibs,colors)
472
+ print("7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
473
+
474
+ # save video
475
+ video_file_name = uuid.uuid4().hex + ".mp4"
476
+ video_path = os.path.join(os.path.dirname(__file__), "tmp")
477
+ video_file_path = os.path.join(video_path, video_file_name)
478
+ os.makedirs(video_path, exist_ok=True)
479
+
480
+ mediapy.write_video(video_file_path, painted_video, fps=video_fps)
481
+
482
+ return video_file_path
483
+
484
+
485
+ with gr.Blocks() as demo:
486
+ video = gr.State()
487
+ video_queried_preview = gr.State()
488
+ video_preview = gr.State()
489
+ video_input = gr.State()
490
+ video_fps = gr.State(24)
491
+
492
+ query_points = gr.State([])
493
+ query_points_color = gr.State([])
494
+ is_tracked_query = gr.State([])
495
+ query_count = gr.State(0)
496
+
497
+ gr.Markdown("# 🎨 CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos")
498
+ gr.Markdown("<div style='text-align: left;'> \
499
+ <p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
500
+ The model tracks points on a grid or points selected by you. </p> \
501
+ <p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
502
+ <p> After you uploaded a video, please click \"Submit\" and then click \"Track\" for grid tracking or specify points you want to track before clicking. Enjoy the results! </p>\
503
+ <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
504
+ </div>"
505
+ )
506
+
507
+
508
+ gr.Markdown("## First step: upload your video or select an example video, and click submit.")
509
+ with gr.Row():
510
+
511
+
512
+ with gr.Accordion("Your video input", open=True) as video_in_drawer:
513
+ video_in = gr.Video(label="Video Input", format="mp4")
514
+ submit = gr.Button("Submit", scale=0)
515
+
516
+ import os
517
+ apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
518
+ bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
519
+ paragliding_launch = os.path.join(
520
+ os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
521
+ )
522
+ paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
523
+ cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
524
+ pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
525
+ teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
526
+ backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
527
+
528
+
529
+ gr.Examples(examples=[bear, apple, paragliding, paragliding_launch, cat, pillow, teddy, backpack],
530
+ inputs = [
531
+ video_in
532
+ ],
533
+ )
534
+
535
+
536
+ gr.Markdown("## Second step: Simply click \"Track\" to track a grid of points or select query points on the video before clicking")
537
+ with gr.Row():
538
+ with gr.Column():
539
+ with gr.Row():
540
+ query_frames = gr.Slider(
541
+ minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
542
+ with gr.Row():
543
+ undo = gr.Button("Undo", interactive=False)
544
+ clear_frame = gr.Button("Clear Frame", interactive=False)
545
+ clear_all = gr.Button("Clear All", interactive=False)
546
+
547
+ with gr.Row():
548
+ current_frame = gr.Image(
549
+ label="Click to add query points",
550
+ type="numpy",
551
+ interactive=False
552
+ )
553
+
554
+ with gr.Row():
555
+ track_button = gr.Button("Track", interactive=False)
556
+
557
+ with gr.Column():
558
+ output_video = gr.Video(
559
+ label="Output Video",
560
+ interactive=False,
561
+ autoplay=True,
562
+ loop=True,
563
+ )
564
+
565
+
566
+
567
+ submit.click(
568
+ fn = preprocess_video_input,
569
+ inputs = [video_in],
570
+ outputs = [
571
+ video,
572
+ video_preview,
573
+ video_queried_preview,
574
+ video_input,
575
+ video_fps,
576
+ video_in_drawer,
577
+ current_frame,
578
+ query_frames,
579
+ query_points,
580
+ query_points_color,
581
+ is_tracked_query,
582
+ query_count,
583
+ undo,
584
+ clear_frame,
585
+ clear_all,
586
+ track_button,
587
+ ],
588
+ queue = False
589
+ )
590
+
591
+ query_frames.change(
592
+ fn = choose_frame,
593
+ inputs = [query_frames, video_queried_preview],
594
+ outputs = [
595
+ current_frame,
596
+ ],
597
+ queue = False
598
+ )
599
+
600
+ current_frame.select(
601
+ fn = get_point,
602
+ inputs = [
603
+ query_frames,
604
+ video_queried_preview,
605
+ query_points,
606
+ query_points_color,
607
+ query_count,
608
+ ],
609
+ outputs = [
610
+ current_frame,
611
+ video_queried_preview,
612
+ query_points,
613
+ query_points_color,
614
+ query_count
615
+ ],
616
+ queue = False
617
+ )
618
+
619
+ undo.click(
620
+ fn = undo_point,
621
+ inputs = [
622
+ query_frames,
623
+ video_preview,
624
+ video_queried_preview,
625
+ query_points,
626
+ query_points_color,
627
+ query_count
628
+ ],
629
+ outputs = [
630
+ current_frame,
631
+ video_queried_preview,
632
+ query_points,
633
+ query_points_color,
634
+ query_count
635
+ ],
636
+ queue = False
637
+ )
638
+
639
+ clear_frame.click(
640
+ fn = clear_frame_fn,
641
+ inputs = [
642
+ query_frames,
643
+ video_preview,
644
+ video_queried_preview,
645
+ query_points,
646
+ query_points_color,
647
+ query_count
648
+ ],
649
+ outputs = [
650
+ current_frame,
651
+ video_queried_preview,
652
+ query_points,
653
+ query_points_color,
654
+ query_count
655
+ ],
656
+ queue = False
657
+ )
658
+
659
+ clear_all.click(
660
+ fn = clear_all_fn,
661
+ inputs = [
662
+ query_frames,
663
+ video_preview,
664
+ ],
665
+ outputs = [
666
+ current_frame,
667
+ video_queried_preview,
668
+ query_points,
669
+ query_points_color,
670
+ query_count
671
+ ],
672
+ queue = False
673
+ )
674
+
675
+
676
+ track_button.click(
677
+ fn = track,
678
+ inputs = [
679
+ video_preview,
680
+ video_input,
681
+ video_fps,
682
+ query_points,
683
+ query_points_color,
684
+ query_count,
685
+ ],
686
+ outputs = [
687
+ output_video,
688
+ ],
689
+ queue = True,
690
+ )
691
+
692
+
693
+ # demo.launch(show_api=False, show_error=True, debug=False, share=False)
694
+ demo.launch(show_api=False, show_error=True, debug=False, share=True)