PodCastIt / dia_use.py
HaiderAUT's picture
Create dia_use.py
83a5f2b verified
from huggingface_hub import hf_hub_download
def from_pretrained(
cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
) -> "Dia":
"""Loads the Dia model from a Hugging Face Hub repository.
Downloads the configuration and checkpoint files from the specified
repository ID and then loads the model.
Args:
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
device: The device to load the model onto.
Returns:
An instance of the Dia model loaded with weights and set to eval mode.
Raises:
FileNotFoundError: If config or checkpoint download/loading fails.
RuntimeError: If there is an error loading the checkpoint.
"""
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
return cls.from_local(config_path, checkpoint_path, device)