aharley commited on
Commit
574fdd2
·
1 Parent(s): 09e82bb

updated comments

Browse files
Files changed (4) hide show
  1. README.md +20 -8
  2. app.py +321 -260
  3. nets/alltracker.py +11 -11
  4. requirements.txt +17 -0
README.md CHANGED
@@ -1,14 +1,26 @@
1
  ---
2
- title: Alltracker
3
- emoji: 📈
4
- colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.35.0
 
 
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Efficient dense tracking
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AllTracker
3
+ emoji:
4
+ colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.34.2
8
+ suggested_hardware: a100-large
9
+ suggested_storage: large
10
  app_file: app.py
11
+ pinned: true
12
+ license: cc-by-nc-4.0
 
13
  ---
14
 
15
+ This is a demo for ["AllTracker: Efficient Dense Point Tracking at High Resolution"](https://alltracker.github.io/)
16
+
17
+ Paper page: https://huggingface.co/papers/2506.07310
18
+
19
+ ```
20
+ @inproceedings{harley2025alltracker,
21
+ author = {Adam W. Harley and Yang You and Xinglong Sun and Yang Zheng and Nikhil Raghuraman and Yunqi Gu and Sheldon Liang and Wen-Hsuan Chu and Achal Dave and Pavel Tokmakov and Suya You and Rares Ambrus and Katerina Fragkiadaki and Leonidas J. Guibas},
22
+ title = {All{T}racker: {E}fficient Dense Point Tracking at High Resolution}
23
+ booktitle = {ICCV},
24
+ year = {2025}
25
+ }
26
+ ```
app.py CHANGED
@@ -5,7 +5,9 @@ import os
5
  import sys
6
  import uuid
7
  from concurrent.futures import ThreadPoolExecutor
 
8
 
 
9
 
10
  import gradio as gr
11
  import mediapy
@@ -21,6 +23,7 @@ import numpy as np
21
  import utils.basic
22
  import utils.improc
23
 
 
24
 
25
  # Generate random colormaps for visualizing different points.
26
  def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
@@ -37,63 +40,63 @@ def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
37
  random.shuffle(colors)
38
  return colors
39
 
40
- def get_points_on_a_grid(
41
- size: int,
42
- extent: Tuple[float, ...],
43
- center: Optional[Tuple[float, ...]] = None,
44
- device: Optional[torch.device] = torch.device("cpu"),
45
- ):
46
- r"""Get a grid of points covering a rectangular region
47
-
48
- `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
49
- :attr:`size` grid fo points distributed to cover a rectangular area
50
- specified by `extent`.
51
-
52
- The `extent` is a pair of integer :math:`(H,W)` specifying the height
53
- and width of the rectangle.
54
-
55
- Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
56
- specifying the vertical and horizontal center coordinates. The center
57
- defaults to the middle of the extent.
58
-
59
- Points are distributed uniformly within the rectangle leaving a margin
60
- :math:`m=W/64` from the border.
61
-
62
- It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
63
- points :math:`P_{ij}=(x_i, y_i)` where
64
-
65
- .. math::
66
- P_{ij} = \left(
67
- c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
68
- c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
69
- \right)
70
-
71
- Points are returned in row-major order.
72
-
73
- Args:
74
- size (int): grid size.
75
- extent (tuple): height and with of the grid extent.
76
- center (tuple, optional): grid center.
77
- device (str, optional): Defaults to `"cpu"`.
78
-
79
- Returns:
80
- Tensor: grid.
81
- """
82
- if size == 1:
83
- return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
84
-
85
- if center is None:
86
- center = [extent[0] / 2, extent[1] / 2]
87
-
88
- margin = extent[1] / 64
89
- range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
90
- range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
91
- grid_y, grid_x = torch.meshgrid(
92
- torch.linspace(*range_y, size, device=device),
93
- torch.linspace(*range_x, size, device=device),
94
- indexing="ij",
95
- )
96
- return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
97
 
98
  def paint_point_track_gpu_scatter(
99
  frames: np.ndarray,
@@ -382,105 +385,105 @@ def paint_point_track(
382
  return video
383
 
384
 
385
- PREVIEW_WIDTH = 768 # Width of the preview video
386
- PREVIEW_HEIGHT = 768
387
  # VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
388
  POINT_SIZE = 1 # Size of the query point in the preview video
389
- FRAME_LIMIT = 300 # Limit the number of frames to process
390
 
391
 
392
- def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
393
- print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
394
 
395
- current_frame = video_queried_preview[int(frame_num)]
396
 
397
- # Get the mouse click
398
- query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num))
399
 
400
- # Choose the color for the point from matplotlib colormap
401
- color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20)
402
- color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
403
- # print(f"Color: {color}")
404
- query_points_color[int(frame_num)].append(color)
405
 
