|
import hashlib |
|
import logging |
|
import os |
|
import pathlib |
|
from typing import Literal |
|
|
|
|
|
from cosmos_transfer1 import checkpoints |
|
from huggingface_hub import login, snapshot_download |
|
|
|
|
|
def download_checkpoint(checkpoint: str, output_dir: str) -> None: |
|
"""Download a single checkpoint from HuggingFace Hub.""" |
|
try: |
|
|
|
checkpoint, revision = checkpoint.split(":") if ":" in checkpoint else (checkpoint, None) |
|
checkpoint_dir = os.path.join(output_dir, checkpoint) |
|
if get_md5_checksum(output_dir, checkpoint): |
|
logging.warning(f"Checkpoint {checkpoint_dir} EXISTS, skipping download... ") |
|
return |
|
else: |
|
print(f"Downloading {checkpoint} to {checkpoint_dir}") |
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
print(f"Downloading {checkpoint}...") |
|
|
|
snapshot_download(repo_id=checkpoint, local_dir=checkpoint_dir, revision=revision) |
|
print(f"Successfully downloaded {checkpoint}") |
|
|
|
except Exception as e: |
|
print(f"Error downloading {checkpoint}: {str(e)}") |
|
|
|
|
|
MD5_CHECKSUM_LOOKUP = { |
|
f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/pytorch_model.bin": "0fcf0d965ca9baec14bb1607005e2512", |
|
f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/model.safetensors": "0739b040bb51f92464b4cd37f23405f9", |
|
f"{checkpoints.T5_MODEL_CHECKPOINT}/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3", |
|
f"{checkpoints.T5_MODEL_CHECKPOINT}/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973", |
|
f"{checkpoints.SAM2_MODEL_CHECKPOINT}/sam2_hiera_large.pt": "08083462423be3260cd6a5eef94dc01c", |
|
f"{checkpoints.DEPTH_ANYTHING_MODEL_CHECKPOINT}/model.safetensors": "14e97d7ed2146d548c873623cdc965de", |
|
checkpoints.BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "2006e158f8a17a3b801c661f0c01e9f2", |
|
checkpoints.HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "2ddd781560d221418c2ed9258b6ca829", |
|
checkpoints.LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "184beee5414bcb6c0c5c0f09d8f8b481", |
|
checkpoints.UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "b28378d13f323b49445dc469dfbbc317", |
|
checkpoints.BASE_7B_CHECKPOINT_PATH: "356497b415f3b0697f8bb034d22b6807", |
|
checkpoints.VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "69fdffc5006bc5d6acb29449bb3ffdca", |
|
checkpoints.EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a0642e300e9e184077d875e1b5920a61", |
|
checkpoints.DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "80999ed60d89a8dfee785c544e0ccd54", |
|
checkpoints.SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "3e4077a80c836bf102c7b2ac2cd5da8c", |
|
checkpoints.KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "26619fb1686cff0e69606a9c97cac68e", |
|
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d", |
|
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8", |
|
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299", |
|
f"{checkpoints.COSMOS_UPSAMPLER_CHECKPOINT}/consolidated.safetensors": "d06e6366e003126dcb351ce9b8bf3701", |
|
f"{checkpoints.COSMOS_GUARDRAIL_CHECKPOINT}/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", |
|
f"{checkpoints.LLAMA_GUARD_3_MODEL_CHECKPOINT}/model-00001-of-00004.safetensors": "5748060ae47b335dc19263060c921a54", |
|
checkpoints.SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "4f8a4340d48ebedaa9e7bab772e0203d", |
|
checkpoints.SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "89b82db1bc1dc859178154f88b6ca0f2", |
|
checkpoints.SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a9592d232a7e5f7971f39918c18eaae0", |
|
checkpoints.SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "cb27af88ec7fb425faec32f4734d99cf", |
|
checkpoints.BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "a3fb13e8418d8bb366b58e4092bd91df", |
|
checkpoints.BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "48b2080ca5be66c05fac44dea4989a04", |
|
} |
|
|
|
|
|
def get_md5_checksum(output_dir, model_name): |
|
print("---------------------") |
|
for key, value in MD5_CHECKSUM_LOOKUP.items(): |
|
if key.startswith(model_name): |
|
print(f"Verifying checkpoint {key}...") |
|
file_path = os.path.join(output_dir, key) |
|
|
|
if not pathlib.Path(file_path).exists(): |
|
print(f"Checkpoint {key} does not exist.") |
|
return False |
|
|
|
with open(file_path, "rb") as f: |
|
file_md5 = hashlib.md5(f.read()).hexdigest() |
|
if file_md5 != value: |
|
print(f"MD5 checksum of checkpoint {key} does not match.") |
|
return False |
|
return True |
|
|
|
|
|
def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkpoints", model: Literal["all", "7b", "7b_av"] = "all"): |
|
""" |
|
Download checkpoints from HuggingFace Hub |
|
|
|
:param str hf_token: HuggingFace token |
|
:param str output_dir: Directory to store the downloaded checkpoints |
|
:param str model: Model type to download |
|
""" |
|
|
|
if hf_token: |
|
login(token=hf_token) |
|
|
|
checkpoint_vars = [] |
|
|
|
for name in dir(checkpoints): |
|
obj = getattr(checkpoints, name) |
|
if isinstance(obj, str) and "CHECKPOINT" in name and "PATH" not in name: |
|
if model != "all" and name in [ |
|
"COSMOS_TRANSFER1_7B_CHECKPOINT", |
|
"COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT", |
|
]: |
|
if model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT": |
|
checkpoint_vars.append(obj) |
|
elif model == "7b_av" and name in [ |
|
"COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT", |
|
"COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT", |
|
]: |
|
checkpoint_vars.append(obj) |
|
else: |
|
checkpoint_vars.append(obj) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Found {len(checkpoint_vars)} checkpoints to download") |
|
print(checkpoint_vars) |
|
|
|
|
|
for checkpoint in checkpoint_vars: |
|
download_checkpoint(checkpoint, output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
PWD = os.path.dirname(__file__) |
|
CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints") |
|
os.makedirs(CHECKPOINTS_PATH, exist_ok=True) |
|
main(hf_token=sys.argv[1], output_dir=CHECKPOINTS_PATH, model="all") |
|
|