helloworld-S commited on
Commit
1979685
·
verified ·
1 Parent(s): 09a4d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -33,6 +33,8 @@ import shutil
33
  import yaml
34
  import numpy as np
35
 
 
 
36
  dtype = torch.bfloat16
37
  device = "cuda"
38
 
@@ -52,7 +54,9 @@ config = get_train_config(config_path)
52
  model.config = config
53
  store_attn_map = False
54
 
55
- ckpt_root = "~/.cache/huggingface/hub/XVerse"
 
 
56
  modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
57
  model.add_modulation_adapter(modulation_adapter)
58
  if config["model"]["use_dit_lora"]:
 
33
  import yaml
34
  import numpy as np
35
 
36
+ from huggingface_hub import hf_hub_download
37
+
38
  dtype = torch.bfloat16
39
  device = "cuda"
40
 
 
54
  model.config = config
55
  store_attn_map = False
56
 
57
+ file_path = hf_hub_download(repo_id="ByteDance/XVerse", force_download=False)
58
+ ckpt_root = os.path.dirname(file_path)
59
+
60
  modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
61
  model.add_modulation_adapter(modulation_adapter)
62
  if config["model"]["use_dit_lora"]: