Spaces:
Configuration error
Configuration error
| # -*- coding: utf-8 -*- | |
| # @Time : 2024/12/15 | |
| # @Author : wenshao | |
| # @Email : [email protected] | |
| # @Project : FasterLivePortrait | |
| # @FileName: joyvasa_audio_to_motion_pipeline.py | |
| import math | |
| import pdb | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import pickle | |
| from tqdm import tqdm | |
| import pathlib | |
| import os | |
| from ..models.JoyVASA.dit_talking_head import DitTalkingHead | |
| from ..models.JoyVASA.helper import NullableArgs | |
| from ..utils import utils | |
| class JoyVASAAudio2MotionPipeline: | |
| """ | |
| JoyVASA 声音生成LivePortrait Motion | |
| """ | |
| def __init__(self, **kwargs): | |
| self.device, self.dtype = utils.get_opt_device_dtype() | |
| # Check if the operating system is Windows | |
| if os.name == 'nt': | |
| temp = pathlib.PosixPath | |
| pathlib.PosixPath = pathlib.WindowsPath | |
| motion_model_path = kwargs.get("motion_model_path", "") | |
| audio_model_path = kwargs.get("audio_model_path", "") | |
| motion_template_path = kwargs.get("motion_template_path", "") | |
| model_data = torch.load(motion_model_path, map_location="cpu") | |
| model_args = NullableArgs(model_data['args']) | |
| model = DitTalkingHead(motion_feat_dim=model_args.motion_feat_dim, | |
| n_motions=model_args.n_motions, | |
| n_prev_motions=model_args.n_prev_motions, | |
| feature_dim=model_args.feature_dim, | |
| audio_model=model_args.audio_model, | |
| n_diff_steps=model_args.n_diff_steps, | |
| audio_encoder_path=audio_model_path) | |
| model_data['model'].pop('denoising_net.TE.pe') | |
| model.load_state_dict(model_data['model'], strict=False) | |
| model.to(self.device, dtype=self.dtype) | |
| model.eval() | |
| # Restore the original PosixPath if it was changed | |
| if os.name == 'nt': | |
| pathlib.PosixPath = temp | |
| self.motion_generator = model | |
| self.n_motions = model_args.n_motions | |
| self.n_prev_motions = model_args.n_prev_motions | |
| self.fps = model_args.fps | |
| self.audio_unit = 16000. / self.fps # num of samples per frame | |
| self.n_audio_samples = round(self.audio_unit * self.n_motions) | |
| self.pad_mode = model_args.pad_mode | |
| self.use_indicator = model_args.use_indicator | |
| self.cfg_mode = kwargs.get("cfg_mode", "incremental") | |
| self.cfg_cond = kwargs.get("cfg_cond", None) | |
| self.cfg_scale = kwargs.get("cfg_scale", 2.8) | |
| with open(motion_template_path, 'rb') as fin: | |
| self.templete_dict = pickle.load(fin) | |
| def gen_motion_sequence(self, audio_path, **kwargs): | |
| # preprocess audio | |
| audio, sample_rate = torchaudio.load(audio_path) | |
| if sample_rate != 16000: | |
| audio = torchaudio.functional.resample( | |
| audio, | |
| orig_freq=sample_rate, | |
| new_freq=16000, | |
| ) | |
| audio = audio.mean(0).to(self.device, dtype=self.dtype) | |
| # audio = F.pad(audio, (1280, 640), "constant", 0) | |
| # audio_mean, audio_std = torch.mean(audio), torch.std(audio) | |
| # audio = (audio - audio_mean) / (audio_std + 1e-5) | |
| # crop audio into n_subdivision according to n_motions | |
| clip_len = int(len(audio) / 16000 * self.fps) | |
| stride = self.n_motions | |
| if clip_len <= self.n_motions: | |
| n_subdivision = 1 | |
| else: | |
| n_subdivision = math.ceil(clip_len / stride) | |
| # padding | |
| n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio) | |
| n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit) | |
| if n_padding_audio_samples > 0: | |
| if self.pad_mode == 'zero': | |
| padding_value = 0 | |
| elif self.pad_mode == 'replicate': | |
| padding_value = audio[-1] | |
| else: | |
| raise ValueError(f'Unknown pad mode: {self.pad_mode}') | |
| audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value) | |
| # generate motions | |
| coef_list = [] | |
| for i in range(0, n_subdivision): | |
| start_idx = i * stride | |
| end_idx = start_idx + self.n_motions | |
| indicator = torch.ones((1, self.n_motions)).to(self.device) if self.use_indicator else None | |
| if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0: | |
| indicator[:, -n_padding_frames:] = 0 | |
| audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0) | |
| if i == 0: | |
| motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in, | |
| indicator=indicator, | |
| cfg_mode=self.cfg_mode, | |
| cfg_cond=self.cfg_cond, | |
| cfg_scale=self.cfg_scale, | |
| dynamic_threshold=0) | |
| else: | |
| motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in, | |
| prev_motion_feat.to(self.dtype), | |
| prev_audio_feat.to(self.dtype), | |
| noise.to(self.dtype), | |
| indicator=indicator, | |
| cfg_mode=self.cfg_mode, | |
| cfg_cond=self.cfg_cond, | |
| cfg_scale=self.cfg_scale, | |
| dynamic_threshold=0) | |
| prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone() | |
| prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:] | |
| motion_coef = motion_feat | |
| if i == n_subdivision - 1 and n_padding_frames > 0: | |
| motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames | |
| coef_list.append(motion_coef) | |
| motion_coef = torch.cat(coef_list, dim=1) | |
| # motion_coef = self.reformat_motion(args, motion_coef) | |
| motion_coef = motion_coef.squeeze().cpu().numpy().astype(np.float32) | |
| motion_list = [] | |
| for idx in tqdm(range(motion_coef.shape[0]), total=motion_coef.shape[0]): | |
| exp = motion_coef[idx][:63] * self.templete_dict["std_exp"] + self.templete_dict["mean_exp"] | |
| scale = motion_coef[idx][63:64] * ( | |
| self.templete_dict["max_scale"] - self.templete_dict["min_scale"]) + self.templete_dict[ | |
| "min_scale"] | |
| t = motion_coef[idx][64:67] * (self.templete_dict["max_t"] - self.templete_dict["min_t"]) + \ | |
| self.templete_dict["min_t"] | |
| pitch = motion_coef[idx][67:68] * ( | |
| self.templete_dict["max_pitch"] - self.templete_dict["min_pitch"]) + self.templete_dict[ | |
| "min_pitch"] | |
| yaw = motion_coef[idx][68:69] * (self.templete_dict["max_yaw"] - self.templete_dict["min_yaw"]) + \ | |
| self.templete_dict["min_yaw"] | |
| roll = motion_coef[idx][69:70] * (self.templete_dict["max_roll"] - self.templete_dict["min_roll"]) + \ | |
| self.templete_dict["min_roll"] | |
| R = utils.get_rotation_matrix(pitch, yaw, roll) | |
| R = R.reshape(1, 3, 3).astype(np.float32) | |
| exp = exp.reshape(1, 21, 3).astype(np.float32) | |
| scale = scale.reshape(1, 1).astype(np.float32) | |
| t = t.reshape(1, 3).astype(np.float32) | |
| pitch = pitch.reshape(1, 1).astype(np.float32) | |
| yaw = yaw.reshape(1, 1).astype(np.float32) | |
| roll = roll.reshape(1, 1).astype(np.float32) | |
| motion_list.append({"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll}) | |
| tgt_motion = {'n_frames': motion_coef.shape[0], 'output_fps': self.fps, 'motion': motion_list, 'c_eyes_lst': [], | |
| 'c_lip_lst': []} | |
| return tgt_motion | |