File size: 7,038 Bytes
f54e7d4 e22a639 f54e7d4 226c7c9 f54e7d4 e22a639 f54e7d4 e22a639 f54e7d4 6d7fc1c f54e7d4 dcc4583 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import hashlib
import logging
import os
import pathlib
from typing import Literal
# Import the checkpoint paths
from cosmos_transfer1 import checkpoints
from huggingface_hub import login, snapshot_download
def download_checkpoint(checkpoint: str, output_dir: str) -> None:
"""Download a single checkpoint from HuggingFace Hub."""
try:
# Parse the checkpoint path to get repo_id and filename
checkpoint, revision = checkpoint.split(":") if ":" in checkpoint else (checkpoint, None)
checkpoint_dir = os.path.join(output_dir, checkpoint)
if get_md5_checksum(output_dir, checkpoint):
logging.warning(f"Checkpoint {checkpoint_dir} EXISTS, skipping download... ")
return
else:
print(f"Downloading {checkpoint} to {checkpoint_dir}")
# Create the output directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Downloading {checkpoint}...")
# Download the files
snapshot_download(repo_id=checkpoint, local_dir=checkpoint_dir, revision=revision)
print(f"Successfully downloaded {checkpoint}")
except Exception as e:
print(f"Error downloading {checkpoint}: {str(e)}")
MD5_CHECKSUM_LOOKUP = {
f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/pytorch_model.bin": "0fcf0d965ca9baec14bb1607005e2512",
f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/model.safetensors": "0739b040bb51f92464b4cd37f23405f9",
f"{checkpoints.T5_MODEL_CHECKPOINT}/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3",
f"{checkpoints.T5_MODEL_CHECKPOINT}/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973",
f"{checkpoints.SAM2_MODEL_CHECKPOINT}/sam2_hiera_large.pt": "08083462423be3260cd6a5eef94dc01c",
f"{checkpoints.DEPTH_ANYTHING_MODEL_CHECKPOINT}/model.safetensors": "14e97d7ed2146d548c873623cdc965de",
checkpoints.BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "2006e158f8a17a3b801c661f0c01e9f2",
checkpoints.HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "2ddd781560d221418c2ed9258b6ca829",
checkpoints.LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "184beee5414bcb6c0c5c0f09d8f8b481",
checkpoints.UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "b28378d13f323b49445dc469dfbbc317",
checkpoints.BASE_7B_CHECKPOINT_PATH: "356497b415f3b0697f8bb034d22b6807",
checkpoints.VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "69fdffc5006bc5d6acb29449bb3ffdca",
checkpoints.EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a0642e300e9e184077d875e1b5920a61",
checkpoints.DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "80999ed60d89a8dfee785c544e0ccd54",
checkpoints.SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "3e4077a80c836bf102c7b2ac2cd5da8c",
checkpoints.KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "26619fb1686cff0e69606a9c97cac68e",
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d",
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8",
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299",
f"{checkpoints.COSMOS_UPSAMPLER_CHECKPOINT}/consolidated.safetensors": "d06e6366e003126dcb351ce9b8bf3701",
f"{checkpoints.COSMOS_GUARDRAIL_CHECKPOINT}/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf",
f"{checkpoints.LLAMA_GUARD_3_MODEL_CHECKPOINT}/model-00001-of-00004.safetensors": "5748060ae47b335dc19263060c921a54",
checkpoints.SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "4f8a4340d48ebedaa9e7bab772e0203d",
checkpoints.SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "89b82db1bc1dc859178154f88b6ca0f2",
checkpoints.SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a9592d232a7e5f7971f39918c18eaae0",
checkpoints.SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "cb27af88ec7fb425faec32f4734d99cf",
checkpoints.BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "a3fb13e8418d8bb366b58e4092bd91df",
checkpoints.BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "48b2080ca5be66c05fac44dea4989a04",
}
def get_md5_checksum(output_dir, model_name):
print("---------------------")
for key, value in MD5_CHECKSUM_LOOKUP.items():
if key.startswith(model_name):
print(f"Verifying checkpoint {key}...")
file_path = os.path.join(output_dir, key)
# File must exist
if not pathlib.Path(file_path).exists():
print(f"Checkpoint {key} does not exist.")
return False
# File must match give MD5 checksum
with open(file_path, "rb") as f:
file_md5 = hashlib.md5(f.read()).hexdigest()
if file_md5 != value:
print(f"MD5 checksum of checkpoint {key} does not match.")
return False
return True
def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkpoints", model: Literal["all", "7b", "7b_av"] = "all"):
"""
Download checkpoints from HuggingFace Hub
:param str hf_token: HuggingFace token
:param str output_dir: Directory to store the downloaded checkpoints
:param str model: Model type to download
"""
if hf_token:
login(token=hf_token)
checkpoint_vars = []
# Get all variables from the checkpoints module
for name in dir(checkpoints):
obj = getattr(checkpoints, name)
if isinstance(obj, str) and "CHECKPOINT" in name and "PATH" not in name:
if model != "all" and name in [
"COSMOS_TRANSFER1_7B_CHECKPOINT",
"COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
]:
if model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT":
checkpoint_vars.append(obj)
elif model == "7b_av" and name in [
"COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
"COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT",
]:
checkpoint_vars.append(obj)
else:
checkpoint_vars.append(obj)
# checkpoint_vars = [
# "nvidia/Cosmos-Guardrail1",
# "nvidia/Cosmos-Tokenize1-CV8x8x8-720p",
# "nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView",
# "nvidia/Cosmos-Transfer1-7B-Sample-AV",
# "nvidia/Cosmos-UpsamplePrompt1-12B-Transfer",
# "depth-anything/Depth-Anything-V2-Small-hf",
# "IDEA-Research/grounding-dino-tiny",
# "meta-llama/Llama-Guard-3-8B",
# "facebook/sam2-hiera-large",
# "google-t5/t5-11b",
# ]
print(f"Found {len(checkpoint_vars)} checkpoints to download")
print(checkpoint_vars)
# Download each checkpoint
for checkpoint in checkpoint_vars:
download_checkpoint(checkpoint, output_dir)
if __name__ == "__main__":
import sys
PWD = os.path.dirname(__file__)
CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
main(hf_token=sys.argv[1], output_dir=CHECKPOINTS_PATH, model="all")
|