Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
from loguru import logger | |
from torchvision import transforms | |
from torchvision.transforms import v2 | |
from diffusers.utils.torch_utils import randn_tensor | |
from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection | |
from ..models.dac_vae.model.dac import DAC | |
from ..models.synchformer import Synchformer | |
from ..models.hifi_foley import HunyuanVideoFoley | |
from .config_utils import load_yaml, AttributeDict | |
from .schedulers import FlowMatchDiscreteScheduler | |
from tqdm import tqdm | |
def load_state_dict(model, model_path): | |
logger.info(f"Loading model state dict from: {model_path}") | |
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if missing_keys: | |
logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") | |
for key in missing_keys: | |
logger.warning(f" - {key}") | |
else: | |
logger.info("No missing keys found") | |
if unexpected_keys: | |
logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") | |
for key in unexpected_keys: | |
logger.warning(f" - {key}") | |
else: | |
logger.info("No unexpected keys found") | |
logger.info("Model state dict loaded successfully") | |
return model | |
def load_model(model_path, config_path, device): | |
logger.info("Starting model loading process...") | |
logger.info(f"Configuration file: {config_path}") | |
logger.info(f"Model weights dir: {model_path}") | |
logger.info(f"Target device: {device}") | |
cfg = load_yaml(config_path) | |
logger.info("Configuration loaded successfully") | |
# HunyuanVideoFoley | |
logger.info("Loading HunyuanVideoFoley main model...") | |
foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) | |
foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) | |
foley_model.eval() | |
logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") | |
# DAC-VAE | |
dac_path = os.path.join(model_path, "vae_128d_48k.pth") | |
logger.info(f"Loading DAC VAE model from: {dac_path}") | |
dac_model = DAC.load(dac_path) | |
dac_model = dac_model.to(device) | |
dac_model.requires_grad_(False) | |
dac_model.eval() | |
logger.info("DAC VAE model loaded successfully") | |
# Siglip2 visual-encoder | |
logger.info("Loading SigLIP2 visual encoder...") | |
siglip2_preprocess = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval() | |
logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") | |
# clap text-encoder | |
logger.info("Loading CLAP text encoder...") | |
clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") | |
clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) | |
logger.info("CLAP tokenizer and model loaded successfully") | |
# syncformer | |
syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") | |
logger.info(f"Loading Synchformer model from: {syncformer_path}") | |
syncformer_preprocess = v2.Compose( | |
[ | |
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), | |
v2.CenterCrop(224), | |
v2.ToImage(), | |
v2.ToDtype(torch.float32, scale=True), | |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
] | |
) | |
syncformer_model = Synchformer() | |
syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) | |
syncformer_model = syncformer_model.to(device).eval() | |
logger.info("Synchformer model and preprocessing pipeline loaded successfully") | |
logger.info("Creating model dictionary with attribute access...") | |
model_dict = AttributeDict({ | |
'foley_model': foley_model, | |
'dac_model': dac_model, | |
'siglip2_preprocess': siglip2_preprocess, | |
'siglip2_model': siglip2_model, | |
'clap_tokenizer': clap_tokenizer, | |
'clap_model': clap_model, | |
'syncformer_preprocess': syncformer_preprocess, | |
'syncformer_model': syncformer_model, | |
'device': device, | |
}) | |
logger.info("All models loaded successfully!") | |
logger.info("Available model components:") | |
for key in model_dict.keys(): | |
logger.info(f" - {key}") | |
logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") | |
return model_dict, cfg | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps, | |
device, | |
**kwargs, | |
): | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): | |
shape = (batch_size, num_channels_latents, int(length)) | |
latents = randn_tensor(shape, device=device, dtype=dtype) | |
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler | |
if hasattr(scheduler, "init_noise_sigma"): | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * scheduler.init_noise_sigma | |
return latents | |
def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): | |
target_dtype = model_dict.foley_model.dtype | |
autocast_enabled = target_dtype != torch.float32 | |
device = model_dict.device | |
scheduler = FlowMatchDiscreteScheduler( | |
shift=cfg.diffusion_config.sample_flow_shift, | |
reverse=cfg.diffusion_config.flow_reverse, | |
solver=cfg.diffusion_config.flow_solver, | |
use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, | |
flux_base_shift=cfg.diffusion_config.flux_base_shift, | |
flux_max_shift=cfg.diffusion_config.flux_max_shift, | |
) | |
timesteps, num_inference_steps = retrieve_timesteps( | |
scheduler, | |
num_inference_steps, | |
device, | |
) | |
latents = prepare_latents( | |
scheduler, | |
batch_size=batch_size, | |
num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, | |
length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, | |
dtype=target_dtype, | |
device=device, | |
) | |
# Denoise loop | |
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): | |
# noise latents | |
latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
latent_input = scheduler.scale_model_input(latent_input, t) | |
t_expand = t.repeat(latent_input.shape[0]) | |
# siglip2 features | |
siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size | |
uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( | |
bs=batch_size, len=siglip2_feat.shape[1] | |
).to(device) | |
if guidance_scale is not None and guidance_scale > 1.0: | |
siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) | |
else: | |
siglip2_feat_input = siglip2_feat | |
# syncformer features | |
syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size | |
uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( | |
bs=batch_size, len=syncformer_feat.shape[1] | |
).to(device) | |
if guidance_scale is not None and guidance_scale > 1.0: | |
syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) | |
else: | |
syncformer_feat_input = syncformer_feat | |
# text features | |
text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size | |
uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size | |
if guidance_scale is not None and guidance_scale > 1.0: | |
text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) | |
else: | |
text_feat_input = text_feat_repeated | |
with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): | |
# Predict the noise residual | |
noise_pred = model_dict.foley_model( | |
x=latent_input, | |
t=t_expand, | |
cond=text_feat_input, | |
clip_feat=siglip2_feat_input, | |
sync_feat=syncformer_feat_input, | |
return_dict=True, | |
)["x"] | |
noise_pred = noise_pred.to(dtype=torch.float32) | |
if guidance_scale is not None and guidance_scale > 1.0: | |
# Perform classifier-free guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# Compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
# Post-process the latents to audio | |
with torch.no_grad(): | |
audio = model_dict.dac_model.decode(latents) | |
audio = audio.float().cpu() | |
audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] | |
return audio, model_dict.dac_model.sample_rate | |