aharley commited on
Commit
c7646a0
·
1 Parent(s): 6cf1a23

added overlay option; fixed bugs

Browse files
Files changed (2) hide show
  1. app.py +48 -22
  2. utils/improc.py +1 -0
app.py CHANGED
@@ -20,6 +20,7 @@ import random
20
  from typing import List, Optional, Sequence, Tuple
21
  import spaces
22
  import numpy as np
 
23
  import utils.basic
24
  import utils.improc
25
 
@@ -105,12 +106,16 @@ def paint_point_track_gpu_scatter(
105
  visibles: np.ndarray,
106
  colormap: Optional[List[Tuple[int, int, int]]] = None,
107
  rate: int = 1,
 
108
  # sharpness: float = 0.1,
109
  ) -> np.ndarray:
110
  print('starting vis')
111
  device = "cuda" if torch.cuda.is_available() else "cpu"
112
  frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
113
- frames_t = frames_t * 0.5 # darken, to see the point tracks better
 
 
 
114
  point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
115
  visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
116
  T, C, H, W = frames_t.shape
@@ -517,14 +522,14 @@ def choose_rate8(video_preview, video_fps, tracks, visibs):
517
  # def choose_rate16(video_preview, video_fps, tracks, visibs):
518
  # return choose_rate(16, video_preview, video_fps, tracks, visibs)
519
 
520
- def update_vis(rate, cmap, video_preview, query_frame, video_fps, tracks, visibs):
521
  print('rate', rate)
522
  print('cmap', cmap)
523
  print('video_preview', video_preview.shape)
524
  T, H, W,_ = video_preview.shape
525
  tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
526
  visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
527
- return paint_video(video_preview, query_frame, video_fps, tracks_, visibs_, rate=rate, cmap=cmap)
528
  # return video_preview_array[int(frame_num)]
529
 
530
  def preprocess_video_input(video_path):
@@ -570,15 +575,19 @@ def preprocess_video_input(video_path):
570
  )
571
 
572
 
573
- def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, cmap="gist_rainbow"):
574
  print('video_preview', video_preview.shape)
575
  print('tracks', tracks.shape)
576
  T, H, W, _ = video_preview.shape
577
  query_count = tracks.shape[0]
578
  print('cmap', cmap)
579
-
580
  if cmap=="bremm":
 
581
  xy0 = tracks[:,query_frame] # N,2
 
 
 
582
  colors = utils.improc.get_2d_colors(xy0, H, W)
583
  else:
584
  cmap_ = matplotlib.colormaps.get_cmap(cmap)
@@ -594,7 +603,7 @@ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, c
594
  colors.extend(frame_colors)
595
  colors = np.array(colors)
596
 
597
- painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1))
598
  # save video
599
  video_file_name = uuid.uuid4().hex + ".mp4"
600
  video_path = os.path.join(os.path.dirname(__file__), "tmp")
@@ -609,7 +618,7 @@ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, c
609
  im = PIL.Image.fromarray(painted_video[ti])
610
  # im.save(temp_out_f, "PNG", subsampling=0, quality=80)
611
  im.save(temp_out_f)
612
- print('saved', temp_out_f)
613
  # os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
614
  os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
615
  print('saved', video_file_path)
@@ -617,16 +626,19 @@ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, c
617
  # temp_out_f = '%s/%03d.png' % (video_path, ti)
618
  temp_out_f = '%s/%03d.jpg' % (video_path, ti)
619
  os.remove(temp_out_f)
620
- print('deleted', temp_out_f)
621
  return video_file_path
622
 
623
 
624
  @spaces.GPU
625
  def track(
626
- video_preview,
627
- video_input,
628
- video_fps,
629
- query_frame,
 
 
 
630
  ):
631
  # tracking_mode = 'selected'
632
  # if query_count == 0:
