Spaces:
Paused
Paused
CausalVideoAutoencoder: made neater load_ckpt.
Browse files
xora/examples/image_to_video.py
CHANGED
|
@@ -19,12 +19,12 @@ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
| 19 |
vae_config_path = vae_dir / "config.json"
|
| 20 |
with open(vae_config_path, 'r') as f:
|
| 21 |
vae_config = json.load(f)
|
|
|
|
| 22 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 23 |
-
vae
|
| 24 |
-
config=vae_config,
|
| 25 |
state_dict=vae_state_dict,
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
# Load UNet (Transformer) from separate mode
|
| 30 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
|
|
|
| 19 |
vae_config_path = vae_dir / "config.json"
|
| 20 |
with open(vae_config_path, 'r') as f:
|
| 21 |
vae_config = json.load(f)
|
| 22 |
+
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 23 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 24 |
+
vae.load_state_dict(
|
|
|
|
| 25 |
state_dict=vae_state_dict,
|
| 26 |
+
)
|
| 27 |
+
vae = vae.cuda().to(torch.bfloat16)
|
| 28 |
|
| 29 |
# Load UNet (Transformer) from separate mode
|
| 30 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
xora/examples/text_to_video.py
CHANGED
|
@@ -10,7 +10,7 @@ import safetensors.torch
|
|
| 10 |
import json
|
| 11 |
|
| 12 |
# Paths for the separate mode directories
|
| 13 |
-
separate_dir = Path("/opt/models/xora-
|
| 14 |
unet_dir = separate_dir / 'unet'
|
| 15 |
vae_dir = separate_dir / 'vae'
|
| 16 |
scheduler_dir = separate_dir / 'scheduler'
|
|
@@ -20,12 +20,12 @@ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
| 20 |
vae_config_path = vae_dir / "config.json"
|
| 21 |
with open(vae_config_path, 'r') as f:
|
| 22 |
vae_config = json.load(f)
|
|
|
|
| 23 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 24 |
-
vae
|
| 25 |
-
config=vae_config,
|
| 26 |
state_dict=vae_state_dict,
|
| 27 |
-
|
| 28 |
-
|
| 29 |
|
| 30 |
# Load UNet (Transformer) from separate mode
|
| 31 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
|
|
|
| 10 |
import json
|
| 11 |
|
| 12 |
# Paths for the separate mode directories
|
| 13 |
+
separate_dir = Path("/opt/models/xora-img2video")
|
| 14 |
unet_dir = separate_dir / 'unet'
|
| 15 |
vae_dir = separate_dir / 'vae'
|
| 16 |
scheduler_dir = separate_dir / 'scheduler'
|
|
|
|
| 20 |
vae_config_path = vae_dir / "config.json"
|
| 21 |
with open(vae_config_path, 'r') as f:
|
| 22 |
vae_config = json.load(f)
|
| 23 |
+
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 24 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 25 |
+
vae.load_state_dict(
|
|
|
|
| 26 |
state_dict=vae_state_dict,
|
| 27 |
+
)
|
| 28 |
+
vae = vae.cuda().to(torch.bfloat16)
|
| 29 |
|
| 30 |
# Load UNet (Transformer) from separate mode
|
| 31 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -41,35 +41,6 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 41 |
|
| 42 |
return video_vae
|
| 43 |
|
| 44 |
-
@classmethod
|
| 45 |
-
def from_pretrained_conf(cls, config, state_dict, torch_dtype=torch.float32):
|
| 46 |
-
video_vae = cls.from_config(config)
|
| 47 |
-
video_vae.to(torch_dtype)
|
| 48 |
-
|
| 49 |
-
per_channel_statistics_prefix = "per_channel_statistics."
|
| 50 |
-
ckpt_state_dict = {
|
| 51 |
-
key: value
|
| 52 |
-
for key, value in state_dict.items()
|
| 53 |
-
if not key.startswith(per_channel_statistics_prefix)
|
| 54 |
-
}
|
| 55 |
-
video_vae.load_state_dict(ckpt_state_dict)
|
| 56 |
-
|
| 57 |
-
data_dict = {
|
| 58 |
-
key.removeprefix(per_channel_statistics_prefix): value
|
| 59 |
-
for key, value in state_dict.items()
|
| 60 |
-
if key.startswith(per_channel_statistics_prefix)
|
| 61 |
-
}
|
| 62 |
-
if len(data_dict) > 0:
|
| 63 |
-
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 64 |
-
video_vae.register_buffer(
|
| 65 |
-
"mean_of_means",
|
| 66 |
-
data_dict.get(
|
| 67 |
-
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 68 |
-
),
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
return video_vae
|
| 72 |
-
|
| 73 |
@staticmethod
|
| 74 |
def from_config(config):
|
| 75 |
assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
|
|
@@ -155,6 +126,13 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 155 |
return json.dumps(self.config.__dict__)
|
| 156 |
|
| 157 |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
model_keys = set(name for name, _ in self.named_parameters())
|
| 159 |
|
| 160 |
key_mapping = {
|
|
@@ -162,9 +140,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 162 |
"downsamplers.0": "downsample",
|
| 163 |
"upsamplers.0": "upsample",
|
| 164 |
}
|
| 165 |
-
|
| 166 |
converted_state_dict = {}
|
| 167 |
-
for key, value in
|
| 168 |
for k, v in key_mapping.items():
|
| 169 |
key = key.replace(k, v)
|
| 170 |
|
|
@@ -176,6 +153,20 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 176 |
|
| 177 |
super().load_state_dict(converted_state_dict, strict=strict)
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def last_layer(self):
|
| 180 |
if hasattr(self.decoder, "conv_out"):
|
| 181 |
if isinstance(self.decoder.conv_out, nn.Sequential):
|
|
|
|
| 41 |
|
| 42 |
return video_vae
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
@staticmethod
|
| 45 |
def from_config(config):
|
| 46 |
assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
|
|
|
|
| 126 |
return json.dumps(self.config.__dict__)
|
| 127 |
|
| 128 |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 129 |
+
per_channel_statistics_prefix = "per_channel_statistics."
|
| 130 |
+
ckpt_state_dict = {
|
| 131 |
+
key: value
|
| 132 |
+
for key, value in state_dict.items()
|
| 133 |
+
if not key.startswith(per_channel_statistics_prefix)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
model_keys = set(name for name, _ in self.named_parameters())
|
| 137 |
|
| 138 |
key_mapping = {
|
|
|
|
| 140 |
"downsamplers.0": "downsample",
|
| 141 |
"upsamplers.0": "upsample",
|
| 142 |
}
|
|
|
|
| 143 |
converted_state_dict = {}
|
| 144 |
+
for key, value in ckpt_state_dict.items():
|
| 145 |
for k, v in key_mapping.items():
|
| 146 |
key = key.replace(k, v)
|
| 147 |
|
|
|
|
| 153 |
|
| 154 |
super().load_state_dict(converted_state_dict, strict=strict)
|
| 155 |
|
| 156 |
+
data_dict = {
|
| 157 |
+
key.removeprefix(per_channel_statistics_prefix): value
|
| 158 |
+
for key, value in state_dict.items()
|
| 159 |
+
if key.startswith(per_channel_statistics_prefix)
|
| 160 |
+
}
|
| 161 |
+
if len(data_dict) > 0:
|
| 162 |
+
self.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 163 |
+
self.register_buffer(
|
| 164 |
+
"mean_of_means",
|
| 165 |
+
data_dict.get(
|
| 166 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 167 |
+
),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
def last_layer(self):
|
| 171 |
if hasattr(self.decoder, "conv_out"):
|
| 172 |
if isinstance(self.decoder.conv_out, nn.Sequential):
|