viranchi123 commited on
Commit
b82f700
·
verified ·
1 Parent(s): 73224de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -10,6 +10,22 @@ import plotly.graph_objects as go
10
  import plotly.express as px
11
  import open3d as o3d
12
  from depth_anything_v2.dpt import DepthAnythingV2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  css = """
15
  #img-display-container {
@@ -41,7 +57,7 @@ model_configs = {
41
  }
42
  encoder = 'vitl'
43
  model = DepthAnythingV2(**model_configs[encoder])
44
- state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
45
  model.load_state_dict(state_dict)
46
  model = model.to(DEVICE).eval()
47
 
 
10
  import plotly.express as px
11
  import open3d as o3d
12
  from depth_anything_v2.dpt import DepthAnythingV2
13
+ import os
14
+ import gdown
15
+
16
+
17
+
18
+ # Define path and file ID
19
+ checkpoint_dir = "checkpoints"
20
+ os.makedirs(checkpoint_dir, exist_ok=True)
21
+
22
+ model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
23
+ gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
24
+
25
+ # Download if not already present
26
+ if not os.path.exists(model_file):
27
+ print("Downloading model from Google Drive...")
28
+ gdown.download(gdrive_url, model_file, quiet=False)
29
 
30
  css = """
31
  #img-display-container {
 
57
  }
58
  encoder = 'vitl'
59
  model = DepthAnythingV2(**model_configs[encoder])
60
+ state_dict = torch.load(f'depth_anything_v2_{encoder}.pth', map_location="cpu")
61
  model.load_state_dict(state_dict)
62
  model = model.to(DEVICE).eval()
63