aharley commited on
Commit
6cf1a23
·
1 Parent(s): dabf756

added colormap options

Browse files
Files changed (2) hide show
  1. app.py +45 -54
  2. utils/improc.py +1 -1
app.py CHANGED
@@ -517,15 +517,16 @@ def choose_rate8(video_preview, video_fps, tracks, visibs):
517
  # def choose_rate16(video_preview, video_fps, tracks, visibs):
518
  # return choose_rate(16, video_preview, video_fps, tracks, visibs)
519
 
520
- def choose_rate(rate, video_preview, video_fps, tracks, visibs):
521
  print('rate', rate)
 
522
  print('video_preview', video_preview.shape)
523
  T, H, W,_ = video_preview.shape
524
  tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
525
  visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
526
- return paint_video(video_preview, video_fps, tracks_, visibs_, rate=rate)
527
  # return video_preview_array[int(frame_num)]
528
-
529
  def preprocess_video_input(video_path):
530
  video_arr = mediapy.read_video(video_path)
531
  video_fps = video_arr.metadata.fps
@@ -553,27 +554,15 @@ def preprocess_video_input(video_path):
553
  preview_video = np.array(preview_video)
554
  input_video = np.array(input_video)
555
 
556
- interactive = True
557
-
558
  return (
559
  video_arr, # Original video
560
  preview_video, # Original preview video, resized for faster processing
561
  preview_video.copy(), # Copy of preview video for visualization
562
  input_video, # Resized video input for model
563
- # None, # video_feature, # Extracted feature
564
  video_fps, # Set the video FPS
565
- # gr.update(open=True), # open/close the video input drawer
566
- # tracking_mode, # Set the tracking mode
567
  preview_video[0], # Set the preview frame to the first frame
568
- gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive
569
- [[] for _ in range(num_frames)], # Set query_points to empty
570
- [[] for _ in range(num_frames)], # Set query_points_color to empty
571
- [[] for _ in range(num_frames)],
572
- 0, # Set query count to 0
573
- gr.update(interactive=interactive), # Make the buttons interactive
574
- gr.update(interactive=interactive),
575
- gr.update(interactive=interactive),
576
- gr.update(interactive=True),
577
  # gr.update(interactive=True),
578
  # gr.update(interactive=True),
579
  # gr.update(interactive=True),
@@ -581,22 +570,30 @@ def preprocess_video_input(video_path):
581
  )
582
 
583
 
584
- def paint_video(video_preview, video_fps, tracks, visibs, rate=1):
585
  print('video_preview', video_preview.shape)
 
586
  T, H, W, _ = video_preview.shape
587
  query_count = tracks.shape[0]
588
- cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
589
- query_points_color = [[]]
590
- for i in range(query_count):
591
- # Choose the color for the point from matplotlib colormap
592
- color = cmap(i / float(query_count))
593
- color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
594
- query_points_color[0].append(color)
595
- # make color array
596
- colors = []
597
- for frame_colors in query_points_color:
598
- colors.extend(frame_colors)
599
- colors = np.array(colors)
 
 
 
 
 
 
 
600
  painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1))
601
  # save video
602
  video_file_name = uuid.uuid4().hex + ".mp4"
@@ -630,9 +627,6 @@ def track(
630
  video_input,
631
  video_fps,
632
  query_frame,
633
- query_points,
634
- query_points_color,
635
- query_count,
636
  ):
637
  # tracking_mode = 'selected'
638
  # if query_count == 0:
