File size: 9,666 Bytes
9867d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import torch
import os
from loguru import logger
from torchvision import transforms
from torchvision.transforms import v2
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection
from ..models.dac_vae.model.dac import DAC
from ..models.synchformer import Synchformer
from ..models.hifi_foley import HunyuanVideoFoley
from .config_utils import load_yaml, AttributeDict
from .schedulers import FlowMatchDiscreteScheduler
from tqdm import tqdm

def load_state_dict(model, model_path):
    logger.info(f"Loading model state dict from: {model_path}")
    state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
    
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    
    if missing_keys:
        logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):")
        for key in missing_keys:
            logger.warning(f"  - {key}")
    else:
        logger.info("No missing keys found")
    
    if unexpected_keys:
        logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):")
        for key in unexpected_keys:
            logger.warning(f"  - {key}")
    else:
        logger.info("No unexpected keys found")
    
    logger.info("Model state dict loaded successfully")
    return model

def load_model(model_path, config_path, device):
    logger.info("Starting model loading process...")
    logger.info(f"Configuration file: {config_path}")
    logger.info(f"Model weights dir: {model_path}")
    logger.info(f"Target device: {device}")
    
    cfg = load_yaml(config_path)
    logger.info("Configuration loaded successfully")
    
    # HunyuanVideoFoley
    logger.info("Loading HunyuanVideoFoley main model...")
    foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16)
    foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth"))
    foley_model.eval()
    logger.info("HunyuanVideoFoley model loaded and set to evaluation mode")

    # DAC-VAE
    dac_path = os.path.join(model_path, "vae_128d_48k.pth")
    logger.info(f"Loading DAC VAE model from: {dac_path}")
    dac_model = DAC.load(dac_path)
    dac_model = dac_model.to(device)
    dac_model.requires_grad_(False)
    dac_model.eval()
    logger.info("DAC VAE model loaded successfully")

    # Siglip2 visual-encoder
    logger.info("Loading SigLIP2 visual encoder...")
    siglip2_preprocess = transforms.Compose([
                transforms.Resize((512, 512)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ])
    siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval()
    logger.info("SigLIP2 model and preprocessing pipeline loaded successfully")

    # clap text-encoder
    logger.info("Loading CLAP text encoder...")
    clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general")
    clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device)
    logger.info("CLAP tokenizer and model loaded successfully")

    # syncformer
    syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth")
    logger.info(f"Loading Synchformer model from: {syncformer_path}")
    syncformer_preprocess = v2.Compose(
        [
            v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
            v2.CenterCrop(224),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )

    syncformer_model = Synchformer()
    syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu"))
    syncformer_model = syncformer_model.to(device).eval()
    logger.info("Synchformer model and preprocessing pipeline loaded successfully")


    logger.info("Creating model dictionary with attribute access...")
    model_dict = AttributeDict({
        'foley_model': foley_model,
        'dac_model': dac_model,
        'siglip2_preprocess': siglip2_preprocess,
        'siglip2_model': siglip2_model,
        'clap_tokenizer': clap_tokenizer,
        'clap_model': clap_model,
        'syncformer_preprocess': syncformer_preprocess,
        'syncformer_model': syncformer_model,
        'device': device,
    })
    
    logger.info("All models loaded successfully!")
    logger.info("Available model components:")
    for key in model_dict.keys():
        logger.info(f"  - {key}")
    logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)")

    return model_dict, cfg

def retrieve_timesteps(
    scheduler,
    num_inference_steps,
    device,
    **kwargs,
):
    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
    timesteps = scheduler.timesteps
    return timesteps, num_inference_steps


def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device):
    shape = (batch_size, num_channels_latents, int(length))
    latents = randn_tensor(shape, device=device, dtype=dtype)

    # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
    if hasattr(scheduler, "init_noise_sigma"):
        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * scheduler.init_noise_sigma

    return latents


@torch.no_grad()
def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1):

    target_dtype = model_dict.foley_model.dtype
    autocast_enabled = target_dtype != torch.float32
    device = model_dict.device

    scheduler = FlowMatchDiscreteScheduler(
        shift=cfg.diffusion_config.sample_flow_shift,
        reverse=cfg.diffusion_config.flow_reverse,
        solver=cfg.diffusion_config.flow_solver,
        use_flux_shift=cfg.diffusion_config.sample_use_flux_shift,
        flux_base_shift=cfg.diffusion_config.flux_base_shift,
        flux_max_shift=cfg.diffusion_config.flux_max_shift,
    )

    timesteps, num_inference_steps = retrieve_timesteps(
        scheduler,
        num_inference_steps,
        device,
    )

    latents = prepare_latents(
        scheduler,
        batch_size=batch_size,
        num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim,
        length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate,
        dtype=target_dtype,
        device=device,
    )

    # Denoise loop
    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"):
        # noise latents
        latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
        latent_input = scheduler.scale_model_input(latent_input, t)

        t_expand = t.repeat(latent_input.shape[0])

        # siglip2 features
        siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1)  # Repeat for batch_size
        uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence(
                bs=batch_size, len=siglip2_feat.shape[1]
        ).to(device)

        if guidance_scale is not None and guidance_scale > 1.0:
            siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0)
        else:
            siglip2_feat_input = siglip2_feat

        # syncformer features
        syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1)  # Repeat for batch_size
        uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence(
                bs=batch_size, len=syncformer_feat.shape[1]
        ).to(device)
        if guidance_scale is not None and guidance_scale > 1.0:
            syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0)
        else:
            syncformer_feat_input = syncformer_feat

        # text features
        text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1)  # Repeat for batch_size
        uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1)  # Repeat for batch_size
        if guidance_scale is not None and guidance_scale > 1.0:
            text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0)
        else:
            text_feat_input = text_feat_repeated

        with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype):
            # Predict the noise residual
            noise_pred = model_dict.foley_model(
                x=latent_input,
                t=t_expand,
                cond=text_feat_input,
                clip_feat=siglip2_feat_input,
                sync_feat=syncformer_feat_input,
                return_dict=True,
            )["x"]

        noise_pred = noise_pred.to(dtype=torch.float32)

        if guidance_scale is not None and guidance_scale > 1.0:
            # Perform classifier-free guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    # Post-process the latents to audio

    with torch.no_grad():
        audio = model_dict.dac_model.decode(latents)
        audio = audio.float().cpu()

    audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)]

    return audio, model_dict.dac_model.sample_rate