Spaces:
Runtime error
Runtime error
Commit
·
2bdad0b
1
Parent(s):
c022669
Refactor model loading to use snapshot_download from Hugging Face Hub and streamline file management for pretrained models.
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import argparse
|
|
| 7 |
import gradio as gr
|
| 8 |
import uuid
|
| 9 |
import spaces
|
| 10 |
-
from huggingface_hub import
|
| 11 |
#
|
| 12 |
|
| 13 |
subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
|
|
@@ -20,50 +20,27 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
|
|
| 20 |
from src.utils.visualization_utils import render_video_from_file
|
| 21 |
from src.model import LSM_MASt3R
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
import shutil
|
| 45 |
-
shutil.move(model_path, os.path.abspath(relative_model_path))
|
| 46 |
-
model_path = os.path.abspath(relative_model_path)
|
| 47 |
-
print(f"模型文件已下载并移动到: {model_path}")
|
| 48 |
-
|
| 49 |
-
# 加载模型
|
| 50 |
-
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
| 51 |
-
model = model.eval()
|
| 52 |
-
print("模型加载成功并设置为评估模式!")
|
| 53 |
-
|
| 54 |
-
except FileNotFoundError:
|
| 55 |
-
print(f"错误: 无法找到或下载文件 {model_filename},请检查路径或仓库 {model_repo}。")
|
| 56 |
-
except KeyError as e:
|
| 57 |
-
print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
|
| 58 |
-
except Exception as e:
|
| 59 |
-
print(f"发生未知错误: {e}")
|
| 60 |
-
# 调试:检查检查点内容
|
| 61 |
-
ckpt = torch.load(model_path, map_location='cpu')
|
| 62 |
-
print("检查点键:", ckpt.keys())
|
| 63 |
-
print("config.model:", ckpt['args'].model)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
@spaces.GPU(duration=80)
|
| 69 |
def process(inputfiles, input_path=None):
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import uuid
|
| 9 |
import spaces
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
#
|
| 12 |
|
| 13 |
subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
|
|
|
|
| 20 |
from src.utils.visualization_utils import render_video_from_file
|
| 21 |
from src.model import LSM_MASt3R
|
| 22 |
|
| 23 |
+
# Download the model checkpoint from Hugging Face Hub
|
| 24 |
+
repo_id = "Journey9ni/LSM"
|
| 25 |
+
remote_dir = "checkpoints/pretrained_models/"
|
| 26 |
+
local_dir = "checkpoints/pretrained_model"
|
| 27 |
+
model_path_map = {
|
| 28 |
+
"MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
|
| 29 |
+
"checkpoint-40.pth":"checkpoint-40.pth",
|
| 30 |
+
"demo_e200.ckpt":"lang_seg.ckpt"
|
| 31 |
+
}
|
| 32 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 33 |
+
# download remote repo
|
| 34 |
+
snapshot_download(repo_id=repo_id)
|
| 35 |
+
|
| 36 |
+
# rename the files
|
| 37 |
+
for remote_name, local_name in model_path_map.items():
|
| 38 |
+
os.rename(os.path.join(local_dir, remote_name), os.path.join(local_dir, local_name))
|
| 39 |
+
|
| 40 |
+
# load the model
|
| 41 |
+
model_path = "checkpoints/pretrained_model/checkpoint-40.pth"
|
| 42 |
+
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
| 43 |
+
model = model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
@spaces.GPU(duration=80)
|
| 46 |
def process(inputfiles, input_path=None):
|