cosmos_transfer1_av / download_checkpoints.py
harry900000's picture
move assignment of CHECKPOINTS_PATH to `app.py`
dcc4583
import hashlib
import logging
import os
import pathlib
from typing import Literal
# Import the checkpoint paths
from cosmos_transfer1 import checkpoints
from huggingface_hub import login, snapshot_download
def download_checkpoint(checkpoint: str, output_dir: str) -> None:
"""Download a single checkpoint from HuggingFace Hub."""
try:
# Parse the checkpoint path to get repo_id and filename
checkpoint, revision = checkpoint.split(":") if ":" in checkpoint else (checkpoint, None)
checkpoint_dir = os.path.join(output_dir, checkpoint)
if get_md5_checksum(output_dir, checkpoint):
logging.warning(f"Checkpoint {checkpoint_dir} EXISTS, skipping download... ")
return
else:
print(f"Downloading {checkpoint} to {checkpoint_dir}")
# Create the output directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Downloading {checkpoint}...")
# Download the files
snapshot_download(repo_id=checkpoint, local_dir=checkpoint_dir, revision=revision)
print(f"Successfully downloaded {checkpoint}")
except Exception as e:
print(f"Error downloading {checkpoint}: {str(e)}")
MD5_CHECKSUM_LOOKUP = {
f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/pytorch_model.bin": "0fcf0d965ca9baec14bb1607005e2512",
f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/model.safetensors": "0739b040bb51f92464b4cd37f23405f9",
f"{checkpoints.T5_MODEL_CHECKPOINT}/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3",
f"{checkpoints.T5_MODEL_CHECKPOINT}/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973",
f"{checkpoints.SAM2_MODEL_CHECKPOINT}/sam2_hiera_large.pt": "08083462423be3260cd6a5eef94dc01c",
f"{checkpoints.DEPTH_ANYTHING_MODEL_CHECKPOINT}/model.safetensors": "14e97d7ed2146d548c873623cdc965de",
checkpoints.BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "2006e158f8a17a3b801c661f0c01e9f2",
checkpoints.HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "2ddd781560d221418c2ed9258b6ca829",
checkpoints.LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "184beee5414bcb6c0c5c0f09d8f8b481",
checkpoints.UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "b28378d13f323b49445dc469dfbbc317",
checkpoints.BASE_7B_CHECKPOINT_PATH: "356497b415f3b0697f8bb034d22b6807",
checkpoints.VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "69fdffc5006bc5d6acb29449bb3ffdca",
checkpoints.EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a0642e300e9e184077d875e1b5920a61",
checkpoints.DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "80999ed60d89a8dfee785c544e0ccd54",
checkpoints.SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "3e4077a80c836bf102c7b2ac2cd5da8c",
checkpoints.KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "26619fb1686cff0e69606a9c97cac68e",
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d",
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8",
"nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299",
f"{checkpoints.COSMOS_UPSAMPLER_CHECKPOINT}/consolidated.safetensors": "d06e6366e003126dcb351ce9b8bf3701",
f"{checkpoints.COSMOS_GUARDRAIL_CHECKPOINT}/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf",
f"{checkpoints.LLAMA_GUARD_3_MODEL_CHECKPOINT}/model-00001-of-00004.safetensors": "5748060ae47b335dc19263060c921a54",
checkpoints.SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "4f8a4340d48ebedaa9e7bab772e0203d",
checkpoints.SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "89b82db1bc1dc859178154f88b6ca0f2",
checkpoints.SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a9592d232a7e5f7971f39918c18eaae0",
checkpoints.SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "cb27af88ec7fb425faec32f4734d99cf",
checkpoints.BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "a3fb13e8418d8bb366b58e4092bd91df",
checkpoints.BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "48b2080ca5be66c05fac44dea4989a04",
}
def get_md5_checksum(output_dir, model_name):
print("---------------------")
for key, value in MD5_CHECKSUM_LOOKUP.items():
if key.startswith(model_name):
print(f"Verifying checkpoint {key}...")
file_path = os.path.join(output_dir, key)
# File must exist
if not pathlib.Path(file_path).exists():
print(f"Checkpoint {key} does not exist.")
return False
# File must match give MD5 checksum
with open(file_path, "rb") as f:
file_md5 = hashlib.md5(f.read()).hexdigest()
if file_md5 != value:
print(f"MD5 checksum of checkpoint {key} does not match.")
return False
return True
def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkpoints", model: Literal["all", "7b", "7b_av"] = "all"):
"""
Download checkpoints from HuggingFace Hub
:param str hf_token: HuggingFace token
:param str output_dir: Directory to store the downloaded checkpoints
:param str model: Model type to download
"""
if hf_token:
login(token=hf_token)
checkpoint_vars = []
# Get all variables from the checkpoints module
for name in dir(checkpoints):
obj = getattr(checkpoints, name)
if isinstance(obj, str) and "CHECKPOINT" in name and "PATH" not in name:
if model != "all" and name in [
"COSMOS_TRANSFER1_7B_CHECKPOINT",
"COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
]:
if model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT":
checkpoint_vars.append(obj)
elif model == "7b_av" and name in [
"COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
"COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT",
]:
checkpoint_vars.append(obj)
else:
checkpoint_vars.append(obj)
# checkpoint_vars = [
# "nvidia/Cosmos-Guardrail1",
# "nvidia/Cosmos-Tokenize1-CV8x8x8-720p",
# "nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView",
# "nvidia/Cosmos-Transfer1-7B-Sample-AV",
# "nvidia/Cosmos-UpsamplePrompt1-12B-Transfer",
# "depth-anything/Depth-Anything-V2-Small-hf",
# "IDEA-Research/grounding-dino-tiny",
# "meta-llama/Llama-Guard-3-8B",
# "facebook/sam2-hiera-large",
# "google-t5/t5-11b",
# ]
print(f"Found {len(checkpoint_vars)} checkpoints to download")
print(checkpoint_vars)
# Download each checkpoint
for checkpoint in checkpoint_vars:
download_checkpoint(checkpoint, output_dir)
if __name__ == "__main__":
import sys
PWD = os.path.dirname(__file__)
CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
main(hf_token=sys.argv[1], output_dir=CHECKPOINTS_PATH, model="all")