@@ -788,7 +782,7 @@ def track(
788
  # print('sc', sc)
789
  # tracks = tracks * sc
790
 
791
- return paint_video(video_preview, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True, value=1)
792
  # gr.update(interactive=True),
793
  # gr.update(interactive=True),
794
  # gr.update(interactive=True),
@@ -863,11 +857,6 @@ with gr.Blocks() as demo:
863
  video_input = gr.State()
864
  video_fps = gr.State(24)
865
 
866
- query_points = gr.State([])
867
- query_points_color = gr.State([])
868
- is_tracked_query = gr.State([])
869
- query_count = gr.State(0)
870
-
871
  # rate = gr.State([])
872
  tracks = gr.State([])
873
  visibs = gr.State([])
@@ -875,14 +864,13 @@ with gr.Blocks() as demo:
875
  gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
876
  gr.Markdown("<div style='text-align: left;'> \
877
  <p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This demo runs our model to perform all-pixel tracking in a video of your choice.</p> \
878
- <p>To get started, simply upload your <b>.mp4</b> video, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
879
  <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
880
  <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
881
  <p>Initial code for this Gradio app came from LocoTrack and CoTracker -- big thanks to those authors!</p> \
882
  </div>"
883
  )
884
 
885
-
886
  gr.Markdown("## Step 1: Select a video, and click \"Submit\".")
887
  with gr.Row():
888
  with gr.Column():
@@ -891,7 +879,6 @@ with gr.Blocks() as demo:
891
  with gr.Row():
892
  submit = gr.Button("Submit")
893
  with gr.Column():
894
- # with gr.Accordion("Sample videos", open=True) as video_in_drawer:
895
  with gr.Row():
896
  butterfly = os.path.join(os.path.dirname(__file__), "videos", "butterfly_800.mp4")
897
  monkey = os.path.join(os.path.dirname(__file__), "videos", "monkey_800.mp4")
@@ -951,6 +938,9 @@ with gr.Blocks() as demo:
951
  # rate_slider = gr.Slider(
952
  # minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
953
  rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Choose visualization subsampling", interactive=False)
 
 
 
954
 
955
  with gr.Row():
956
  output_video = gr.Video(
@@ -971,13 +961,8 @@ with gr.Blocks() as demo:
971
  video_queried_preview,
972
  video_input,
973
  video_fps,
974
- # video_in_drawer,
975
  current_frame,
976
  query_frame_slider,
977
- query_points,
978
- query_points_color,
979
- is_tracked_query,
980
- query_count,
981
  # undo,
982
  # clear_frame,
983
  # clear_all,
@@ -1081,15 +1066,13 @@ with gr.Blocks() as demo:
1081
  video_input,
1082
  video_fps,
1083
  query_frame_slider,
1084
- query_points,
1085
- query_points_color,
1086
- query_count,
1087
  ],
1088
  outputs = [
1089
  output_video,
1090
  tracks,
1091
  visibs,
1092
  rate_radio,
 
1093
  # rate1_button,
1094
  # rate2_button,
1095
  # rate4_button,
@@ -1108,8 +1091,16 @@ with gr.Blocks() as demo:
1108
  # queue = False
1109
  # )
1110
  rate_radio.change(
1111
- fn = choose_rate,
1112
- inputs = [rate_radio, video_preview, video_fps, tracks, visibs],
 
 
 
 
 
 
 
 
1113
  outputs = [
1114
  output_video,
1115
  ],
@@ -1153,5 +1144,5 @@ with gr.Blocks() as demo:
1153
 
1154
 
1155
  # demo.launch(show_api=False, show_error=True, debug=False, share=False)
1156
- # demo.launch(show_api=False, show_error=True, debug=False, share=True)
1157
- demo.launch(show_api=False, show_error=True, debug=False, share=False)
 
517
  # def choose_rate16(video_preview, video_fps, tracks, visibs):
518
  # return choose_rate(16, video_preview, video_fps, tracks, visibs)
519
 
520
+ def update_vis(rate, cmap, video_preview, query_frame, video_fps, tracks, visibs):
521
  print('rate', rate)
522
+ print('cmap', cmap)
523
  print('video_preview', video_preview.shape)
524
  T, H, W,_ = video_preview.shape
525
  tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
526
  visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
527
+ return paint_video(video_preview, query_frame, video_fps, tracks_, visibs_, rate=rate, cmap=cmap)
528
  # return video_preview_array[int(frame_num)]
529
+
530
  def preprocess_video_input(video_path):
531
  video_arr = mediapy.read_video(video_path)
532
  video_fps = video_arr.metadata.fps
 
554
  preview_video = np.array(preview_video)
555
  input_video = np.array(input_video)
556
 
 
 
557
  return (
558
  video_arr, # Original video
559
  preview_video, # Original preview video, resized for faster processing
560
  preview_video.copy(), # Copy of preview video for visualization
561
  input_video, # Resized video input for model
 
562
  video_fps, # Set the video FPS
 
 
563
  preview_video[0], # Set the preview frame to the first frame
564
+ gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True), # Set slider interactive
565
+ gr.update(interactive=True), # make track button interactive
 
 
 
 
 
 
 
566
  # gr.update(interactive=True),
567
  # gr.update(interactive=True),
568
  # gr.update(interactive=True),
 
570
  )
571
 
572
 
573
+ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, cmap="gist_rainbow"):
574
  print('video_preview', video_preview.shape)
575
+ print('tracks', tracks.shape)
576
  T, H, W, _ = video_preview.shape
577
  query_count = tracks.shape[0]
578
+ print('cmap', cmap)
579
+
580
+ if cmap=="bremm":
581
+ xy0 = tracks[:,query_frame] # N,2
582
+ colors = utils.improc.get_2d_colors(xy0, H, W)
583
+ else:
584
+ cmap_ = matplotlib.colormaps.get_cmap(cmap)
585
+ query_points_color = [[]]
586
+ for i in range(query_count):
587
+ # Choose the color for the point from matplotlib colormap
588
+ color = cmap_(i / float(query_count))
589
+ color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
590
+ query_points_color[0].append(color)
591
+ # make color array
592
+ colors = []
593
+ for frame_colors in query_points_color:
594
+ colors.extend(frame_colors)
595
+ colors = np.array(colors)
596
+
597
  painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1))
