kairunwen commited on
Commit
230b87b
·
1 Parent(s): c022669

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -45
app.py CHANGED
@@ -7,7 +7,7 @@ import argparse
7
  import gradio as gr
8
  import uuid
9
  import spaces
10
- from huggingface_hub import hf_hub_download
11
  #
12
 
13
  subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
@@ -20,50 +20,27 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
20
  from src.utils.visualization_utils import render_video_from_file
21
  from src.model import LSM_MASt3R
22
 
23
- # 定义相对路径和 Hugging Face 仓库信息
24
- relative_model_dir = "checkpoints" # 文件夹名称
25
- relative_model_path = os.path.join(relative_model_dir, "checkpoint-40.pth") # 相对路径
26
- model_repo = "kairunwen/LSM" # Hugging Face 仓库
27
- model_filename = "checkpoint-40.pth" # 仓库中的文件名
28
-
29
- # 转换为绝对路径
30
- model_path = os.path.abspath(relative_model_path)
31
-
32
- try:
33
- # 创建 checkpoints 文件夹(如果不存在)
34
- os.makedirs(relative_model_dir, exist_ok=True)
35
- print(f"确保 {relative_model_dir} 文件夹存在")
36
-
37
- # 验证文件是否存在
38
- if os.path.exists(model_path):
39
- print(f"找到本地模型文件: {model_path}")
40
- else:
41
- print(f"本地模型文件 {model_path} 不存在,正在从 Hugging Face 下载...")
42
- model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
43
- # 可选:将下载的文件移动到 checkpoints 文件夹
44
- import shutil
45
- shutil.move(model_path, os.path.abspath(relative_model_path))
46
- model_path = os.path.abspath(relative_model_path)
47
- print(f"模型文件已下载并移动到: {model_path}")
48
-
49
- # 加载模型
50
- model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
51
- model = model.eval()
52
- print("模型加载成功并设置为评估模式!")
53
-
54
- except FileNotFoundError:
55
- print(f"错误: 无法找到或下载文件 {model_filename},请检查路径或仓库 {model_repo}。")
56
- except KeyError as e:
57
- print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
58
- except Exception as e:
59
- print(f"发生未知错误: {e}")
60
- # 调试:检查检查点内容
61
- ckpt = torch.load(model_path, map_location='cpu')
62
- print("检查点键:", ckpt.keys())
63
- print("config.model:", ckpt['args'].model)
64
-
65
-
66
-
67
 
68
  @spaces.GPU(duration=80)
69
  def process(inputfiles, input_path=None):
 
7
  import gradio as gr
8
  import uuid
9
  import spaces
10
+ from huggingface_hub import snapshot_download
11
  #
12
 
13
  subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
 
20
  from src.utils.visualization_utils import render_video_from_file
21
  from src.model import LSM_MASt3R
22
 
23
+ # Download the model checkpoint from Hugging Face Hub
24
+ repo_id = "Journey9ni/LSM"
25
+ remote_dir = "checkpoints/pretrained_models"
26
+ local_dir = "checkpoints/pretrained_model"
27
+ model_path_map = {
28
+ "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
29
+ "checkpoint-40.pth":"checkpoint-40.pth",
30
+ "demo_e200.ckpt":"lang_seg.ckpt"
31
+ }
32
+ os.makedirs(local_dir, exist_ok=True)
33
+ # download remote repo
34
+ snapshot_download(repo_id=repo_id, local_dir='./')
35
+
36
+ # rename the files
37
+ for remote_name, local_name in model_path_map.items():
38
+ os.rename(os.path.join(remote_dir, remote_name), os.path.join(local_dir, local_name))
39
+
40
+ # load the model
41
+ model_path = "checkpoints/pretrained_model/checkpoint-40.pth"
42
+ model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
43
+ model = model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @spaces.GPU(duration=80)
46
  def process(inputfiles, input_path=None):