406
- # Draw the point on the frame
407
- x, y = evt.index
408
- current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1)
409
 
410
- # Update the frame
411
- video_queried_preview[int(frame_num)] = current_frame_draw
412
 
413
- # Update the query count
414
- query_count += 1
415
- return (
416
- current_frame_draw, # Updated frame for preview
417
- video_queried_preview, # Updated preview video
418
- query_points, # Updated query points
419
- query_points_color, # Updated query points color
420
- query_count # Updated query count
421
- )
422
 
423
 
424
- def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
425
- if len(query_points[int(frame_num)]) == 0:
426
- return (
427
- video_queried_preview[int(frame_num)],
428
- video_queried_preview,
429
- query_points,
430
- query_points_color,
431
- query_count
432
- )
433
 
434
- # Get the last point
435
- query_points[int(frame_num)].pop(-1)
436
- query_points_color[int(frame_num)].pop(-1)
437
 
438
- # Redraw the frame
439
- current_frame_draw = video_preview[int(frame_num)].copy()
440
- for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
441
- x, y, _ = point
442
- current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
443
 
444
- # Update the query count
445
- query_count -= 1
446
 
447
- # Update the frame
448
- video_queried_preview[int(frame_num)] = current_frame_draw
449
- return (
450
- current_frame_draw, # Updated frame for preview
451
- video_queried_preview, # Updated preview video
452
- query_points, # Updated query points
453
- query_points_color, # Updated query points color
454
- query_count # Updated query count
455
- )
456
 
457
 
458
- def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
459
- query_count -= len(query_points[int(frame_num)])
460
 
461
- query_points[int(frame_num)] = []
462
- query_points_color[int(frame_num)] = []
463
 
464
- video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
465
 
466
- return (
467
- video_preview[int(frame_num)], # Set the preview frame to the original frame
468
- video_queried_preview,
469
- query_points, # Cleared query points
470
- query_points_color, # Cleared query points color
471
- query_count # New query count
472
- )
473
 
474
 
475
 
476
- def clear_all_fn(frame_num, video_preview):
477
- return (
478
- video_preview[int(frame_num)],
479
- video_preview.copy(),
480
- [[] for _ in range(len(video_preview))],
481
- [[] for _ in range(len(video_preview))],
482
- 0
483
- )
484
 
485
 
486
  def choose_frame(frame_num, video_preview_array):
@@ -502,6 +505,11 @@ def preprocess_video_input(video_path):
502
  new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
503
  else:
504
  new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
 
 
 
 
 
505
  preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
506
  # input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
507
  # input_video = video_arr
@@ -519,7 +527,7 @@ def preprocess_video_input(video_path):
519
  input_video, # Resized video input for model
520
  # None, # video_feature, # Extracted feature
521
  video_fps, # Set the video FPS
522
- gr.update(open=False), # Close the video input drawer
523
  # tracking_mode, # Set the tracking mode
524
  preview_video[0], # Set the preview frame to the first frame
525
  gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive
