aharley commited on
Commit
5426cac
·
1 Parent(s): 574fdd2

made fancy

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -103,12 +103,13 @@ def paint_point_track_gpu_scatter(
103
  point_tracks: np.ndarray,
104
  visibles: np.ndarray,
105
  colormap: Optional[List[Tuple[int, int, int]]] = None,
106
- radius: int = 2,
107
  sharpness: float = 0.15,
108
  ) -> np.ndarray:
109
  print('starting vis')
110
  device = "cuda" if torch.cuda.is_available() else "cpu"
111
  frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
 
112
  point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
113
  visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
114
  T, C, H, W = frames_t.shape
@@ -159,7 +160,8 @@ def paint_point_track_gpu_scatter(
159
  # frames_t[t] = frames_t[t] * (1 - weight) + accum
160
 
161
  # alpha = weight.clamp(0, 1)
162
- alpha = weight.clamp(0, 1) * 0.75 # transparency
 
163
  accum = accum / (weight + 1e-6) # [3, H, W]
164
  frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha
165
 
@@ -256,7 +258,7 @@ def paint_point_track_parallel(
256
  if colormap is None:
257
  colormap = get_colors(num_colors=num_points)
258
  height, width = frames.shape[1:3]
259
- radius = 2
260
  print('radius', radius)
261
  diam = radius * 2 + 1
262
  # Precompute the icon and its bilinear components
@@ -499,15 +501,15 @@ def preprocess_video_input(video_path):
499
  video_arr = video_arr[:FRAME_LIMIT]
500
  num_frames = FRAME_LIMIT
501
 
502
- # Resize to preview size for faster processing, width = PREVIEW_WIDTH
503
  height, width = video_arr.shape[1:3]
504
  if height > width:
505
  new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
506
  else:
507
  new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
508
- if height*width > 768*768:
509
  new_height = new_height*3//4
510
  new_width = new_width*3//4
 
511
 
512
 
513
  preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
@@ -693,6 +695,7 @@ def track(
693
 
694
  # traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
695
  # visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
 
696
  traj_maps_e = traj_maps_e[:,:,:,::2,::2] # subsample
697
  visconf_maps_e = visconf_maps_e[:,:,:,::2,::2] # subsample
698
 
@@ -722,7 +725,9 @@ def track(
722
  colors.extend(frame_colors)
723
  colors = np.array(colors)
724
 
725
- inds = np.sum(visibs * 1.0, axis=1) >= min(T//4,3)
 
 
726
  tracks = tracks[inds]
727
  visibs = visibs[inds]
728
  colors = colors[inds]
@@ -779,8 +784,7 @@ with gr.Blocks() as demo:
779
 
780
  gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
781
  gr.Markdown("<div style='text-align: left;'> \
782
- <p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This space demonstrates point (pixel) tracking in videos. \
783
- The model tracks all pixels in a frame that you select. </p> \
784
  <p>To get started, simply upload your <b>.mp4</b> video, or click on one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
785
  <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a frame (using the slider), and then click \"Track\".</p> \
786
  <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub Repo</a>!</p> \
@@ -819,11 +823,11 @@ with gr.Blocks() as demo:
819
  # with gr.Column():
820
  # gr.Markdown("Choose a video or upload one of your own.")
821
 
822
- gr.Markdown("## Step 2: Select a frame, and click \"Track\"")
823
  with gr.Row():
824
  with gr.Column():
825
  with gr.Row():
826
- query_frames = gr.Slider(
827
  minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
828
  # with gr.Row():
829
  # undo = gr.Button("Undo", interactive=False)
@@ -842,6 +846,10 @@ with gr.Blocks() as demo:
842
  track_button = gr.Button("Track", interactive=False)
843
 
844
  with gr.Column():
 
 
 
 
845
  output_video = gr.Video(
846
  label="Output Video",
847
  interactive=False,
@@ -862,7 +870,7 @@ with gr.Blocks() as demo:
862
  video_fps,
863
  # video_in_drawer,
864
  current_frame,
865
- query_frames,
866
  query_points,
867
  query_points_color,
868
  is_tracked_query,
@@ -875,9 +883,9 @@ with gr.Blocks() as demo:
875
  queue = False
876
  )
877
 
878
- query_frames.change(
879
  fn = choose_frame,
880
- inputs = [query_frames, video_queried_preview],
881
  outputs = [
882
  current_frame,
883
  ],
@@ -959,6 +967,7 @@ with gr.Blocks() as demo:
959
  # queue = False
960
  # )
961
 
 
962
 
963
  track_button.click(
964
  fn = track,
@@ -966,7 +975,7 @@ with gr.Blocks() as demo:
966
  video_preview,
967
  video_input,
968
  video_fps,
969
- query_frames,
970
  query_points,
971
  query_points_color,
972
  query_count,
@@ -978,5 +987,7 @@ with gr.Blocks() as demo:
978
  )
979
 
980
 
 
 
981
  # demo.launch(show_api=False, show_error=True, debug=False, share=False)
982
  demo.launch(show_api=False, show_error=True, debug=False, share=True)
 
103
  point_tracks: np.ndarray,
104
  visibles: np.ndarray,
105
  colormap: Optional[List[Tuple[int, int, int]]] = None,
106
+ radius: int = 1,
107
  sharpness: float = 0.15,
108
  ) -> np.ndarray:
109
  print('starting vis')
110
  device = "cuda" if torch.cuda.is_available() else "cpu"
111
  frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
112
+ frames_t = frames_t * 0.5 # darken, to see the point tracks better
113
  point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
114
  visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
115
  T, C, H, W = frames_t.shape
 
160
  # frames_t[t] = frames_t[t] * (1 - weight) + accum
161
 
162
  # alpha = weight.clamp(0, 1)
163
+ # alpha = weight.clamp(0, 1) * 0.75 # transparency
164
+ alpha = weight.clamp(0, 1) # transparency
165
  accum = accum / (weight + 1e-6) # [3, H, W]
166
  frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha
167
 
 
258
  if colormap is None:
259
  colormap = get_colors(num_colors=num_points)
260
  height, width = frames.shape[1:3]
261
+ radius = 1
262
  print('radius', radius)
263
  diam = radius * 2 + 1
264
  # Precompute the icon and its bilinear components
 
501
  video_arr = video_arr[:FRAME_LIMIT]
502
  num_frames = FRAME_LIMIT
503
 
 
504
  height, width = video_arr.shape[1:3]
505
  if height > width:
506
  new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
507
  else:
508
  new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
509
+ if height*width > 768*1024:
510
  new_height = new_height*3//4
511
  new_width = new_width*3//4
512
+ new_height, new_width = new_height//8 * 8, new_width//8 * 8 # make it divisible by 8, partly to satisfy ffmpeg
513
 
514
 
515
  preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
 
695
 
696
  # traj_maps_e = traj_maps_e[:,:,:,::4,::4] # subsample
697
  # visconf_maps_e = visconf_maps_e[:,:,:,::4,::4] # subsample
698
+
699
  traj_maps_e = traj_maps_e[:,:,:,::2,::2] # subsample
700
  visconf_maps_e = visconf_maps_e[:,:,:,::2,::2] # subsample
701
 
 
725
  colors.extend(frame_colors)
726
  colors = np.array(colors)
727
 
728
+ visibs_ = visibs * 1.0
729
+ visibs_ = visibs_[:,1:] * visibs_[:,:-1]
730
+ inds = np.sum(visibs_, axis=1) >= min(T//4,8)
731
  tracks = tracks[inds]
732
  visibs = visibs[inds]
733
  colors = colors[inds]
 
784
 
785
  gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
786
  gr.Markdown("<div style='text-align: left;'> \
787
+ <p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This space demonstrates all-pixel tracking in videos.</p> \
 
788
  <p>To get started, simply upload your <b>.mp4</b> video, or click on one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
789
  <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a frame (using the slider), and then click \"Track\".</p> \
790
  <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub Repo</a>!</p> \
 
823
  # with gr.Column():
824
  # gr.Markdown("Choose a video or upload one of your own.")
825
 
826
+ gr.Markdown("## Step 2: Select a frame, and click \"Track\".")
827
  with gr.Row():
828
  with gr.Column():
829
  with gr.Row():
830
+ query_frame_slider = gr.Slider(
831
  minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
832
  # with gr.Row():
833
  # undo = gr.Button("Undo", interactive=False)
 
846
  track_button = gr.Button("Track", interactive=False)
847
 
848
  with gr.Column():
849
+ # with gr.Row():
850
+ # rate_slider = gr.Slider(
851
+ # minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
852
+ # with gr.Row():
853
  output_video = gr.Video(
854
  label="Output Video",
855
  interactive=False,
 
870
  video_fps,
871
  # video_in_drawer,
872
  current_frame,
873
+ query_frame_slider,
874
  query_points,
875
  query_points_color,
876
  is_tracked_query,
 
883
  queue = False
884
  )
885
 
886
+ query_frame_slider.change(
887
  fn = choose_frame,
888
+ inputs = [query_frame_slider, video_queried_preview],
889
  outputs = [
890
  current_frame,
891
  ],
 
967
  # queue = False
968
  # )
969
 
970
+ # output_video = None
971
 
972
  track_button.click(
973
  fn = track,
 
975
  video_preview,
976
  video_input,
977
  video_fps,
978
+ query_frame_slider,
979
  query_points,
980
  query_points_color,
981
  query_count,
 
987
  )
988
 
989
 
990
+
991
+
992
  # demo.launch(show_api=False, show_error=True, debug=False, share=False)
993
  demo.launch(show_api=False, show_error=True, debug=False, share=True)