kairunwen commited on
Commit
c022669
·
1 Parent(s): 2c2ef94
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -20,22 +20,31 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
20
  from src.utils.visualization_utils import render_video_from_file
21
  from src.model import LSM_MASt3R
22
 
23
- # Assuming your model has been uploaded to HuggingFace
24
- model_repo = "kairunwen/LSM" # Replace with the actual repository name
25
- model_filename = "checkpoint-40.pth" # Model filename
26
-
27
- # Download model from HuggingFace
28
- # model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
29
-
30
- # Load model
31
- # model = LSM_MASt3R.from_pretrained(model_path)
32
- # model = model.eval()
33
 
 
 
34
 
35
  try:
36
- # 下载模型文件
37
- model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
38
- print(f"模型文件已下载到: {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # 加载模型
41
  model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
@@ -43,11 +52,15 @@ try:
43
  print("模型加载成功并设置为评估模式!")
44
 
45
  except FileNotFoundError:
46
- print(f"错误: 无法找到文件 {model_filename},请检查仓库 {model_repo} 是否正确上传文件。")
47
  except KeyError as e:
48
  print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
49
  except Exception as e:
50
  print(f"发生未知错误: {e}")
 
 
 
 
51
 
52
 
53
 
 
20
  from src.utils.visualization_utils import render_video_from_file
21
  from src.model import LSM_MASt3R
22
 
23
+ # 定义相对路径和 Hugging Face 仓库信息
24
+ relative_model_dir = "checkpoints" # 文件夹名称
25
+ relative_model_path = os.path.join(relative_model_dir, "checkpoint-40.pth") # 相对路径
26
+ model_repo = "kairunwen/LSM" # Hugging Face 仓库
27
+ model_filename = "checkpoint-40.pth" # 仓库中的文件名
 
 
 
 
 
28
 
29
+ # 转换为绝对路径
30
+ model_path = os.path.abspath(relative_model_path)
31
 
32
  try:
33
+ # 创建 checkpoints 文件夹(如果不存在)
34
+ os.makedirs(relative_model_dir, exist_ok=True)
35
+ print(f"确保 {relative_model_dir} 文件夹存在")
36
+
37
+ # 验证文件是否存在
38
+ if os.path.exists(model_path):
39
+ print(f"找到本地模型文件: {model_path}")
40
+ else:
41
+ print(f"本地模型文件 {model_path} 不存在,正在从 Hugging Face 下载...")
42
+ model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
43
+ # 可选:将下载的文件移动到 checkpoints 文件夹
44
+ import shutil
45
+ shutil.move(model_path, os.path.abspath(relative_model_path))
46
+ model_path = os.path.abspath(relative_model_path)
47
+ print(f"模型文件已下载并移动到: {model_path}")
48
 
49
  # 加载模型
50
  model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
 
52
  print("模型加载成功并设置为评估模式!")
53
 
54
  except FileNotFoundError:
55
+ print(f"错误: 无法找到或下载文件 {model_filename},请检查路径或仓库 {model_repo}")
56
  except KeyError as e:
57
  print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
58
  except Exception as e:
59
  print(f"发生未知错误: {e}")
60
+ # 调试:检查检查点内容
61
+ ckpt = torch.load(model_path, map_location='cpu')
62
+ print("检查点键:", ckpt.keys())
63
+ print("config.model:", ckpt['args'].model)
64
 
65
 
66