|
from huggingface_hub import hf_hub_download |
|
|
|
def from_pretrained( |
|
cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda") |
|
) -> "Dia": |
|
"""Loads the Dia model from a Hugging Face Hub repository. |
|
Downloads the configuration and checkpoint files from the specified |
|
repository ID and then loads the model. |
|
Args: |
|
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). |
|
device: The device to load the model onto. |
|
Returns: |
|
An instance of the Dia model loaded with weights and set to eval mode. |
|
Raises: |
|
FileNotFoundError: If config or checkpoint download/loading fails. |
|
RuntimeError: If there is an error loading the checkpoint. |
|
""" |
|
config_path = hf_hub_download(repo_id=model_name, filename="config.json") |
|
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") |
|
return cls.from_local(config_path, checkpoint_path, device) |