Spaces:
Paused
Paused
Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +4 -1
inference/flovd_demo.py
CHANGED
@@ -71,6 +71,7 @@ from finetune.datasets.utils import (
|
|
71 |
from torch.utils.data import Dataset
|
72 |
from torchvision import transforms
|
73 |
|
|
|
74 |
|
75 |
import pdb
|
76 |
sys.path.append(os.path.abspath(os.path.join(sys.path[-1], 'finetune'))) # for camera flow generator
|
@@ -143,7 +144,9 @@ def load_cogvideox_flovd_OMSM_lora_pipeline(omsm_path, backbone_path, transforme
|
|
143 |
# 1) Load Lora weight
|
144 |
transformer.add_adapter(transformer_lora_config)
|
145 |
|
146 |
-
|
|
|
|
|
147 |
transformer_state_dict = {
|
148 |
f'{k.replace("transformer.", "")}': v
|
149 |
for k, v in lora_state_dict.items()
|
|
|
71 |
from torch.utils.data import Dataset
|
72 |
from torchvision import transforms
|
73 |
|
74 |
+
from safetensors.torch import load_file
|
75 |
|
76 |
import pdb
|
77 |
sys.path.append(os.path.abspath(os.path.join(sys.path[-1], 'finetune'))) # for camera flow generator
|
|
|
144 |
# 1) Load Lora weight
|
145 |
transformer.add_adapter(transformer_lora_config)
|
146 |
|
147 |
+
lora_path = os.path.join(omsm_path, "pytorch_lora_weights.safetensors")
|
148 |
+
lora_state_dict = load_file(lora_path)
|
149 |
+
|
150 |
transformer_state_dict = {
|
151 |
f'{k.replace("transformer.", "")}': v
|
152 |
for k, v in lora_state_dict.items()
|