Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import argparse
|
|
7 |
import gradio as gr
|
8 |
import uuid
|
9 |
import spaces
|
10 |
-
from huggingface_hub import
|
11 |
#
|
12 |
|
13 |
subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
|
@@ -20,50 +20,27 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
|
|
20 |
from src.utils.visualization_utils import render_video_from_file
|
21 |
from src.model import LSM_MASt3R
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
import shutil
|
45 |
-
shutil.move(model_path, os.path.abspath(relative_model_path))
|
46 |
-
model_path = os.path.abspath(relative_model_path)
|
47 |
-
print(f"模型文件已下载并移动到: {model_path}")
|
48 |
-
|
49 |
-
# 加载模型
|
50 |
-
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
51 |
-
model = model.eval()
|
52 |
-
print("模型加载成功并设置为评估模式!")
|
53 |
-
|
54 |
-
except FileNotFoundError:
|
55 |
-
print(f"错误: 无法找到或下载文件 {model_filename},请检查路径或仓库 {model_repo}。")
|
56 |
-
except KeyError as e:
|
57 |
-
print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
|
58 |
-
except Exception as e:
|
59 |
-
print(f"发生未知错误: {e}")
|
60 |
-
# 调试:检查检查点内容
|
61 |
-
ckpt = torch.load(model_path, map_location='cpu')
|
62 |
-
print("检查点键:", ckpt.keys())
|
63 |
-
print("config.model:", ckpt['args'].model)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
@spaces.GPU(duration=80)
|
69 |
def process(inputfiles, input_path=None):
|
|
|
7 |
import gradio as gr
|
8 |
import uuid
|
9 |
import spaces
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
#
|
12 |
|
13 |
subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
|
|
|
20 |
from src.utils.visualization_utils import render_video_from_file
|
21 |
from src.model import LSM_MASt3R
|
22 |
|
23 |
+
# Download the model checkpoint from Hugging Face Hub
|
24 |
+
repo_id = "Journey9ni/LSM"
|
25 |
+
remote_dir = "checkpoints/pretrained_models"
|
26 |
+
local_dir = "checkpoints/pretrained_model"
|
27 |
+
model_path_map = {
|
28 |
+
"MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
|
29 |
+
"checkpoint-40.pth":"checkpoint-40.pth",
|
30 |
+
"demo_e200.ckpt":"lang_seg.ckpt"
|
31 |
+
}
|
32 |
+
os.makedirs(local_dir, exist_ok=True)
|
33 |
+
# download remote repo
|
34 |
+
snapshot_download(repo_id=repo_id, local_dir='./')
|
35 |
+
|
36 |
+
# rename the files
|
37 |
+
for remote_name, local_name in model_path_map.items():
|
38 |
+
os.rename(os.path.join(remote_dir, remote_name), os.path.join(local_dir, local_name))
|
39 |
+
|
40 |
+
# load the model
|
41 |
+
model_path = "checkpoints/pretrained_model/checkpoint-40.pth"
|
42 |
+
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
43 |
+
model = model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
@spaces.GPU(duration=80)
|
46 |
def process(inputfiles, input_path=None):
|