|
|
|
import os
|
|
import cv2
|
|
import json
|
|
import time
|
|
import decord
|
|
import einops
|
|
import librosa
|
|
import torch
|
|
import random
|
|
import argparse
|
|
import traceback
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
def get_facemask(ref_image, align_instance, area=1.25):
|
|
|
|
bsz, f, c, h, w = ref_image.shape
|
|
images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
|
|
face_masks = []
|
|
for image in images:
|
|
image_pil = Image.fromarray(image).convert("RGB")
|
|
_, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
|
|
try:
|
|
bboxSrc = bboxes_list[0]
|
|
except:
|
|
bboxSrc = [0, 0, w, h]
|
|
x1, y1, ww, hh = bboxSrc
|
|
x2, y2 = x1 + ww, y1 + hh
|
|
ww, hh = (x2-x1) * area, (y2-y1) * area
|
|
center = [(x2+x1)//2, (y2+y1)//2]
|
|
x1 = max(center[0] - ww//2, 0)
|
|
y1 = max(center[1] - hh//2, 0)
|
|
x2 = min(center[0] + ww//2, w)
|
|
y2 = min(center[1] + hh//2, h)
|
|
|
|
face_mask = np.zeros_like(np.array(image_pil))
|
|
face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
|
|
face_masks.append(torch.from_numpy(face_mask[...,:1]))
|
|
face_masks = torch.stack(face_masks, dim=0)
|
|
face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
|
|
face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
|
|
return face_masks
|
|
|
|
|
|
def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
|
|
if fps == 25:
|
|
start_ts = [0]
|
|
step_ts = [1]
|
|
elif fps == 12.5:
|
|
start_ts = [0]
|
|
step_ts = [2]
|
|
else:
|
|
start_ts = [0]
|
|
step_ts = [1]
|
|
|
|
num_frames = min(num_frames, 400)
|
|
audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
|
|
audio_feats = torch.stack(audio_feats, dim=2)
|
|
audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
|
|
|
|
audio_prompts = []
|
|
for bb in range(1):
|
|
audio_feats_list = []
|
|
for f in range(num_frames):
|
|
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
|
audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
|
|
audio_feats_list.append(audio_clip)
|
|
audio_feats_list = torch.stack(audio_feats_list, 1)
|
|
audio_prompts.append(audio_feats_list)
|
|
audio_prompts = torch.cat(audio_prompts)
|
|
return audio_prompts |