roll-ai commited on
Commit
9ab115e
·
verified ·
1 Parent(s): d8bfbd8

Update inference/flovd_demo.py

Browse files
Files changed (1) hide show
  1. inference/flovd_demo.py +4 -1
inference/flovd_demo.py CHANGED
@@ -71,6 +71,7 @@ from finetune.datasets.utils import (
71
  from torch.utils.data import Dataset
72
  from torchvision import transforms
73
 
 
74
 
75
  import pdb
76
  sys.path.append(os.path.abspath(os.path.join(sys.path[-1], 'finetune'))) # for camera flow generator
@@ -143,7 +144,9 @@ def load_cogvideox_flovd_OMSM_lora_pipeline(omsm_path, backbone_path, transforme
143
  # 1) Load Lora weight
144
  transformer.add_adapter(transformer_lora_config)
145
 
146
- lora_state_dict = FloVDOMSMCogVideoXImageToVideoPipeline.lora_state_dict(omsm_path)
 
 
147
  transformer_state_dict = {
148
  f'{k.replace("transformer.", "")}': v
149
  for k, v in lora_state_dict.items()
 
71
  from torch.utils.data import Dataset
72
  from torchvision import transforms
73
 
74
+ from safetensors.torch import load_file
75
 
76
  import pdb
77
  sys.path.append(os.path.abspath(os.path.join(sys.path[-1], 'finetune'))) # for camera flow generator
 
144
  # 1) Load Lora weight
145
  transformer.add_adapter(transformer_lora_config)
146
 
147
+ lora_path = os.path.join(omsm_path, "pytorch_lora_weights.safetensors")
148
+ lora_state_dict = load_file(lora_path)
149
+
150
  transformer_state_dict = {
151
  f'{k.replace("transformer.", "")}': v
152
  for k, v in lora_state_dict.items()