jev-aleks commited on
Commit
e55f275
·
1 Parent(s): 79fa6fc

fix model loading

Browse files
Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -29,8 +29,17 @@ from demo_utils.utils import (load_modules,
29
  download_scenedino_checkpoint("ssc-kitti-360-dino")
30
  download_scenedino_checkpoint("ssc-kitti-360-dinov2")
31
 
32
- net_v1, renderer_v1, ray_sampler_v1 = None, None, None
33
- net_v2, renderer_v2, ray_sampler_v2 = None, None, None
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  def convert_voxels(arr, map_dict):
@@ -41,20 +50,6 @@ with open("sscbench/label_maps.yaml", "r") as f:
41
  label_maps = yaml.safe_load(f)
42
 
43
 
44
- @spaces.GPU
45
- def load_scenedino():
46
- # Load model, ray sampler, datasets
47
- ckpt_path = "out/scenedino-pretrained/seg-best-dino/"
48
- ckpt_name = "checkpoint.pt"
49
- net_v1, renderer_v1, ray_sampler_v1 = load_modules(ckpt_path, ckpt_name)
50
- renderer_v1.eval()
51
-
52
- ckpt_path = "out/scenedino-pretrained/seg-best-dinov2/"
53
- ckpt_name = "checkpoint.pt"
54
- net_v2, renderer_v2, ray_sampler_v2 = load_modules(ckpt_path, ckpt_name)
55
- renderer_v2.eval()
56
-
57
-
58
  @spaces.GPU
59
  def demo_run(image: str,
60
  backbone: str,
@@ -65,9 +60,6 @@ def demo_run(image: str,
65
  y_range: int,
66
  z_range: int):
67
 
68
- if net_v1 is None:
69
- load_scenedino()
70
-
71
  if backbone == "DINO (ViT-B)":
72
  net, renderer, ray_sampler = net_v1, renderer_v1, ray_sampler_v1
73
  elif backbone == "DINOv2 (ViT-B)":
 
29
  download_scenedino_checkpoint("ssc-kitti-360-dino")
30
  download_scenedino_checkpoint("ssc-kitti-360-dinov2")
31
 
32
+
33
+ # Load model, ray sampler, datasets
34
+ ckpt_path = "out/scenedino-pretrained/seg-best-dino/"
35
+ ckpt_name = "checkpoint.pt"
36
+ net_v1, renderer_v1, ray_sampler_v1 = load_modules(ckpt_path, ckpt_name)
37
+ renderer_v1.eval()
38
+
39
+ ckpt_path = "out/scenedino-pretrained/seg-best-dinov2/"
40
+ ckpt_name = "checkpoint.pt"
41
+ net_v2, renderer_v2, ray_sampler_v2 = load_modules(ckpt_path, ckpt_name)
42
+ renderer_v2.eval()
43
 
44
 
45
  def convert_voxels(arr, map_dict):
 
50
  label_maps = yaml.safe_load(f)
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @spaces.GPU
54
  def demo_run(image: str,
55
  backbone: str,
 
60
  y_range: int,
61
  z_range: int):
62
 
 
 
 
63
  if backbone == "DINO (ViT-B)":
64
  net, renderer, ray_sampler = net_v1, renderer_v1, ray_sampler_v1
65
  elif backbone == "DINOv2 (ViT-B)":