598
  # save video
599
  video_file_name = uuid.uuid4().hex + ".mp4"
 
627
  video_input,
628
  video_fps,
629
  query_frame,
 
 
 
630
  ):
631
  # tracking_mode = 'selected'
632
  # if query_count == 0:
 
782
  # print('sc', sc)
783
  # tracks = tracks * sc
784
 
785
+ return paint_video(video_preview, query_frame, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True), gr.update(interactive=True)
786
  # gr.update(interactive=True),
787
  # gr.update(interactive=True),
788
  # gr.update(interactive=True),
 
857
  video_input = gr.State()
858
  video_fps = gr.State(24)
859
 
 
 
 
 
 
860
  # rate = gr.State([])
861
  tracks = gr.State([])
862
  visibs = gr.State([])
 
864
  gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
865
  gr.Markdown("<div style='text-align: left;'> \
866
  <p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This demo runs our model to perform all-pixel tracking in a video of your choice.</p> \
867
+ <p>To get started, simply upload an mp4, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
868
  <p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
869
  <p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
870
  <p>Initial code for this Gradio app came from LocoTrack and CoTracker -- big thanks to those authors!</p> \
871
  </div>"
872
  )
873
 
 
874
  gr.Markdown("## Step 1: Select a video, and click \"Submit\".")
875
  with gr.Row():
876
  with gr.Column():
 
879
  with gr.Row():
880
  submit = gr.Button("Submit")
881
  with gr.Column():
 
882
  with gr.Row():
883
  butterfly = os.path.join(os.path.dirname(__file__), "videos", "butterfly_800.mp4")
884
  monkey = os.path.join(os.path.dirname(__file__), "videos", "monkey_800.mp4")
 
938
  # rate_slider = gr.Slider(
939
  # minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
940
  rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Choose visualization subsampling", interactive=False)
941
+
942
+ with gr.Row():
943
+ cmap_radio = gr.Radio(["gist_rainbow", "rainbow", "jet", "turbo", "bremm"], value="gist_rainbow", label="Choose colormap", interactive=False)
944
 
945
  with gr.Row():
946
  output_video = gr.Video(
 
961
  video_queried_preview,
962
  video_input,
963
  video_fps,
 
964
  current_frame,
965
  query_frame_slider,
 
 
 
 
966
  # undo,
967
  # clear_frame,
968
  # clear_all,
 
1066
  video_input,
1067
  video_fps,
1068
  query_frame_slider,
 
 
 
1069
  ],
1070
  outputs = [
1071
  output_video,
1072
  tracks,
1073
  visibs,
1074
  rate_radio,
1075
+ cmap_radio,
1076
  # rate1_button,
1077
  # rate2_button,
1078
  # rate4_button,
 
1091
  # queue = False
1092
  # )
1093
  rate_radio.change(
1094
+ fn = update_vis,
1095
+ inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
1096
+ outputs = [
1097
+ output_video,
1098
+ ],
1099
+ queue = False
1100
+ )
1101
+ cmap_radio.change(
1102
+ fn = update_vis,
1103
+ inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
1104
  outputs = [
1105
  output_video,
1106
  ],
 
1144
 
1145
 
1146
  # demo.launch(show_api=False, show_error=True, debug=False, share=False)
1147
+ demo.launch(show_api=False, show_error=True, debug=False, share=True)
1148
+ # demo.launch(show_api=False, show_error=True, debug=False, share=False)
utils/improc.py CHANGED
@@ -58,7 +58,7 @@ def flow2color(flow, clip=0.0):
58
  flow = (flow*255.0).type(torch.ByteTensor)
59
  return flow
60
 
61
- COLORMAP_FILE = "./utils/bremm.png"
62
  class ColorMap2d:
63
  def __init__(self, filename=None):
64
  self._colormap_file = filename or COLORMAP_FILE
 
58
  flow = (flow*255.0).type(torch.ByteTensor)
59
  return flow
60
 
61
+ COLORMAP_FILE = "./bremm.png"
62
  class ColorMap2d:
63
  def __init__(self, filename=None):
64
  self._colormap_file = filename or COLORMAP_FILE