|
import torch |
|
import os |
|
from loguru import logger |
|
from torchvision import transforms |
|
from torchvision.transforms import v2 |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection |
|
from ..models.dac_vae.model.dac import DAC |
|
from ..models.synchformer import Synchformer |
|
from ..models.hifi_foley import HunyuanVideoFoley |
|
from .config_utils import load_yaml, AttributeDict |
|
from .schedulers import FlowMatchDiscreteScheduler |
|
from tqdm import tqdm |
|
|
|
def load_state_dict(model, model_path): |
|
logger.info(f"Loading model state dict from: {model_path}") |
|
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
|
if missing_keys: |
|
logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") |
|
for key in missing_keys: |
|
logger.warning(f" - {key}") |
|
else: |
|
logger.info("No missing keys found") |
|
|
|
if unexpected_keys: |
|
logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") |
|
for key in unexpected_keys: |
|
logger.warning(f" - {key}") |
|
else: |
|
logger.info("No unexpected keys found") |
|
|
|
logger.info("Model state dict loaded successfully") |
|
return model |
|
|
|
def load_model(model_path, config_path, device): |
|
logger.info("Starting model loading process...") |
|
logger.info(f"Configuration file: {config_path}") |
|
logger.info(f"Model weights dir: {model_path}") |
|
logger.info(f"Target device: {device}") |
|
|
|
cfg = load_yaml(config_path) |
|
logger.info("Configuration loaded successfully") |
|
|
|
|
|
logger.info("Loading HunyuanVideoFoley main model...") |
|
foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) |
|
foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) |
|
foley_model.eval() |
|
logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") |
|
|
|
|
|
dac_path = os.path.join(model_path, "vae_128d_48k.pth") |
|
logger.info(f"Loading DAC VAE model from: {dac_path}") |
|
dac_model = DAC.load(dac_path) |
|
dac_model = dac_model.to(device) |
|
dac_model.requires_grad_(False) |
|
dac_model.eval() |
|
logger.info("DAC VAE model loaded successfully") |
|
|
|
|
|
logger.info("Loading SigLIP2 visual encoder...") |
|
siglip2_preprocess = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
]) |
|
siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval() |
|
logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") |
|
|
|
|
|
logger.info("Loading CLAP text encoder...") |
|
clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") |
|
clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) |
|
logger.info("CLAP tokenizer and model loaded successfully") |
|
|
|
|
|
syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") |
|
logger.info(f"Loading Synchformer model from: {syncformer_path}") |
|
syncformer_preprocess = v2.Compose( |
|
[ |
|
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), |
|
v2.CenterCrop(224), |
|
v2.ToImage(), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
syncformer_model = Synchformer() |
|
syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) |
|
syncformer_model = syncformer_model.to(device).eval() |
|
logger.info("Synchformer model and preprocessing pipeline loaded successfully") |
|
|
|
|
|
logger.info("Creating model dictionary with attribute access...") |
|
model_dict = AttributeDict({ |
|
'foley_model': foley_model, |
|
'dac_model': dac_model, |
|
'siglip2_preprocess': siglip2_preprocess, |
|
'siglip2_model': siglip2_model, |
|
'clap_tokenizer': clap_tokenizer, |
|
'clap_model': clap_model, |
|
'syncformer_preprocess': syncformer_preprocess, |
|
'syncformer_model': syncformer_model, |
|
'device': device, |
|
}) |
|
|
|
logger.info("All models loaded successfully!") |
|
logger.info("Available model components:") |
|
for key in model_dict.keys(): |
|
logger.info(f" - {key}") |
|
logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") |
|
|
|
return model_dict, cfg |
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps, |
|
device, |
|
**kwargs, |
|
): |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): |
|
shape = (batch_size, num_channels_latents, int(length)) |
|
latents = randn_tensor(shape, device=device, dtype=dtype) |
|
|
|
|
|
if hasattr(scheduler, "init_noise_sigma"): |
|
|
|
latents = latents * scheduler.init_noise_sigma |
|
|
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): |
|
|
|
target_dtype = model_dict.foley_model.dtype |
|
autocast_enabled = target_dtype != torch.float32 |
|
device = model_dict.device |
|
|
|
scheduler = FlowMatchDiscreteScheduler( |
|
shift=cfg.diffusion_config.sample_flow_shift, |
|
reverse=cfg.diffusion_config.flow_reverse, |
|
solver=cfg.diffusion_config.flow_solver, |
|
use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, |
|
flux_base_shift=cfg.diffusion_config.flux_base_shift, |
|
flux_max_shift=cfg.diffusion_config.flux_max_shift, |
|
) |
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps, |
|
device, |
|
) |
|
|
|
latents = prepare_latents( |
|
scheduler, |
|
batch_size=batch_size, |
|
num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, |
|
length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, |
|
dtype=target_dtype, |
|
device=device, |
|
) |
|
|
|
|
|
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): |
|
|
|
latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents |
|
latent_input = scheduler.scale_model_input(latent_input, t) |
|
|
|
t_expand = t.repeat(latent_input.shape[0]) |
|
|
|
|
|
siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) |
|
uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( |
|
bs=batch_size, len=siglip2_feat.shape[1] |
|
).to(device) |
|
|
|
if guidance_scale is not None and guidance_scale > 1.0: |
|
siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) |
|
else: |
|
siglip2_feat_input = siglip2_feat |
|
|
|
|
|
syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) |
|
uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( |
|
bs=batch_size, len=syncformer_feat.shape[1] |
|
).to(device) |
|
if guidance_scale is not None and guidance_scale > 1.0: |
|
syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) |
|
else: |
|
syncformer_feat_input = syncformer_feat |
|
|
|
|
|
text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) |
|
uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) |
|
if guidance_scale is not None and guidance_scale > 1.0: |
|
text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) |
|
else: |
|
text_feat_input = text_feat_repeated |
|
|
|
with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): |
|
|
|
noise_pred = model_dict.foley_model( |
|
x=latent_input, |
|
t=t_expand, |
|
cond=text_feat_input, |
|
clip_feat=siglip2_feat_input, |
|
sync_feat=syncformer_feat_input, |
|
return_dict=True, |
|
)["x"] |
|
|
|
noise_pred = noise_pred.to(dtype=torch.float32) |
|
|
|
if guidance_scale is not None and guidance_scale > 1.0: |
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
audio = model_dict.dac_model.decode(latents) |
|
audio = audio.float().cpu() |
|
|
|
audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] |
|
|
|
return audio, model_dict.dac_model.sample_rate |
|
|
|
|
|
|