Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,155 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import sys
import argparse
from huggingface_hub import hf_hub_download
from pathlib import Path
def download_scenedino_checkpoint(model_name):
print("----------------------- Downloading pretrained model -----------------------")
model_configs = {
"ssc-kitti-360-dino": {
"model-dir": "seg-best-dino"
},
"ssc-kitti-360-dino-orb-slam": {
"model-dir": "seg-best-dino-orb-slam"
},
"ssc-kitti-360-dinov2": {
"model-dir": "seg-best-dinov2"
}
}
repo_id = "jev-aleks/SceneDINO"
checkpoint_filename = "checkpoint.pt"
config_filename = "training_config.yaml"
if model_name not in model_configs:
raise ValueError(f"Unknown model: {model_name}. Possible options: {', '.join(model_configs.keys())}")
config = model_configs[model_name]
output_dir = Path("out/scenedino-pretrained")
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_filename = Path(config["model-dir"]) / checkpoint_filename
config_filename = Path(config["model-dir"]) / config_filename
checkpoint_path = output_dir / checkpoint_filename
config_path = output_dir / config_filename
print(f"Operating in \"{os.getcwd()}\".")
print(f"Creating directories: {output_dir}")
# Download checkpoint
print(f"Downloading checkpoint from HF repo \"{repo_id}\" to \"{checkpoint_path}\".")
hf_hub_download(
repo_id=repo_id,
filename=str(checkpoint_filename),
local_dir=str(output_dir),
)
# Download config
print(f"Downloading config from HF repo \"{repo_id}\" to \"{config_path}\".")
hf_hub_download(
repo_id=repo_id,
filename=str(config_filename),
local_dir=str(output_dir),
)
print("Download completed successfully!")
def main():
parser = argparse.ArgumentParser(description="Download pretrained models from Hugging Face Hub")
parser.add_argument("model", help="Model name to download")
args = parser.parse_args()
download_scenedino_checkpoint(args.model)
if __name__ == "__main__":
main() |