from huggingface_hub import hf_hub_download def from_pretrained( cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda") ) -> "Dia": """Loads the Dia model from a Hugging Face Hub repository. Downloads the configuration and checkpoint files from the specified repository ID and then loads the model. Args: model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). device: The device to load the model onto. Returns: An instance of the Dia model loaded with weights and set to eval mode. Raises: FileNotFoundError: If config or checkpoint download/loading fails. RuntimeError: If there is an error loading the checkpoint. """ config_path = hf_hub_download(repo_id=model_name, filename="config.json") checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") return cls.from_local(config_path, checkpoint_path, device)