@@ -774,7 +786,8 @@ def track(
774
  tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
775
  visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
776
  confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
777
- visibs = (visibs * confs) > 0.3 # N,T
 
778
  # visibs = (confs) > 0.1 # N,T
779
 
780
 
@@ -782,7 +795,7 @@ def track(
782
  # print('sc', sc)
783
  # tracks = tracks * sc
784
 
785
- return paint_video(video_preview, query_frame, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True), gr.update(interactive=True)
786
  # gr.update(interactive=True),
787
  # gr.update(interactive=True),
788
  # gr.update(interactive=True),
@@ -863,7 +876,7 @@ with gr.Blocks() as demo:
863
 
864
  gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
865
  gr.Markdown("<div style='text-align: left;'> \
866
- <p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This demo runs our model to perform all-pixel tracking in a video of your choice.</p> \
867
  <p>To get started, simply upload an mp4, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
868
  <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
869
  <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
@@ -909,7 +922,7 @@ with gr.Blocks() as demo:
909
  with gr.Column():
910
  with gr.Row():
911
  query_frame_slider = gr.Slider(
912
- minimum=0, maximum=100, value=0, step=1, label="Choose frame", interactive=False)
913
  # with gr.Row():
914
  # undo = gr.Button("Undo", interactive=False)
915
  # clear_frame = gr.Button("Clear Frame", interactive=False)
@@ -937,11 +950,12 @@ with gr.Blocks() as demo:
937
  with gr.Row():
938
  # rate_slider = gr.Slider(
939
  # minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
940
- rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Choose visualization subsampling", interactive=False)
941
-
942
  with gr.Row():
943
- cmap_radio = gr.Radio(["gist_rainbow", "rainbow", "jet", "turbo", "bremm"], value="gist_rainbow", label="Choose colormap", interactive=False)
944
-
 
 
945
  with gr.Row():
946
  output_video = gr.Video(
947
  label="Output video",
@@ -1066,12 +1080,16 @@ with gr.Blocks() as demo:
1066
  video_input,
1067
  video_fps,
1068
  query_frame_slider,
 
 
 
1069
  ],
1070
  outputs = [
1071
  output_video,
1072
  tracks,
1073
  visibs,
1074
  rate_radio,
 
1075
  cmap_radio,
1076
  # rate1_button,
1077
  # rate2_button,
@@ -1092,7 +1110,7 @@ with gr.Blocks() as demo:
1092
  # )
1093
  rate_radio.change(
1094
  fn = update_vis,
1095
- inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
1096
  outputs = [
1097
  output_video,
1098
  ],
@@ -1100,7 +1118,15 @@ with gr.Blocks() as demo:
1100
  )
1101
  cmap_radio.change(
1102
  fn = update_vis,
1103
- inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
 
 
 
 
 
 
 
 
1104
  outputs = [
1105
  output_video,
1106
  ],
 
20
  from typing import List, Optional, Sequence, Tuple
21
  import spaces
22
  import numpy as np
23
+ import utils.py
24
  import utils.basic
25
  import utils.improc
26
 
 
106
  visibles: np.ndarray,
107
  colormap: Optional[List[Tuple[int, int, int]]] = None,
108
  rate: int = 1,
109
+ show_bkg=True,
110
  # sharpness: float = 0.1,
111
  ) -> np.ndarray:
112
  print('starting vis')
113
  device = "cuda" if torch.cuda.is_available() else "cpu"
114
  frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
115
+ if show_bkg:
116
+ frames_t = frames_t * 0.5 # darken, to see the point tracks better
117
+ else:
118
+ frames_t = frames_t * 0.0 # black out
119
  point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
120
  visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
121
  T, C, H, W = frames_t.shape
 
522
  # def choose_rate16(video_preview, video_fps, tracks, visibs):
523
  # return choose_rate(16, video_preview, video_fps, tracks, visibs)
524
 
525
+ def update_vis(rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs):
526
  print('rate', rate)
527
  print('cmap', cmap)
528
  print('video_preview', video_preview.shape)
529
  T, H, W,_ = video_preview.shape
530
  tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
531
  visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
532
+ return paint_video(video_preview, query_frame, video_fps, tracks_, visibs_, rate=rate, show_bkg=show_bkg, cmap=cmap)
533
  # return video_preview_array[int(frame_num)]
534
 
535
  def preprocess_video_input(video_path):
 
575
  )
576
 
577
 
578
+ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, show_bkg=True, cmap="gist_rainbow"):
579
  print('video_preview', video_preview.shape)
580
  print('tracks', tracks.shape)
581
  T, H, W, _ = video_preview.shape
582
  query_count = tracks.shape[0]
583
  print('cmap', cmap)
584
+ print('query_frame', query_frame)
585
  if cmap=="bremm":
586
+ # xy0 = tracks
587
  xy0 = tracks[:,query_frame] # N,2
588
+ # print('xyQ', xy0[:10])
589
+ # print('xy0', tracks[:10,0])
590
+ # print('xy1', tracks[:10,1])
591
  colors = utils.improc.get_2d_colors(xy0, H, W)
592
  else:
593
  cmap_ = matplotlib.colormaps.get_cmap(cmap)
 
603
  colors.extend(frame_colors)
604
  colors = np.array(colors)
605
 
606
+ painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate,show_bkg=show_bkg)#=max(rate//2,1))
607
  # save video
608
  video_file_name = uuid.uuid4().hex + ".mp4"
609
  video_path = os.path.join(os.path.dirname(__file__), "tmp")
 
618
  im = PIL.Image.fromarray(painted_video[ti])
619
  # im.save(temp_out_f, "PNG", subsampling=0, quality=80)
620
  im.save(temp_out_f)
621
+ # print('saved', temp_out_f)
622
  # os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
623
  os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
624
  print('saved', video_file_path)
 
626
  # temp_out_f = '%s/%03d.png' % (video_path, ti)
627
  temp_out_f = '%s/%03d.jpg' % (video_path, ti)
628
  os.remove(temp_out_f)
629
+ # print('deleted', temp_out_f)
630
  return video_file_path
631
 
632
 
633
  @spaces.GPU
634
  def track(
635
+ video_preview,
636
+ video_input,
637
+ video_fps,
638
+ query_frame,
639
+ rate,
640
+ show_bkg,
641
+ cmap,
642
  ):
643
  # tracking_mode = 'selected'
644
  # if query_count == 0:
 
786
  tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
787
  visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
788
  confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
789
+ # visibs = (visibs * confs) > 0.2 # N,T
790
+ visibs = (confs) > 0.1 # N,T
791
  # visibs = (confs) > 0.1 # N,T
792
 
793
 
 
795
  # print('sc', sc)
796
  # tracks = tracks * sc
797
 
798
+ return update_vis(rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
799
  # gr.update(interactive=True),
800
  # gr.update(interactive=True),
801
  # gr.update(interactive=True),
 
876
 
877
  gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
878
  gr.Markdown("<div style='text-align: left;'> \
879
+ <p>This demo runs <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a> to perform all-pixel tracking in a video of your choice.</p> \
880
  <p>To get started, simply upload an mp4, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
881
  <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
882
  <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
 
922
  with gr.Column():
923
  with gr.Row():
924
  query_frame_slider = gr.Slider(
925
+ minimum=0, maximum=100, value=0, step=1, label="Query frame", interactive=False)
926
  # with gr.Row():
927
  # undo = gr.Button("Undo", interactive=False)
928
  # clear_frame = gr.Button("Clear Frame", interactive=False)
 
950
  with gr.Row():
951
  # rate_slider = gr.Slider(
952
  # minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
953
+ rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Subsampling rate", interactive=False)
 
954
  with gr.Row():
955
+ cmap_radio = gr.Radio(["gist_rainbow", "rainbow", "jet", "turbo", "bremm"], value="gist_rainbow", label="Colormap", interactive=False)
956
+ with gr.Row():
957
+ bkg_check = gr.Checkbox(value=True, label="Overlay tracks on video", interactive=False)
958
+
959
  with gr.Row():
960
  output_video = gr.Video(
961
  label="Output video",
 
1080
  video_input,
1081
  video_fps,
1082
  query_frame_slider,
1083
+ rate_radio,
1084
+ bkg_check,
1085
+ cmap_radio,
1086
  ],
1087
  outputs = [
1088
  output_video,
1089
  tracks,
1090
  visibs,
1091
  rate_radio,
1092
+ bkg_check,
1093
  cmap_radio,
1094
  # rate1_button,
1095
  # rate2_button,
 
1110
  # )
1111
  rate_radio.change(
1112
  fn = update_vis,
1113
+ inputs = [rate_radio, bkg_check, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
1114
  outputs = [
1115
  output_video,
1116
  ],
 
1118
  )
1119
  cmap_radio.change(
1120
  fn = update_vis,
1121
+ inputs = [rate_radio, bkg_check, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
1122
+ outputs = [
1123
+ output_video,
1124
+ ],
1125
+ queue = False
1126
+ )
1127
+ bkg_check.change(
1128
+ fn = update_vis,
1129
+ inputs = [rate_radio, bkg_check, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
1130
  outputs = [
1131
  output_video,
1132
  ],
utils/improc.py CHANGED
@@ -81,6 +81,7 @@ class ColorMap2d:
81
 
82
  def get_2d_colors(xys, H, W):
83
  N,D = xys.shape
 
84
  assert(D==2)
85
  bremm = ColorMap2d()
86
  xys[:,0] /= float(W-1)
 
81
 
82
  def get_2d_colors(xys, H, W):
83
  N,D = xys.shape
84
+ xys = xys.copy()
85
  assert(D==2)
86
  bremm = ColorMap2d()
87
  xys[:,0] /= float(W-1)