File size: 7,038 Bytes
f54e7d4
e22a639
f54e7d4
 
 
 
 
226c7c9
f54e7d4
 
 
 
 
 
 
 
 
 
e22a639
f54e7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e22a639
 
 
 
 
 
 
 
 
 
 
 
f54e7d4
6d7fc1c
f54e7d4
 
 
 
dcc4583
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import hashlib
import logging
import os
import pathlib
from typing import Literal

# Import the checkpoint paths
from cosmos_transfer1 import checkpoints
from huggingface_hub import login, snapshot_download


def download_checkpoint(checkpoint: str, output_dir: str) -> None:
    """Download a single checkpoint from HuggingFace Hub."""
    try:
        # Parse the checkpoint path to get repo_id and filename
        checkpoint, revision = checkpoint.split(":") if ":" in checkpoint else (checkpoint, None)
        checkpoint_dir = os.path.join(output_dir, checkpoint)
        if get_md5_checksum(output_dir, checkpoint):
            logging.warning(f"Checkpoint {checkpoint_dir} EXISTS, skipping download... ")
            return
        else:
            print(f"Downloading {checkpoint} to {checkpoint_dir}")
        # Create the output directory if it doesn't exist
        os.makedirs(checkpoint_dir, exist_ok=True)
        print(f"Downloading {checkpoint}...")
        # Download the files
        snapshot_download(repo_id=checkpoint, local_dir=checkpoint_dir, revision=revision)
        print(f"Successfully downloaded {checkpoint}")

    except Exception as e:
        print(f"Error downloading {checkpoint}: {str(e)}")


MD5_CHECKSUM_LOOKUP = {
    f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/pytorch_model.bin": "0fcf0d965ca9baec14bb1607005e2512",
    f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/model.safetensors": "0739b040bb51f92464b4cd37f23405f9",
    f"{checkpoints.T5_MODEL_CHECKPOINT}/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3",
    f"{checkpoints.T5_MODEL_CHECKPOINT}/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973",
    f"{checkpoints.SAM2_MODEL_CHECKPOINT}/sam2_hiera_large.pt": "08083462423be3260cd6a5eef94dc01c",
    f"{checkpoints.DEPTH_ANYTHING_MODEL_CHECKPOINT}/model.safetensors": "14e97d7ed2146d548c873623cdc965de",
    checkpoints.BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "2006e158f8a17a3b801c661f0c01e9f2",
    checkpoints.HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "2ddd781560d221418c2ed9258b6ca829",
    checkpoints.LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "184beee5414bcb6c0c5c0f09d8f8b481",
    checkpoints.UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "b28378d13f323b49445dc469dfbbc317",
    checkpoints.BASE_7B_CHECKPOINT_PATH: "356497b415f3b0697f8bb034d22b6807",
    checkpoints.VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "69fdffc5006bc5d6acb29449bb3ffdca",
    checkpoints.EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a0642e300e9e184077d875e1b5920a61",
    checkpoints.DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "80999ed60d89a8dfee785c544e0ccd54",
    checkpoints.SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "3e4077a80c836bf102c7b2ac2cd5da8c",
    checkpoints.KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "26619fb1686cff0e69606a9c97cac68e",
    "nvidia/Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d",
    "nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8",
    "nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299",
    f"{checkpoints.COSMOS_UPSAMPLER_CHECKPOINT}/consolidated.safetensors": "d06e6366e003126dcb351ce9b8bf3701",
    f"{checkpoints.COSMOS_GUARDRAIL_CHECKPOINT}/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf",
    f"{checkpoints.LLAMA_GUARD_3_MODEL_CHECKPOINT}/model-00001-of-00004.safetensors": "5748060ae47b335dc19263060c921a54",
    checkpoints.SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "4f8a4340d48ebedaa9e7bab772e0203d",
    checkpoints.SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "89b82db1bc1dc859178154f88b6ca0f2",
    checkpoints.SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a9592d232a7e5f7971f39918c18eaae0",
    checkpoints.SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "cb27af88ec7fb425faec32f4734d99cf",
    checkpoints.BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "a3fb13e8418d8bb366b58e4092bd91df",
    checkpoints.BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "48b2080ca5be66c05fac44dea4989a04",
}


def get_md5_checksum(output_dir, model_name):
    print("---------------------")
    for key, value in MD5_CHECKSUM_LOOKUP.items():
        if key.startswith(model_name):
            print(f"Verifying checkpoint {key}...")
            file_path = os.path.join(output_dir, key)
            # File must exist
            if not pathlib.Path(file_path).exists():
                print(f"Checkpoint {key} does not exist.")
                return False
            # File must match give MD5 checksum
            with open(file_path, "rb") as f:
                file_md5 = hashlib.md5(f.read()).hexdigest()
            if file_md5 != value:
                print(f"MD5 checksum of checkpoint {key} does not match.")
                return False
    return True


def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkpoints", model: Literal["all", "7b", "7b_av"] = "all"):
    """
    Download checkpoints from HuggingFace Hub

    :param str hf_token: HuggingFace token
    :param str output_dir: Directory to store the downloaded checkpoints
    :param str model: Model type to download
    """

    if hf_token:
        login(token=hf_token)

    checkpoint_vars = []
    # Get all variables from the checkpoints module
    for name in dir(checkpoints):
        obj = getattr(checkpoints, name)
        if isinstance(obj, str) and "CHECKPOINT" in name and "PATH" not in name:
            if model != "all" and name in [
                "COSMOS_TRANSFER1_7B_CHECKPOINT",
                "COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
            ]:
                if model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT":
                    checkpoint_vars.append(obj)
                elif model == "7b_av" and name in [
                    "COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
                    "COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT",
                ]:
                    checkpoint_vars.append(obj)
            else:
                checkpoint_vars.append(obj)

    # checkpoint_vars = [
    #     "nvidia/Cosmos-Guardrail1",
    #     "nvidia/Cosmos-Tokenize1-CV8x8x8-720p",
    #     "nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView",
    #     "nvidia/Cosmos-Transfer1-7B-Sample-AV",
    #     "nvidia/Cosmos-UpsamplePrompt1-12B-Transfer",
    #     "depth-anything/Depth-Anything-V2-Small-hf",
    #     "IDEA-Research/grounding-dino-tiny",
    #     "meta-llama/Llama-Guard-3-8B",
    #     "facebook/sam2-hiera-large",
    #     "google-t5/t5-11b",
    # ]
    print(f"Found {len(checkpoint_vars)} checkpoints to download")
    print(checkpoint_vars)

    # Download each checkpoint
    for checkpoint in checkpoint_vars:
        download_checkpoint(checkpoint, output_dir)


if __name__ == "__main__":
    import sys

    PWD = os.path.dirname(__file__)
    CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
    os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
    main(hf_token=sys.argv[1], output_dir=CHECKPOINTS_PATH, model="all")