Spaces:
Build error
Build error
import os | |
import cv2 | |
import json | |
import time | |
import decord | |
import einops | |
import librosa | |
import torch | |
import random | |
import argparse | |
import traceback | |
import numpy as np | |
from tqdm import tqdm | |
from PIL import Image | |
from einops import rearrange | |
def get_facemask(ref_image, align_instance, area=1.25): | |
# ref_image: (b f c h w) | |
bsz, f, c, h, w = ref_image.shape | |
images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8) | |
face_masks = [] | |
for image in images: | |
image_pil = Image.fromarray(image).convert("RGB") | |
_, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) | |
try: | |
bboxSrc = bboxes_list[0] | |
except: | |
bboxSrc = [0, 0, w, h] | |
x1, y1, ww, hh = bboxSrc | |
x2, y2 = x1 + ww, y1 + hh | |
ww, hh = (x2-x1) * area, (y2-y1) * area | |
center = [(x2+x1)//2, (y2+y1)//2] | |
x1 = max(center[0] - ww//2, 0) | |
y1 = max(center[1] - hh//2, 0) | |
x2 = min(center[0] + ww//2, w) | |
y2 = min(center[1] + hh//2, h) | |
face_mask = np.zeros_like(np.array(image_pil)) | |
face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0 | |
face_masks.append(torch.from_numpy(face_mask[...,:1])) | |
face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c) | |
face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f) | |
face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype) | |
return face_masks | |
def encode_audio(wav2vec, audio_feats, fps, num_frames=129): | |
if fps == 25: | |
start_ts = [0] | |
step_ts = [1] | |
elif fps == 12.5: | |
start_ts = [0] | |
step_ts = [2] | |
num_frames = min(num_frames, 400) | |
audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states | |
audio_feats = torch.stack(audio_feats, dim=2) | |
audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1) | |
audio_prompts = [] | |
for bb in range(1): | |
audio_feats_list = [] | |
for f in range(num_frames): | |
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 | |
audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10] | |
audio_feats_list.append(audio_clip) | |
audio_feats_list = torch.stack(audio_feats_list, 1) | |
audio_prompts.append(audio_feats_list) | |
audio_prompts = torch.cat(audio_prompts) | |
return audio_prompts |