brandonsmart commited on
Commit
c64f462
·
1 Parent(s): c65a25a

Another attempt to solve the pickling issues

Browse files
Files changed (1) hide show
  1. demo.py +4 -3
demo.py CHANGED
@@ -24,9 +24,11 @@ import main
24
  import utils.export as export
25
 
26
  @spaces.GPU(duration=15)
27
- def get_reconstructed_scene(outdir, model, silent, image_size, ios_mode, filelist):
28
 
 
29
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
30
 
31
  assert len(filelist) == 1 or len(filelist) == 2, "Please provide one or two images"
32
  if ios_mode:
@@ -57,7 +59,6 @@ if __name__ == '__main__':
57
  model_name = "brandonsmart/splatt3r_v1.0"
58
  filename = "epoch=19-step=1200.ckpt"
59
  weights_path = hf_hub_download(repo_id=model_name, filename=filename)
60
- model = main.MAST3RGaussians.load_from_checkpoint(weights_path, 'cpu')
61
  chkpt_tag = hash_md5(weights_path)
62
 
63
  # Define example inputs and their corresponding precalculated outputs
@@ -89,7 +90,7 @@ if __name__ == '__main__':
89
  cache_path = os.path.join(tmpdirname, chkpt_tag)
90
  os.makedirs(cache_path, exist_ok=True)
91
 
92
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, silent, image_size, ios_mode)
93
 
94
  if not ios_mode:
95
  for i in range(len(examples)):
 
24
  import utils.export as export
25
 
26
  @spaces.GPU(duration=15)
27
+ def get_reconstructed_scene(outdir, weights_path, silent, image_size, ios_mode, filelist):
28
 
29
+ # @TEMP: Temporarily instantiating the model here every time to avoid pickling issues with Hugging Face Spaces
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+ model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device)
32
 
33
  assert len(filelist) == 1 or len(filelist) == 2, "Please provide one or two images"
34
  if ios_mode:
 
59
  model_name = "brandonsmart/splatt3r_v1.0"
60
  filename = "epoch=19-step=1200.ckpt"
61
  weights_path = hf_hub_download(repo_id=model_name, filename=filename)
 
62
  chkpt_tag = hash_md5(weights_path)
63
 
64
  # Define example inputs and their corresponding precalculated outputs
 
90
  cache_path = os.path.join(tmpdirname, chkpt_tag)
91
  os.makedirs(cache_path, exist_ok=True)
92
 
93
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, weights_path, silent, image_size, ios_mode)
94
 
95
  if not ios_mode:
96
  for i in range(len(examples)):