James Zhou
[init]
9867d34
raw
history blame
9.67 kB
import torch
import os
from loguru import logger
from torchvision import transforms
from torchvision.transforms import v2
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection
from ..models.dac_vae.model.dac import DAC
from ..models.synchformer import Synchformer
from ..models.hifi_foley import HunyuanVideoFoley
from .config_utils import load_yaml, AttributeDict
from .schedulers import FlowMatchDiscreteScheduler
from tqdm import tqdm
def load_state_dict(model, model_path):
logger.info(f"Loading model state dict from: {model_path}")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):")
for key in missing_keys:
logger.warning(f" - {key}")
else:
logger.info("No missing keys found")
if unexpected_keys:
logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):")
for key in unexpected_keys:
logger.warning(f" - {key}")
else:
logger.info("No unexpected keys found")
logger.info("Model state dict loaded successfully")
return model
def load_model(model_path, config_path, device):
logger.info("Starting model loading process...")
logger.info(f"Configuration file: {config_path}")
logger.info(f"Model weights dir: {model_path}")
logger.info(f"Target device: {device}")
cfg = load_yaml(config_path)
logger.info("Configuration loaded successfully")
# HunyuanVideoFoley
logger.info("Loading HunyuanVideoFoley main model...")
foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16)
foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth"))
foley_model.eval()
logger.info("HunyuanVideoFoley model loaded and set to evaluation mode")
# DAC-VAE
dac_path = os.path.join(model_path, "vae_128d_48k.pth")
logger.info(f"Loading DAC VAE model from: {dac_path}")
dac_model = DAC.load(dac_path)
dac_model = dac_model.to(device)
dac_model.requires_grad_(False)
dac_model.eval()
logger.info("DAC VAE model loaded successfully")
# Siglip2 visual-encoder
logger.info("Loading SigLIP2 visual encoder...")
siglip2_preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval()
logger.info("SigLIP2 model and preprocessing pipeline loaded successfully")
# clap text-encoder
logger.info("Loading CLAP text encoder...")
clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general")
clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device)
logger.info("CLAP tokenizer and model loaded successfully")
# syncformer
syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth")
logger.info(f"Loading Synchformer model from: {syncformer_path}")
syncformer_preprocess = v2.Compose(
[
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
syncformer_model = Synchformer()
syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu"))
syncformer_model = syncformer_model.to(device).eval()
logger.info("Synchformer model and preprocessing pipeline loaded successfully")
logger.info("Creating model dictionary with attribute access...")
model_dict = AttributeDict({
'foley_model': foley_model,
'dac_model': dac_model,
'siglip2_preprocess': siglip2_preprocess,
'siglip2_model': siglip2_model,
'clap_tokenizer': clap_tokenizer,
'clap_model': clap_model,
'syncformer_preprocess': syncformer_preprocess,
'syncformer_model': syncformer_model,
'device': device,
})
logger.info("All models loaded successfully!")
logger.info("Available model components:")
for key in model_dict.keys():
logger.info(f" - {key}")
logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)")
return model_dict, cfg
def retrieve_timesteps(
scheduler,
num_inference_steps,
device,
**kwargs,
):
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device):
shape = (batch_size, num_channels_latents, int(length))
latents = randn_tensor(shape, device=device, dtype=dtype)
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
if hasattr(scheduler, "init_noise_sigma"):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
return latents
@torch.no_grad()
def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1):
target_dtype = model_dict.foley_model.dtype
autocast_enabled = target_dtype != torch.float32
device = model_dict.device
scheduler = FlowMatchDiscreteScheduler(
shift=cfg.diffusion_config.sample_flow_shift,
reverse=cfg.diffusion_config.flow_reverse,
solver=cfg.diffusion_config.flow_solver,
use_flux_shift=cfg.diffusion_config.sample_use_flux_shift,
flux_base_shift=cfg.diffusion_config.flux_base_shift,
flux_max_shift=cfg.diffusion_config.flux_max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
)
latents = prepare_latents(
scheduler,
batch_size=batch_size,
num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim,
length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate,
dtype=target_dtype,
device=device,
)
# Denoise loop
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"):
# noise latents
latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
latent_input = scheduler.scale_model_input(latent_input, t)
t_expand = t.repeat(latent_input.shape[0])
# siglip2 features
siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence(
bs=batch_size, len=siglip2_feat.shape[1]
).to(device)
if guidance_scale is not None and guidance_scale > 1.0:
siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0)
else:
siglip2_feat_input = siglip2_feat
# syncformer features
syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence(
bs=batch_size, len=syncformer_feat.shape[1]
).to(device)
if guidance_scale is not None and guidance_scale > 1.0:
syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0)
else:
syncformer_feat_input = syncformer_feat
# text features
text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
if guidance_scale is not None and guidance_scale > 1.0:
text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0)
else:
text_feat_input = text_feat_repeated
with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype):
# Predict the noise residual
noise_pred = model_dict.foley_model(
x=latent_input,
t=t_expand,
cond=text_feat_input,
clip_feat=siglip2_feat_input,
sync_feat=syncformer_feat_input,
return_dict=True,
)["x"]
noise_pred = noise_pred.to(dtype=torch.float32)
if guidance_scale is not None and guidance_scale > 1.0:
# Perform classifier-free guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# Post-process the latents to audio
with torch.no_grad():
audio = model_dict.dac_model.decode(latents)
audio = audio.float().cpu()
audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)]
return audio, model_dict.dac_model.sample_rate