aharley commited on
Commit
0afbdda
·
verified ·
1 Parent(s): 6d95ea1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -18
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import random
 
 
3
  import cv2
4
  import numpy as np
5
  import torch
@@ -17,6 +19,7 @@ import utils.misc
17
  import utils.saveload
18
  from nets.blocks import InputPadder
19
  from nets.net34 import Net
 
20
  import imageio
21
  from demo_dense_visualize import Tracker
22
  import spaces
@@ -47,25 +50,11 @@ seed_everything(42)
47
  torch.set_grad_enabled(False)
48
 
49
  # -------------------- Model Loading -------------------- #
50
- # Adjust these paths as needed.
51
- init_dir = '648Ai4i4i3n4s_1e-5m_c5c_stage3_from_kub_ns_wa_kk_lsh_dyk_46470'
52
- ckpt_dir = 'checkpoints'
53
- load_dir = os.path.join(ckpt_dir, init_dir)
54
-
55
- # Create the model and load weights.
56
  model = Net(16)
57
  count_parameters(model)
58
- _ = utils.saveload.load(
59
- None,
60
- load_dir,
61
- model,
62
- optimizer=None,
63
- scheduler=None,
64
- ignore_load=None,
65
- strict=True,
66
- verbose=False,
67
- weights_only=False,
68
- )
69
  model.cuda()
70
  for n, p in model.named_parameters():
71
  p.requires_grad = False
@@ -260,7 +249,7 @@ if __name__ == '__main__':
260
 
261
  with gr.Row():
262
  with gr.Column():
263
- video_input = gr.Video(label="Upload Video", value="data/172620-847860540_small.mp4")
264
  extract_btn = gr.Button("Extract First Frame")
265
  # Add sliders for resolution and sliding window length.
266
  resolution_slider = gr.Slider(minimum=512, maximum=1024, step=256, value=1024, label="Target Resolution")
 
1
  import os
2
  import random
3
+ import time
4
+ import datetime
5
  import cv2
6
  import numpy as np
7
  import torch
 
19
  import utils.saveload
20
  from nets.blocks import InputPadder
21
  from nets.net34 import Net
22
+ from tensorboardX import SummaryWriter
23
  import imageio
24
  from demo_dense_visualize import Tracker
25
  import spaces
 
50
  torch.set_grad_enabled(False)
51
 
52
  # -------------------- Model Loading -------------------- #
53
+ url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth"
54
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
 
 
 
 
55
  model = Net(16)
56
  count_parameters(model)
57
+ model.load_state_dict(state_dict, strict=True)
 
 
 
 
 
 
 
 
 
 
58
  model.cuda()
59
  for n, p in model.named_parameters():
60
  p.requires_grad = False
 
249
 
250
  with gr.Row():
251
  with gr.Column():
252
+ video_input = gr.Video(label="Upload Video", value="data/244754_medium.mp4")
253
  extract_btn = gr.Button("Extract First Frame")
254
  # Add sliders for resolution and sliding window length.
255
  resolution_slider = gr.Slider(minimum=512, maximum=1024, step=256, value=1024, label="Target Resolution")