Hunyuan-Avatar / hymm_sp /data_kits /audio_preprocessor.py
rahul7star's picture
Upload 99 files
357c94c verified
import os
import cv2
import json
import time
import decord
import einops
import librosa
import torch
import random
import argparse
import traceback
import numpy as np
from tqdm import tqdm
from PIL import Image
from einops import rearrange
def get_facemask(ref_image, align_instance, area=1.25):
# ref_image: (b f c h w)
bsz, f, c, h, w = ref_image.shape
images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
face_masks = []
for image in images:
image_pil = Image.fromarray(image).convert("RGB")
_, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
try:
bboxSrc = bboxes_list[0]
except:
bboxSrc = [0, 0, w, h]
x1, y1, ww, hh = bboxSrc
x2, y2 = x1 + ww, y1 + hh
ww, hh = (x2-x1) * area, (y2-y1) * area
center = [(x2+x1)//2, (y2+y1)//2]
x1 = max(center[0] - ww//2, 0)
y1 = max(center[1] - hh//2, 0)
x2 = min(center[0] + ww//2, w)
y2 = min(center[1] + hh//2, h)
face_mask = np.zeros_like(np.array(image_pil))
face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
face_masks.append(torch.from_numpy(face_mask[...,:1]))
face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c)
face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
return face_masks
def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
if fps == 25:
start_ts = [0]
step_ts = [1]
elif fps == 12.5:
start_ts = [0]
step_ts = [2]
num_frames = min(num_frames, 400)
audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2)
audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
audio_prompts = []
for bb in range(1):
audio_feats_list = []
for f in range(num_frames):
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
audio_feats_list.append(audio_clip)
audio_feats_list = torch.stack(audio_feats_list, 1)
audio_prompts.append(audio_feats_list)
audio_prompts = torch.cat(audio_prompts)
return audio_prompts