@@ -624,20 +632,47 @@ def track(
624
  torch.cuda.empty_cache()
625
 
626
  with torch.no_grad():
627
- # model.forward_sliding(
628
- flows_e, visconf_maps_e, _, _ = \
629
- model.forward_sliding(video_input[:, query_frame:], iters=4, sw=None, is_training=False)
630
- traj_maps_e = flows_e + grid_xy # B,Tf,2,H,W
631
- print("5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
632
-
 
 
 
633
  if query_frame > 0:
634
  backward_flows_e, backward_visconf_maps_e, _, _ = \
635
- model.forward_sliding(video_input[:, :query_frame+1].flip([1]), iters=4, sw=None, is_training=False)
636
- backward_traj_maps_e = backward_flows_e + grid_xy # B,Tb,2,H,W, reversed
637
- backward_traj_maps_e = backward_traj_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
638
- backward_visconf_maps_e = backward_visconf_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
 
 
 
 
639
  traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W
640
  visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  print("6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
642
 
643
  # for ind in range(0, video_input.shape[1] - model.step, model.step):
@@ -665,7 +700,8 @@ def track(
665
  visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
666
  confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
667
 
668
- visibs = (visibs * confs) > 0.9 # N,T
 
669
 
670
 
671
  # sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
@@ -700,9 +736,31 @@ def track(
700
  video_file_name = uuid.uuid4().hex + ".mp4"
701
  video_path = os.path.join(os.path.dirname(__file__), "tmp")
702
  video_file_path = os.path.join(video_path, video_file_name)
703
- os.makedirs(video_path, exist_ok=True)
704
 
705
- mediapy.write_video(video_file_path, painted_video, fps=video_fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
  return video_file_path
708
 
@@ -719,55 +777,58 @@ with gr.Blocks() as demo:
719
  is_tracked_query = gr.State([])
720
  query_count = gr.State(0)
721
 
722
- gr.Markdown("# 🎨 CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos")
723
  gr.Markdown("<div style='text-align: left;'> \
724
- <p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
725
- The model tracks points on a grid or points selected by you. </p> \
726
- <p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
727
- <p> After you uploaded a video, please click \"Submit\" and then click \"Track\" for grid tracking or specify points you want to track before clicking. Enjoy the results! </p>\
728
- <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
 
729
  </div>"
730
  )
731
 
732
 
733
- gr.Markdown("## First step: upload your video or select an example video, and click submit.")
734
  with gr.Row():
735
-
736
-
737
- with gr.Accordion("Your video input", open=True) as video_in_drawer:
738
- video_in = gr.Video(label="Video Input", format="mp4")
739
- submit = gr.Button("Submit", scale=0)
740
-
741
- import os
742
- apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
743
- bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
744
- paragliding_launch = os.path.join(
745
- os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
746
- )
747
- paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
748
- cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
749
- pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
750
- teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
751
- backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
752
-
753
-
754
- gr.Examples(examples=[bear, apple, paragliding, paragliding_launch, cat, pillow, teddy, backpack],
755
- inputs = [
756
- video_in
757
- ],
758
- )
759
-
 
 
760
 
761
- gr.Markdown("## Second step: Simply click \"Track\" to track a grid of points or select query points on the video before clicking")
762
  with gr.Row():
763
  with gr.Column():
764
  with gr.Row():
765
  query_frames = gr.Slider(
766
  minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
767
- with gr.Row():
768
- undo = gr.Button("Undo", interactive=False)
769
- clear_frame = gr.Button("Clear Frame", interactive=False)
770
- clear_all = gr.Button("Clear All", interactive=False)
771
 
772
  with gr.Row():
773
  current_frame = gr.Image(
@@ -799,16 +860,16 @@ with gr.Blocks() as demo:
799
  video_queried_preview,
800
  video_input,
801
  video_fps,
802
- video_in_drawer,
803
  current_frame,
804
  query_frames,
805
  query_points,
806
  query_points_color,
807
  is_tracked_query,
808
  query_count,
809
- undo,
810
- clear_frame,
811
- clear_all,
812
  track_button,
813
  ],
814
  queue = False
@@ -823,80 +884,80 @@ with gr.Blocks() as demo:
823
  queue = False
824
  )
825
 
826
- current_frame.select(
827
- fn = get_point,
828
- inputs = [
829
- query_frames,
830
- video_queried_preview,
831
- query_points,
832
- query_points_color,
833
- query_count,
834
- ],
835
- outputs = [
836
- current_frame,
837
- video_queried_preview,
838
- query_points,
839
- query_points_color,
840
- query_count
841
- ],
842
- queue = False
843
- )
844
 
845
- undo.click(
846
- fn = undo_point,
847
- inputs = [
848
- query_frames,
849
- video_preview,
850
- video_queried_preview,
851
- query_points,
852
- query_points_color,
853
- query_count
854
- ],
855
- outputs = [
856
- current_frame,
857
- video_queried_preview,
858
- query_points,
859
- query_points_color,
860
- query_count
861
- ],
862
- queue = False
863
- )
864
-
865
- clear_frame.click(
866
- fn = clear_frame_fn,
867
- inputs = [
868
- query_frames,
869
- video_preview,
870
- video_queried_preview,
871
- query_points,
872
- query_points_color,
873
- query_count
874
- ],
875
- outputs = [
876
- current_frame,
877
- video_queried_preview,
878
- query_points,
879
- query_points_color,
880
- query_count
881
- ],
882
- queue = False
883
- )
884
-
885
- clear_all.click(
886
- fn = clear_all_fn,
887
- inputs = [
888
- query_frames,
889
- video_preview,
890
- ],
891
- outputs = [
892
- current_frame,
893
- video_queried_preview,
894
- query_points,
895
- query_points_color,
896
- query_count
897
- ],
898
- queue = False
899
- )
900
 
901
 
902
  track_button.click(
 
5
  import sys
6
  import uuid
7
  from concurrent.futures import ThreadPoolExecutor
8
+ import subprocess
9
 
10
+ from nets.blocks import InputPadder
11
 
12
  import gradio as gr
13
  import mediapy
 
23
  import utils.basic
24
  import utils.improc
25
 
26
+ import PIL.Image
27
 
28
  # Generate random colormaps for visualizing different points.
29
  def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
 
40
  random.shuffle(colors)
41
  return colors
42
 
43
+ # def get_points_on_a_grid(
44
+ # size: int,
45
+ # extent: Tuple[float, ...],
46
+ # center: Optional[Tuple[float, ...]] = None,
47
+ # device: Optional[torch.device] = torch.device("cpu"),
48
+ # ):
49
+ # r"""Get a grid of points covering a rectangular region
50
+
51
+ # `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
52
+ # :attr:`size` grid fo points distributed to cover a rectangular area
53
+ # specified by `extent`.
54
+
55
+ # The `extent` is a pair of integer :math:`(H,W)` specifying the height
56
+ # and width of the rectangle.
57
+
58
+ # Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
59
+ # specifying the vertical and horizontal center coordinates. The center
60
+ # defaults to the middle of the extent.
61
+
62
+ # Points are distributed uniformly within the rectangle leaving a margin
63
+ # :math:`m=W/64` from the border.
64
+
65
+ # It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
66
+ # points :math:`P_{ij}=(x_i, y_i)` where
67
+
68
+ # .. math::
69
+ # P_{ij} = \left(
70
+ # c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
71
+ # c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
72
+ # \right)
73
+
74
+ # Points are returned in row-major order.
75
+
76
+ # Args:
77
+ # size (int): grid size.
78
+ # extent (tuple): height and with of the grid extent.
79
+ # center (tuple, optional): grid center.
80
+ # device (str, optional): Defaults to `"cpu"`.
81
+
82
+ # Returns:
83
+ # Tensor: grid.
84
+ # """
85
+ # if size == 1:
86
+ # return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
87
+
88
+ # if center is None:
89
+ # center = [extent[0] / 2, extent[1] / 2]
90
+
91
+ # margin = extent[1] / 64
92
+ # range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
93
+ # range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
94
+ # grid_y, grid_x = torch.meshgrid(
95
+ # torch.linspace(*range_y, size, device=device),
96
+ # torch.linspace(*range_x, size, device=device),
97
+ # indexing="ij",
98
+ # )
99
+ # return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
100
 
101
  def paint_point_track_gpu_scatter(
102
  frames: np.ndarray,
 
385
  return video
386
 
387
 
388
+ PREVIEW_WIDTH = 1024 # Width of the preview video
389
+ PREVIEW_HEIGHT = 1024
390
  # VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
391
  POINT_SIZE = 1 # Size of the query point in the preview video
392
+ FRAME_LIMIT = 600 # Limit the number of frames to process
393
 
394
 
395
+ # def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
396
+ # print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
397
 
398
+ # current_frame = video_queried_preview[int(frame_num)]
399
 
400
+ # # Get the mouse click
401
+ # query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num))
402
 
403
+ # # Choose the color for the point from matplotlib colormap
404
+ # color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20)
405
+ # color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
406
+ # # print(f"Color: {color}")
407
+ # query_points_color[int(frame_num)].append(color)
408
 
409
+ # # Draw the point on the frame
410
+ # x, y = evt.index
411
+ # current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1)
412
 
413
+ # # Update the frame
414
+ # video_queried_preview[int(frame_num)] = current_frame_draw
415
 
416
+ # # Update the query count
417
+ # query_count += 1
418
+ # return (
419
+ # current_frame_draw, # Updated frame for preview
420
+ # video_queried_preview, # Updated preview video
421
+ # query_points, # Updated query points
422
+ # query_points_color, # Updated query points color
423
+ # query_count # Updated query count
424
+ # )
425
 
426
 
427
+ # def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
428
+ # if len(query_points[int(frame_num)]) == 0:
429
+ # return (
430
+ # video_queried_preview[int(frame_num)],
431
+ # video_queried_preview,
432
+ # query_points,
433
+ # query_points_color,
434
+ # query_count
435
+ # )
436
 
437
+ # # Get the last point
438
+ # query_points[int(frame_num)].pop(-1)
439
+ # query_points_color[int(frame_num)].pop(-1)
440
 
441
+ # # Redraw the frame
442
+ # current_frame_draw = video_preview[int(frame_num)].copy()
443
+ # for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
444
+ # x, y, _ = point
445
+ # current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
446
 
447
+ # # Update the query count
448
+ # query_count -= 1
449
 
450
+ # # Update the frame
451
+ # video_queried_preview[int(frame_num)] = current_frame_draw
452
+ # return (
453
+ # current_frame_draw, # Updated frame for preview
454
+ # video_queried_preview, # Updated preview video
455
+ # query_points, # Updated query points
456
+ # query_points_color, # Updated query points color
457
+ # query_count # Updated query count
458
+ # )
459
 
460
 
461
+ # def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
462
+ # query_count -= len(query_points[int(frame_num)])
463
 
464
+ # query_points[int(frame_num)] = []
465
+ # query_points_color[int(frame_num)] = []
466
 
467
+ # video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
468
 
469
+ # return (
470
+ # video_preview[int(frame_num)], # Set the preview frame to the original frame
471
+ # video_queried_preview,
472
+ # query_points, # Cleared query points
473
+ # query_points_color, # Cleared query points color
474
+ # query_count # New query count
475
+ # )
476
 
477
 
478
 
479
+ # def clear_all_fn(frame_num, video_preview):
480
+ # return (
481
+ # video_preview[int(frame_num)],
482
+ # video_preview.copy(),
483
+ # [[] for _ in range(len(video_preview))],
484
+ # [[] for _ in range(len(video_preview))],
485
+ # 0
486
+ # )
487
 
488
 
489
  def choose_frame(frame_num, video_preview_array):
 
505
  new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
506
  else:
507
  new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
508
+ if height*width > 768*768:
509
+ new_height = new_height*3//4
510
+ new_width = new_width*3//4
511
+
512
+
513
  preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
514
  # input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
515
  # input_video = video_arr
 
527
  input_video, # Resized video input for model
528
  # None, # video_feature, # Extracted feature
529
  video_fps, # Set the video FPS
530
+ # gr.update(open=True), # open/close the video input drawer
531
  # tracking_mode, # Set the tracking mode
532
  preview_video[0], # Set the preview frame to the first frame
533
  gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive
 
632
  torch.cuda.empty_cache()
633
 
634
  with torch.no_grad():
635
+ utils.basic.print_stats('video_input', video_input)
636
+ if query_frame < T-1:
637
+ flows_e, visconf_maps_e, _, _ = \
638
+ model(video_input[:, query_frame:], iters=4, sw=None, is_training=False)
639
+ traj_maps_e = flows_e.cpu() + grid_xy # B,Tf,2,H,W
640
+ visconf_maps_e = visconf_maps_e.cpu()
641
+ else:
642
+ traj_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32)
643
+ visconf_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32)
644
  if query_frame > 0:
645
  backward_flows_e, backward_visconf_maps_e, _, _ = \
646
+ model(video_input[:, :query_frame+1].flip([1]), iters=4, sw=None, is_training=False)
647
+ backward_traj_maps_e = backward_flows_e.cpu() + grid_xy # B,Tb,2,H,W, reversed
648
+ backward_visconf_maps_e = backward_visconf_maps_e.cpu()
649
+ backward_traj_maps_e = backward_traj_maps_e.flip([1]) # flip time
650
+ backward_visconf_maps_e = backward_visconf_maps_e.flip([1]) # flip time
651
+ if query_frame < T-1:
652
+ backward_traj_maps_e = backward_traj_maps_e[:, :-1] # drop the overlapped frame
653
+ backward_visconf_maps_e = backward_visconf_maps_e[:, :-1] # drop the overlapped frame
654
  traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W
655
  visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W
656
+ # if query_frame < T-1:
657
+ # flows_e, visconf_maps_e, _, _ = \
658
+ # model.forward_sliding(video_input[:, query_frame:], iters=4, sw=None, is_training=False)
659
+ # traj_maps_e = flows_e + grid_xy # B,Tf,2,H,W
660
+ # print("5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
661
+ # else:
662
+ # traj_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32)
663
+ # visconf_maps_e = torch.zeros((1,0,2,H,W), dtype=torch.float32)
664
+
665
+ # if query_frame > 0:
666
+ # backward_flows_e, backward_visconf_maps_e, _, _ = \
667
+ # model.forward_sliding(video_input[:, :query_frame+1].flip([1]), iters=4, sw=None, is_training=False)
668
+ # backward_traj_maps_e = backward_flows_e + grid_xy # B,Tb,2,H,W, reversed
669
+ # backward_traj_maps_e = backward_traj_maps_e.flip([1]) # flip time
670
+ # backward_visconf_maps_e = backward_visconf_maps_e.flip([1]) # flip time
671
+ # if query_frame < T-1:
672
+ # backward_traj_maps_e = backward_traj_maps_e[:, :-1] # drop the overlapped frame
673
+ # backward_visconf_maps_e = backward_visconf_maps_e[:, :-1] # drop the overlapped frame
674
+ # traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W
675
+ # visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W
676
  print("6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
677
 
678
  # for ind in range(0, video_input.shape[1] - model.step, model.step):
 
700
  visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
701
  confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
702
 
703
+ visibs = (visibs * confs) > 0.2 # N,T
704
+ # visibs = (confs) > 0.1 # N,T
705
 
706
 
707
  # sc = (np.array([video_preview.shape[2], video_preview.shape[1]]) / np.array([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]])).reshape(1,1,2)
 
736
  video_file_name = uuid.uuid4().hex + ".mp4"
737
  video_path = os.path.join(os.path.dirname(__file__), "tmp")
738
  video_file_path = os.path.join(video_path, video_file_name)
 
739
 
740
+ os.makedirs(video_path, exist_ok=True)
741
+ if False:
742
+ mediapy.write_video(video_file_path, painted_video, fps=video_fps)
743
+ else:
744
+ for ti in range(T):
745
+ temp_out_f = '%s/%03d.jpg' % (video_path, ti)
746
+ # temp_out_f = '%s/%03d.png' % (video_path, ti)
747
+ im = PIL.Image.fromarray(painted_video[ti])
748
+ # im.save(temp_out_f, "PNG", subsampling=0, quality=80)
749
+ im.save(temp_out_f)
750
+ print('saved', temp_out_f)
751
+ # os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
752
+ os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
753
+ print('saved', video_file_path)
754
+ for ti in range(T):
755
+ # temp_out_f = '%s/%03d.png' % (video_path, ti)
756
+ temp_out_f = '%s/%03d.jpg' % (video_path, ti)
757
+ os.remove(temp_out_f)
758
+ print('deleted', temp_out_f)
759
+
760
+ # out_file = tempfile.NamedTemporaryFile(suffix="out.mp4", delete=False)
761
+ # subprocess.run(f"ffmpeg -y -loglevel quiet -stats -i {painted_video} -c:v libx264 {out_file.name}".split())
762
+
763
+
764
 
765
  return video_file_path
766
 
 
777
  is_tracked_query = gr.State([])
778
  query_count = gr.State(0)
779
 
780
+ gr.Markdown("# AllTracker: Efficient Dense Point Tracking at High Resolution")
781
  gr.Markdown("<div style='text-align: left;'> \
782
+ <p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This space demonstrates point (pixel) tracking in videos. \
783
+ The model tracks all pixels in a frame that you select. </p> \
784
+ <p>To get started, simply upload your <b>.mp4</b> video, or click on one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
785
+ <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a frame (using the slider), and then click \"Track\".</p> \
786
+ <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub Repo</a>!</p> \
787
+ <p>Initial code for this Gradio app came from LocoTrack and CoTracker.</p> \
788
  </div>"
789
  )
790
 
791
 
792
+ gr.Markdown("## Step 1: Select a video, and click \"Submit\".")
793
  with gr.Row():
794
+ with gr.Column():
795
+ with gr.Row():
796
+ video_in = gr.Video(label="Video Input", format="mp4")
797
+ with gr.Row():
798
+ submit = gr.Button("Submit")
799
+ with gr.Column():
800
+ # with gr.Accordion("Sample videos", open=True) as video_in_drawer:
801
+ with gr.Row():
802
+ dog = os.path.join(os.path.dirname(__file__), "videos", "dog.mp4")
803
+ monkey = os.path.join(os.path.dirname(__file__), "videos", "monkey_28.mp4")
804
+ apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
805
+ bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
806
+ paragliding_launch = os.path.join(
807
+ os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
808
+ )
809
+ paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
810
+ cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
811
+ pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
812
+ teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
813
+ backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
814
+ gr.Examples(examples=[dog, monkey, bear, apple, paragliding, paragliding_launch, cat, pillow, teddy, backpack],
815
+ inputs = [
816
+ video_in
817
+ ],
818
+ )
819
+ # with gr.Column():
820
+ # gr.Markdown("Choose a video or upload one of your own.")
821
 
822
+ gr.Markdown("## Step 2: Select a frame, and click \"Track\"")
823
  with gr.Row():
824
  with gr.Column():
825
  with gr.Row():
826
  query_frames = gr.Slider(
827
  minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
828
+ # with gr.Row():
829
+ # undo = gr.Button("Undo", interactive=False)
830
+ # clear_frame = gr.Button("Clear Frame", interactive=False)
831
+ # clear_all = gr.Button("Clear All", interactive=False)
832
 
833
  with gr.Row():
834
  current_frame = gr.Image(
 
860
  video_queried_preview,
861
  video_input,
862
  video_fps,
863
+ # video_in_drawer,
864
  current_frame,
865
  query_frames,
866
  query_points,
867
  query_points_color,
868
  is_tracked_query,
869
  query_count,
870
+ # undo,
871
+ # clear_frame,
872
+ # clear_all,
873
  track_button,
874
  ],
875
  queue = False
 
884
  queue = False
885
  )
886
 
887
+ # current_frame.select(
888
+ # fn = get_point,
889
+ # inputs = [
890
+ # query_frames,
891
+ # video_queried_preview,
892
+ # query_points,
893
+ # query_points_color,
894
+ # query_count,
895
+ # ],
896
+ # outputs = [
897
+ # current_frame,
898
+ # video_queried_preview,
899
+ # query_points,
900
+ # query_points_color,
901
+ # query_count
902
+ # ],
903
+ # queue = False
904
+ # )
905
 
906
+ # undo.click(
907
+ # fn = undo_point,
908
+ # inputs = [
909
+ # query_frames,
910
+ # video_preview,
911
+ # video_queried_preview,
912
+ # query_points,
913
+ # query_points_color,
914
+ # query_count
915
+ # ],
916
+ # outputs = [
917
+ # current_frame,
918
+ # video_queried_preview,
919
+ # query_points,
920
+ # query_points_color,
921
+ # query_count
922
+ # ],
923
+ # queue = False
924
+ # )
925
+
926
+ # clear_frame.click(
927
+ # fn = clear_frame_fn,
928
+ # inputs = [
929
+ # query_frames,
930
+ # video_preview,
931
+ # video_queried_preview,
932
+ # query_points,
933
+ # query_points_color,
934
+ # query_count
935
+ # ],
936
+ # outputs = [
937
+ # current_frame,
938
+ # video_queried_preview,
939
+ # query_points,
940
+ # query_points_color,
941
+ # query_count
942
+ # ],
943
+ # queue = False
944
+ # )
945
+
946
+ # clear_all.click(
947
+ # fn = clear_all_fn,
948
+ # inputs = [
949
+ # query_frames,
950
+ # video_preview,
951
+ # ],
952
+ # outputs = [
953
+ # current_frame,
954
+ # video_queried_preview,
955
+ # query_points,
956
+ # query_points_color,
957
+ # query_count
958
+ # ],
959
+ # queue = False
960
+ # )
961
 
962
 
963
  track_button.click(
nets/alltracker.py CHANGED
@@ -236,7 +236,7 @@ class Net(nn.Module):
236
  std = torch.as_tensor([0.229, 0.224, 0.225], device=device).reshape(1,1,3,1,1).to(images.dtype)
237
  images = images / 255.0
238
  images = (images - mean)/std
239
- print("a0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
240
 
241
  T_bak = T
242
  if stride is not None:
@@ -250,7 +250,7 @@ class Net(nn.Module):
250
  padder = InputPadder(images_.shape)
251
  images_ = padder.pad(images_)[0]
252
 
253
- print("a1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
254
 
255
  _, _, H_pad, W_pad = images_.shape # revised HW
256
  C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
@@ -261,7 +261,7 @@ class Net(nn.Module):
261
 
262
  fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
263
  device = fmaps.device
264
- print("a2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
265
 
266
  fmap_anchor = fmaps[:,0]
267
 
@@ -285,11 +285,11 @@ class Net(nn.Module):
285
  if self.use_feats8:
286
  full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
287
  visits = np.zeros((T))
288
- print("a3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
289
 
290
  for ii, ind in enumerate(indices):
291
  ara = np.arange(ind,ind+S)
292
- print('ara', ara)
293
  if ii < len(indices)-1:
294
  next_ind = indices[ii+1]
295
  next_ara = np.arange(next_ind,next_ind+S)
@@ -306,12 +306,12 @@ class Net(nn.Module):
306
  feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
307
  else:
308
  feats8 = None
309
- print("a4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
310
 
311
  flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
312
  fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
313
  is_training=is_training)
314
- print("a5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
315
 
316
  unpad_flow_predictions = []
317
  unpad_visconf_predictions = []
@@ -320,7 +320,7 @@ class Net(nn.Module):
320
  unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
321
  visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
322
  unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
323
- print("a6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
324
 
325
  full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
326
  full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
@@ -329,7 +329,7 @@ class Net(nn.Module):
329
  if self.use_feats8:
330
  full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
331
  visits[ara] += 1
332
- print("a7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
333
 
334
  if is_training:
335
  all_flow_preds.append(unpad_flow_predictions)
@@ -348,7 +348,7 @@ class Net(nn.Module):
348
  full_visconfs8[:,idx] = full_visconfs8[:,nearest]
349
  if self.use_feats8:
350
  full_feats8[:,idx] = full_feats8[:,nearest]
351
- print("a8 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
352
  else: # flow
353
 
354
  flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
@@ -370,7 +370,7 @@ class Net(nn.Module):
370
  if (not is_training) and (T > 2):
371
  full_flows = full_flows[:,:T_bak]
372
  full_visconfs = full_visconfs[:,:T_bak]
373
- print("a9 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
374
 
375
  return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
376
 
 
236
  std = torch.as_tensor([0.229, 0.224, 0.225], device=device).reshape(1,1,3,1,1).to(images.dtype)
237
  images = images / 255.0
238
  images = (images - mean)/std
239
+ # print("a0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
240
 
241
  T_bak = T
242
  if stride is not None:
 
250
  padder = InputPadder(images_.shape)
251
  images_ = padder.pad(images_)[0]
252
 
253
+ # print("a1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
254
 
255
  _, _, H_pad, W_pad = images_.shape # revised HW
256
  C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
 
261
 
262
  fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
263
  device = fmaps.device
264
+ # print("a2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
265
 
266
  fmap_anchor = fmaps[:,0]
267
 
 
285
  if self.use_feats8:
286
  full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
287
  visits = np.zeros((T))
288
+ # print("a3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
289
 
290
  for ii, ind in enumerate(indices):
291
  ara = np.arange(ind,ind+S)
292
+ # print('ara', ara)
293
  if ii < len(indices)-1:
294
  next_ind = indices[ii+1]
295
  next_ara = np.arange(next_ind,next_ind+S)
 
306
  feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
307
  else:
308
  feats8 = None
309
+ # print("a4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
310
 
311
  flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
312
  fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
313
  is_training=is_training)
314
+ # print("a5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
315
 
316
  unpad_flow_predictions = []
317
  unpad_visconf_predictions = []
 
320
  unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
321
  visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
322
  unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
323
+ # print("a6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
324
 
325
  full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
326
  full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
 
329
  if self.use_feats8:
330
  full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
331
  visits[ara] += 1
332
+ # print("a7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
333
 
334
  if is_training:
335
  all_flow_preds.append(unpad_flow_predictions)
 
348
  full_visconfs8[:,idx] = full_visconfs8[:,nearest]
349
  if self.use_feats8:
350
  full_feats8[:,idx] = full_feats8[:,nearest]
351
+ # print("a8 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
352
  else: # flow
353
 
354
  flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
 
370
  if (not is_training) and (T > 2):
371
  full_flows = full_flows[:,:T_bak]
372
  full_visconfs = full_visconfs[:,:T_bak]
373
+ # print("a9 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
374
 
375
  return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
376
 
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ imageio==2.19.3
3
+ imageio-ffmpeg==0.4.7
4
+ tqdm
5
+ gradio
6
+ spaces
7
+ matplotlib
8
+ pillow
9
+ torch==2.2.0
10
+ torchvision==0.17.0
11
+ albumentations
12
+ pytorch-lightning==2.2.5
13
+ opencv-python
14
+ scikit-learn
15
+ scikit-image
16
+ einops
17
+ transformers