MH0386's picture
Upload folder using huggingface_hub
3e165b2 verified
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.conformer.encoder import Encoder
from torch import nn
from visualizr import logger
from visualizr.model.base import BaseModule
class LSTM(nn.Module):
def __init__(self, motion_dim, output_dim, num_layers=2, hidden_dim=128):
super().__init__()
self.lstm = nn.LSTM(
input_size=motion_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x, _ = self.lstm(x)
return self.fc(x)
class DiffusionPredictor(BaseModule):
def __init__(self, conf):
super(DiffusionPredictor, self).__init__()
self.infer_type = conf.infer_type
self.initialize_layers(conf)
logger.info(f"infer_type: {self.infer_type}")
def create_conformer_encoder(self, attention_dim, num_blocks):
return Encoder(
idim=0,
attention_dim=attention_dim,
attention_heads=2,
linear_units=attention_dim,
num_blocks=num_blocks,
input_layer=None,
dropout_rate=0.2,
positional_dropout_rate=0.2,
attention_dropout_rate=0.2,
normalize_before=False,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=3,
macaron_style=True,
pos_enc_layer_type="rel_pos",
selfattention_layer_type="rel_selfattn",
use_cnn_module=True,
cnn_module_kernel=13,
)
def initialize_layers(
self,
conf,
mfcc_dim=39,
hubert_dim=1024,
speech_layers=4,
speech_dim=512,
decoder_dim=1024,
motion_start_dim=512,
HAL_layers=25,
):
self.conf = conf
# Speech downsampling
if self.infer_type.startswith("mfcc"):
# from 100 hz to 25 hz
self.down_sample1 = nn.Conv1d(
mfcc_dim, 256, kernel_size=3, stride=2, padding=1
)
self.down_sample2 = nn.Conv1d(
256, speech_dim, kernel_size=3, stride=2, padding=1
)
elif self.infer_type.startswith("hubert"):
# from 50 hz to 25 hz
self.down_sample1 = nn.Conv1d(
hubert_dim, speech_dim, kernel_size=3, stride=2, padding=1
)
self.weights = nn.Parameter(torch.zeros(HAL_layers))
self.speech_encoder = self.create_conformer_encoder(
speech_dim, speech_layers
)
else:
logger.exception("infer_type not supported")
# Encoders & Decoders
self.coarse_decoder = self.create_conformer_encoder(
decoder_dim, conf.decoder_layers
)
# LSTM predictors for Variance Adapter
if self.infer_type != "hubert_audio_only":
self.pose_predictor = LSTM(speech_dim, 3)
self.pose_encoder = LSTM(3, speech_dim)
if "full_control" in self.infer_type:
self.location_predictor = LSTM(speech_dim, 1)
self.location_encoder = LSTM(1, speech_dim)
self.face_scale_predictor = LSTM(speech_dim, 1)
self.face_scale_encoder = LSTM(1, speech_dim)
# Linear transformations
self.init_code_proj = nn.Sequential(nn.Linear(motion_start_dim, 128))
self.noisy_encoder = nn.Sequential(nn.Linear(conf.motion_dim, 128))
self.t_encoder = nn.Sequential(nn.Linear(1, 128))
self.encoder_direction_code = nn.Linear(conf.motion_dim, 128)
self.out_proj = nn.Linear(decoder_dim, conf.motion_dim)
def forward(
self,
initial_code,
direction_code,
seq_input_vector,
face_location,
face_scale,
yaw_pitch_roll,
noisy_x,
t_emb,
control_flag=False,
):
global x
if self.infer_type.startswith("mfcc"):
x = self.mfcc_speech_downsample(seq_input_vector)
elif self.infer_type.startswith("hubert"):
norm_weights = F.softmax(self.weights, dim=-1)
weighted_feature = (
norm_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) * seq_input_vector
).sum(dim=1)
x = self.down_sample1(weighted_feature.transpose(1, 2)).transpose(1, 2)
x, _ = self.speech_encoder(x, masks=None)
predicted_location, predicted_scale, predicted_pose = (
face_location,
face_scale,
yaw_pitch_roll,
)
if self.infer_type != "hubert_audio_only":
logger.info(f"pose controllable. control_flag: {control_flag}")
x, predicted_location, predicted_scale, predicted_pose = (
self.adjust_features(
x, face_location, face_scale, yaw_pitch_roll, control_flag
)
)
# Variable initial_code and direction_code serve as a motion guide
# extracted from the reference image.
# This aims to tell the model what the starting motion should be.
concatenated_features = self.combine_features(
x, initial_code, direction_code, noisy_x, t_emb
)
outputs = self.decode_features(concatenated_features)
return outputs, predicted_location, predicted_scale, predicted_pose
def mfcc_speech_downsample(self, seq_input_vector):
x = self.down_sample1(seq_input_vector.transpose(1, 2))
return self.down_sample2(x).transpose(1, 2)
def adjust_features(
self, x, face_location, face_scale, yaw_pitch_roll, control_flag
):
predicted_location, predicted_scale = 0, 0
if "full_control" in self.infer_type:
logger.info(f"full controllable. control_flag: {control_flag}")
x_residual, predicted_location = self.adjust_location(
x, face_location, control_flag
)
x = x + x_residual
x_residual, predicted_scale = self.adjust_scale(x, face_scale, control_flag)
x = x + x_residual
x_residual, predicted_pose = self.adjust_pose(x, yaw_pitch_roll, control_flag)
x = x + x_residual
return x, predicted_location, predicted_scale, predicted_pose
def adjust_location(self, x, face_location, control_flag):
if control_flag:
predicted_location = face_location
else:
predicted_location = self.location_predictor(x)
return self.location_encoder(predicted_location), predicted_location
def adjust_scale(self, x, face_scale, control_flag):
if control_flag:
predicted_face_scale = face_scale
else:
predicted_face_scale = self.face_scale_predictor(x)
return self.face_scale_encoder(predicted_face_scale), predicted_face_scale
def adjust_pose(self, x, yaw_pitch_roll, control_flag):
if control_flag:
predicted_pose = yaw_pitch_roll
else:
predicted_pose = self.pose_predictor(x)
return self.pose_encoder(predicted_pose), predicted_pose
def combine_features(self, x, initial_code, direction_code, noisy_x, t_emb):
init_code_proj = (
self.init_code_proj(initial_code).unsqueeze(1).repeat(1, x.size(1), 1)
)
noisy_feature = self.noisy_encoder(noisy_x)
t_emb_feature = (
self.t_encoder(t_emb.unsqueeze(1).float())
.unsqueeze(1)
.repeat(1, x.size(1), 1)
)
direction_code_feature = (
self.encoder_direction_code(direction_code)
.unsqueeze(1)
.repeat(1, x.size(1), 1)
)
return torch.cat(
(x, direction_code_feature, init_code_proj, noisy_feature, t_emb_feature),
dim=-1,
)
def decode_features(self, concatenated_features):
outputs, _ = self.coarse_decoder(concatenated_features, masks=None)
return self.out_proj(outputs)