HaiderAUT commited on
Commit
83a5f2b
·
verified ·
1 Parent(s): 7568705

Create dia_use.py

Browse files
Files changed (1) hide show
  1. dia_use.py +20 -0
dia_use.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ def from_pretrained(
4
+ cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
5
+ ) -> "Dia":
6
+ """Loads the Dia model from a Hugging Face Hub repository.
7
+ Downloads the configuration and checkpoint files from the specified
8
+ repository ID and then loads the model.
9
+ Args:
10
+ model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
11
+ device: The device to load the model onto.
12
+ Returns:
13
+ An instance of the Dia model loaded with weights and set to eval mode.
14
+ Raises:
15
+ FileNotFoundError: If config or checkpoint download/loading fails.
16
+ RuntimeError: If there is an error loading the checkpoint.
17
+ """
18
+ config_path = hf_hub_download(repo_id=model_name, filename="config.json")
19
+ checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
20
+ return cls.from_local(config_path, checkpoint_path, device)