Debug
Browse files
app.py
CHANGED
@@ -20,22 +20,31 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
|
|
20 |
from src.utils.visualization_utils import render_video_from_file
|
21 |
from src.model import LSM_MASt3R
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
# model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
|
29 |
-
|
30 |
-
# Load model
|
31 |
-
# model = LSM_MASt3R.from_pretrained(model_path)
|
32 |
-
# model = model.eval()
|
33 |
|
|
|
|
|
34 |
|
35 |
try:
|
36 |
-
#
|
37 |
-
|
38 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# 加载模型
|
41 |
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
@@ -43,11 +52,15 @@ try:
|
|
43 |
print("模型加载成功并设置为评估模式!")
|
44 |
|
45 |
except FileNotFoundError:
|
46 |
-
print(f"错误:
|
47 |
except KeyError as e:
|
48 |
print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
|
49 |
except Exception as e:
|
50 |
print(f"发生未知错误: {e}")
|
|
|
|
|
|
|
|
|
51 |
|
52 |
|
53 |
|
|
|
20 |
from src.utils.visualization_utils import render_video_from_file
|
21 |
from src.model import LSM_MASt3R
|
22 |
|
23 |
+
# 定义相对路径和 Hugging Face 仓库信息
|
24 |
+
relative_model_dir = "checkpoints" # 文件夹名称
|
25 |
+
relative_model_path = os.path.join(relative_model_dir, "checkpoint-40.pth") # 相对路径
|
26 |
+
model_repo = "kairunwen/LSM" # Hugging Face 仓库
|
27 |
+
model_filename = "checkpoint-40.pth" # 仓库中的文件名
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# 转换为绝对路径
|
30 |
+
model_path = os.path.abspath(relative_model_path)
|
31 |
|
32 |
try:
|
33 |
+
# 创建 checkpoints 文件夹(如果不存在)
|
34 |
+
os.makedirs(relative_model_dir, exist_ok=True)
|
35 |
+
print(f"确保 {relative_model_dir} 文件夹存在")
|
36 |
+
|
37 |
+
# 验证文件是否存在
|
38 |
+
if os.path.exists(model_path):
|
39 |
+
print(f"找到本地模型文件: {model_path}")
|
40 |
+
else:
|
41 |
+
print(f"本地模型文件 {model_path} 不存在,正在从 Hugging Face 下载...")
|
42 |
+
model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
|
43 |
+
# 可选:将下载的文件移动到 checkpoints 文件夹
|
44 |
+
import shutil
|
45 |
+
shutil.move(model_path, os.path.abspath(relative_model_path))
|
46 |
+
model_path = os.path.abspath(relative_model_path)
|
47 |
+
print(f"模型文件已下载并移动到: {model_path}")
|
48 |
|
49 |
# 加载模型
|
50 |
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
|
|
52 |
print("模型加载成功并设置为评估模式!")
|
53 |
|
54 |
except FileNotFoundError:
|
55 |
+
print(f"错误: 无法找到或下载文件 {model_filename},请检查路径或仓库 {model_repo}。")
|
56 |
except KeyError as e:
|
57 |
print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
|
58 |
except Exception as e:
|
59 |
print(f"发生未知错误: {e}")
|
60 |
+
# 调试:检查检查点内容
|
61 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
62 |
+
print("检查点键:", ckpt.keys())
|
63 |
+
print("config.model:", ckpt['args'].model)
|
64 |
|
65 |
|
66 |
|