Create dia_use.py
Browse files- dia_use.py +20 -0
dia_use.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
|
3 |
+
def from_pretrained(
|
4 |
+
cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
|
5 |
+
) -> "Dia":
|
6 |
+
"""Loads the Dia model from a Hugging Face Hub repository.
|
7 |
+
Downloads the configuration and checkpoint files from the specified
|
8 |
+
repository ID and then loads the model.
|
9 |
+
Args:
|
10 |
+
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
11 |
+
device: The device to load the model onto.
|
12 |
+
Returns:
|
13 |
+
An instance of the Dia model loaded with weights and set to eval mode.
|
14 |
+
Raises:
|
15 |
+
FileNotFoundError: If config or checkpoint download/loading fails.
|
16 |
+
RuntimeError: If there is an error loading the checkpoint.
|
17 |
+
"""
|
18 |
+
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
19 |
+
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
20 |
+
return cls.from_local(config_path, checkpoint_path, device)
|