File size: 2,447 Bytes
357c94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

import os
import cv2
import json
import time
import decord
import einops
import librosa
import torch
import random
import argparse
import traceback
import numpy as np
from tqdm import tqdm
from PIL import Image
from einops import rearrange



def get_facemask(ref_image, align_instance, area=1.25):
    # ref_image: (b f c h w)
    bsz, f, c, h, w = ref_image.shape
    images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
    face_masks = []
    for image in images:
        image_pil = Image.fromarray(image).convert("RGB")
        _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
        try:
            bboxSrc = bboxes_list[0]
        except:
            bboxSrc = [0, 0, w, h]
        x1, y1, ww, hh = bboxSrc
        x2, y2 = x1 + ww, y1 + hh
        ww, hh = (x2-x1) * area, (y2-y1) * area
        center = [(x2+x1)//2, (y2+y1)//2]
        x1 = max(center[0] - ww//2, 0)
        y1 = max(center[1] - hh//2, 0)
        x2 = min(center[0] + ww//2, w)
        y2 = min(center[1] + hh//2, h)
        
        face_mask = np.zeros_like(np.array(image_pil))
        face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
        face_masks.append(torch.from_numpy(face_mask[...,:1]))
    face_masks = torch.stack(face_masks, dim=0)     # (b*f, h, w, c)
    face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
    face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
    return face_masks


def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
    if fps == 25:
        start_ts = [0]
        step_ts = [1]
    elif fps == 12.5:
        start_ts = [0]
        step_ts = [2]
    num_frames = min(num_frames, 400)
    audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
    audio_feats = torch.stack(audio_feats, dim=2)
    audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
    
    audio_prompts = []
    for bb in range(1):
        audio_feats_list = []
        for f in range(num_frames):
            cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
            audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
            audio_feats_list.append(audio_clip)
        audio_feats_list = torch.stack(audio_feats_list, 1)
        audio_prompts.append(audio_feats_list)
    audio_prompts = torch.cat(audio_prompts)
    return audio_prompts