Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,666 Bytes
9867d34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
import torch
import os
from loguru import logger
from torchvision import transforms
from torchvision.transforms import v2
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection
from ..models.dac_vae.model.dac import DAC
from ..models.synchformer import Synchformer
from ..models.hifi_foley import HunyuanVideoFoley
from .config_utils import load_yaml, AttributeDict
from .schedulers import FlowMatchDiscreteScheduler
from tqdm import tqdm
def load_state_dict(model, model_path):
logger.info(f"Loading model state dict from: {model_path}")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):")
for key in missing_keys:
logger.warning(f" - {key}")
else:
logger.info("No missing keys found")
if unexpected_keys:
logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):")
for key in unexpected_keys:
logger.warning(f" - {key}")
else:
logger.info("No unexpected keys found")
logger.info("Model state dict loaded successfully")
return model
def load_model(model_path, config_path, device):
logger.info("Starting model loading process...")
logger.info(f"Configuration file: {config_path}")
logger.info(f"Model weights dir: {model_path}")
logger.info(f"Target device: {device}")
cfg = load_yaml(config_path)
logger.info("Configuration loaded successfully")
# HunyuanVideoFoley
logger.info("Loading HunyuanVideoFoley main model...")
foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16)
foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth"))
foley_model.eval()
logger.info("HunyuanVideoFoley model loaded and set to evaluation mode")
# DAC-VAE
dac_path = os.path.join(model_path, "vae_128d_48k.pth")
logger.info(f"Loading DAC VAE model from: {dac_path}")
dac_model = DAC.load(dac_path)
dac_model = dac_model.to(device)
dac_model.requires_grad_(False)
dac_model.eval()
logger.info("DAC VAE model loaded successfully")
# Siglip2 visual-encoder
logger.info("Loading SigLIP2 visual encoder...")
siglip2_preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval()
logger.info("SigLIP2 model and preprocessing pipeline loaded successfully")
# clap text-encoder
logger.info("Loading CLAP text encoder...")
clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general")
clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device)
logger.info("CLAP tokenizer and model loaded successfully")
# syncformer
syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth")
logger.info(f"Loading Synchformer model from: {syncformer_path}")
syncformer_preprocess = v2.Compose(
[
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
syncformer_model = Synchformer()
syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu"))
syncformer_model = syncformer_model.to(device).eval()
logger.info("Synchformer model and preprocessing pipeline loaded successfully")
logger.info("Creating model dictionary with attribute access...")
model_dict = AttributeDict({
'foley_model': foley_model,
'dac_model': dac_model,
'siglip2_preprocess': siglip2_preprocess,
'siglip2_model': siglip2_model,
'clap_tokenizer': clap_tokenizer,
'clap_model': clap_model,
'syncformer_preprocess': syncformer_preprocess,
'syncformer_model': syncformer_model,
'device': device,
})
logger.info("All models loaded successfully!")
logger.info("Available model components:")
for key in model_dict.keys():
logger.info(f" - {key}")
logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)")
return model_dict, cfg
def retrieve_timesteps(
scheduler,
num_inference_steps,
device,
**kwargs,
):
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device):
shape = (batch_size, num_channels_latents, int(length))
latents = randn_tensor(shape, device=device, dtype=dtype)
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
if hasattr(scheduler, "init_noise_sigma"):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
return latents
@torch.no_grad()
def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1):
target_dtype = model_dict.foley_model.dtype
autocast_enabled = target_dtype != torch.float32
device = model_dict.device
scheduler = FlowMatchDiscreteScheduler(
shift=cfg.diffusion_config.sample_flow_shift,
reverse=cfg.diffusion_config.flow_reverse,
solver=cfg.diffusion_config.flow_solver,
use_flux_shift=cfg.diffusion_config.sample_use_flux_shift,
flux_base_shift=cfg.diffusion_config.flux_base_shift,
flux_max_shift=cfg.diffusion_config.flux_max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
)
latents = prepare_latents(
scheduler,
batch_size=batch_size,
num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim,
length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate,
dtype=target_dtype,
device=device,
)
# Denoise loop
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"):
# noise latents
latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
latent_input = scheduler.scale_model_input(latent_input, t)
t_expand = t.repeat(latent_input.shape[0])
# siglip2 features
siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence(
bs=batch_size, len=siglip2_feat.shape[1]
).to(device)
if guidance_scale is not None and guidance_scale > 1.0:
siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0)
else:
siglip2_feat_input = siglip2_feat
# syncformer features
syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence(
bs=batch_size, len=syncformer_feat.shape[1]
).to(device)
if guidance_scale is not None and guidance_scale > 1.0:
syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0)
else:
syncformer_feat_input = syncformer_feat
# text features
text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
if guidance_scale is not None and guidance_scale > 1.0:
text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0)
else:
text_feat_input = text_feat_repeated
with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype):
# Predict the noise residual
noise_pred = model_dict.foley_model(
x=latent_input,
t=t_expand,
cond=text_feat_input,
clip_feat=siglip2_feat_input,
sync_feat=syncformer_feat_input,
return_dict=True,
)["x"]
noise_pred = noise_pred.to(dtype=torch.float32)
if guidance_scale is not None and guidance_scale > 1.0:
# Perform classifier-free guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# Post-process the latents to audio
with torch.no_grad():
audio = model_dict.dac_model.decode(latents)
audio = audio.float().cpu()
audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)]
return audio, model_dict.dac_model.sample_rate
|