File size: 11,805 Bytes
357c94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import math
import time
import torch
import random
from loguru import logger
from einops import rearrange
from hymm_sp.diffusion import load_diffusion_pipeline
from hymm_sp.helpers import get_nd_rotary_pos_embed_new
from hymm_sp.inference import Inference
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler
from hymm_sp.data_kits.audio_preprocessor import encode_audio, get_facemask

def align_to(value, alignment):
    return int(math.ceil(value / alignment) * alignment)

class HunyuanVideoSampler(Inference):
    def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None,
                 device=0, logger=None):
        super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2,
                         pipeline=pipeline,  device=device, logger=logger)
        
        self.args = args
        self.pipeline = load_diffusion_pipeline(
            args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model,
            device=self.device)
        print('load hunyuan model successful... ')

    def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}):
        target_ndim = 3
        ndim = 5 - 2
        if '884' in self.args.vae:
            latents_size = [(video_length-1)//4+1 , height//8, width//8]
        else:
            latents_size = [video_length , height//8, width//8]

        if isinstance(self.model.patch_size, int):
            assert all(s % self.model.patch_size == 0 for s in latents_size), \
                f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
                f"but got {latents_size}."
            rope_sizes = [s // self.model.patch_size for s in latents_size]
        elif isinstance(self.model.patch_size, list):
            assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
                f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
                f"but got {latents_size}."
            rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]

        if len(rope_sizes) != target_ndim:
            rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
        head_dim = self.model.hidden_size // self.model.num_heads
        rope_dim_list = self.model.rope_dim_list
        if rope_dim_list is None:
            rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
        assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
        freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, 
                                                    rope_sizes, 
                                                    theta=self.args.rope_theta, 
                                                    use_real=True,
                                                    theta_rescale_factor=1,
                                                    concat_dict=concat_dict)
        return freqs_cos, freqs_sin

    @torch.no_grad()
    def predict(self, 
                args, batch, wav2vec, feature_extractor, align_instance,
                **kwargs):
        """
        Predict the image from the given text.

        Args:
            prompt (str or List[str]): The input text.
            kwargs:
                size (int): The (height, width) of the output image/video. Default is (256, 256).
                video_length (int): The frame number of the output video. Default is 1.
                seed (int or List[str]): The random seed for the generation. Default is a random integer.
                negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
                infer_steps (int): The number of inference steps. Default is 100.
                guidance_scale (float): The guidance scale for the generation. Default is 6.0.
                num_videos_per_prompt (int): The number of videos per prompt. Default is 1.    
                verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1.
                output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
                    Default is 'pil'.
        """
        
        out_dict = dict()

        prompt = batch['text_prompt'][0]
        image_path = str(batch["image_path"][0])
        audio_path = str(batch["audio_path"][0])
        neg_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
        # videoid = batch['videoid'][0]
        fps = batch["fps"].to(self.device)
        audio_prompts = batch["audio_prompts"].to(self.device)
        weight_dtype = audio_prompts.dtype

        audio_prompts = [encode_audio(wav2vec, audio_feat.to(dtype=wav2vec.dtype), fps.item(), num_frames=batch["audio_len"][0]) for audio_feat in audio_prompts]
        audio_prompts = torch.cat(audio_prompts, dim=0).to(device=self.device, dtype=weight_dtype)
        if audio_prompts.shape[1] <= 129:
            audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1)
        else:
            audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
        
        wav2vec.to("cpu")
        torch.cuda.empty_cache()

        uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
        motion_exp = batch["motion_bucket_id_exps"].to(self.device)
        motion_pose = batch["motion_bucket_id_heads"].to(self.device)
        
        pixel_value_ref = batch['pixel_value_ref'].to(self.device)  # (b f c h w) 取值范围[0,255]
        face_masks = get_facemask(pixel_value_ref.clone(), align_instance, area=3.0) 

        pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1)
        uncond_pixel_value_ref = torch.zeros_like(pixel_value_ref)
        pixel_value_ref = pixel_value_ref / 127.5 - 1.             
        uncond_pixel_value_ref = uncond_pixel_value_ref * 2 - 1    
        
        pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w")
        uncond_uncond_pixel_value_ref = rearrange(uncond_pixel_value_ref, "b f c h w -> b c f h w")

        pixel_value_llava = batch["pixel_value_ref_llava"].to(self.device)
        pixel_value_llava = rearrange(pixel_value_llava, "b f c h w -> (b f) c h w")
        uncond_pixel_value_llava = pixel_value_llava.clone()
    
        # ========== Encode reference latents ==========
        vae_dtype = self.vae.dtype
        with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):

            if args.cpu_offload:
                self.vae.to('cuda')

            self.vae.enable_tiling()
            ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
            uncond_ref_latents = self.vae.encode(uncond_uncond_pixel_value_ref).latent_dist.sample()
            self.vae.disable_tiling()
            if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
                ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
                uncond_ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
            else:
                ref_latents.mul_(self.vae.config.scaling_factor)
                uncond_ref_latents.mul_(self.vae.config.scaling_factor)
            
            if args.cpu_offload:
                self.vae.to('cpu')
                torch.cuda.empty_cache()
                
        face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), 
                                                (ref_latents.shape[-2], 
                                                ref_latents.shape[-1]), 
                                                mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)


        size = (batch['pixel_value_ref'].shape[-2], batch['pixel_value_ref'].shape[-1])
        target_length = 129
        target_height = align_to(size[0], 16)
        target_width = align_to(size[1], 16)
        concat_dict = {'mode': 'timecat', 'bias': -1} 
        # concat_dict = {}
        freqs_cos, freqs_sin = self.get_rotary_pos_embed(
            target_length, 
            target_height, 
            target_width, 
            concat_dict)  
        n_tokens = freqs_cos.shape[0]

        generator = torch.Generator(device=self.device).manual_seed(args.seed)

        debug_str = f"""
                    prompt: {prompt}
                image_path: {image_path}
                audio_path: {audio_path}
           negative_prompt: {neg_prompt}
                      seed: {args.seed}
                       fps: {fps.item()}
               infer_steps: {args.infer_steps}
             target_height: {target_height}
              target_width: {target_width}
             target_length: {target_length}
            guidance_scale: {args.cfg_scale}
            """
        self.logger.info(debug_str)
        pipeline_kwargs = {
            "cpu_offload": args.cpu_offload
        }
        start_time = time.time()
        samples = self.pipeline(prompt=prompt,                                
                                height=target_height,
                                width=target_width,
                                frame=target_length,
                                num_inference_steps=args.infer_steps,
                                guidance_scale=args.cfg_scale,                      # cfg scale
                         
                                negative_prompt=neg_prompt,
                                num_images_per_prompt=args.num_images,
                                generator=generator,
                                prompt_embeds=None,

                                ref_latents=ref_latents,                            # [1, 16, 1, h//8, w//8]
                                uncond_ref_latents=uncond_ref_latents,
                                pixel_value_llava=pixel_value_llava,                # [1, 3, 336, 336]
                                uncond_pixel_value_llava=uncond_pixel_value_llava,
                                face_masks=face_masks,                              # [b f h w]
                                audio_prompts=audio_prompts, 
                                uncond_audio_prompts=uncond_audio_prompts, 
                                motion_exp=motion_exp, 
                                motion_pose=motion_pose, 
                                fps=fps, 
                                
                                num_videos_per_prompt=1,
                                attention_mask=None,
                                negative_prompt_embeds=None,
                                negative_attention_mask=None,
                                output_type="pil",
                                freqs_cis=(freqs_cos, freqs_sin),
                                n_tokens=n_tokens,
                                data_type='video',
                                is_progress_bar=True,
                                vae_ver=self.args.vae,
                                enable_tiling=self.args.vae_tiling,
                                **pipeline_kwargs
                                )[0]
        if samples is None:
            return None
        out_dict['samples'] = samples
        gen_time = time.time() - start_time
        logger.info(f"Success, time: {gen_time}")
        
        wav2vec.to(self.device)
        
        return out_dict