File size: 1,041 Bytes
83a5f2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from huggingface_hub import hf_hub_download

def from_pretrained(
        cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
    ) -> "Dia":
        """Loads the Dia model from a Hugging Face Hub repository.
        Downloads the configuration and checkpoint files from the specified
        repository ID and then loads the model.
        Args:
            model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
            device: The device to load the model onto.
        Returns:
            An instance of the Dia model loaded with weights and set to eval mode.
        Raises:
            FileNotFoundError: If config or checkpoint download/loading fails.
            RuntimeError: If there is an error loading the checkpoint.
        """
        config_path = hf_hub_download(repo_id=model_name, filename="config.json")
        checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
        return cls.from_local(config_path, checkpoint_path, device)