Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,387 +1,118 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
from
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
nn.init.constant_(self.mapping1.bias, 0.)
|
| 31 |
-
self.use_ref = use_ref
|
| 32 |
-
|
| 33 |
-
def forward(self, x, ref, use_tanh=False):
|
| 34 |
-
x = self.audio_encoder.forward_feature(x).view(x.size(0), -1)
|
| 35 |
-
ref_reshape = ref.reshape(x.size(0), -1) #20, -1
|
| 36 |
-
|
| 37 |
-
y = self.mapping1(torch.cat([x, ref_reshape], dim=1))
|
| 38 |
-
|
| 39 |
-
if self.use_ref:
|
| 40 |
-
out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial
|
| 41 |
-
else:
|
| 42 |
-
out = y.reshape(ref.shape[0], ref.shape[1], -1)
|
| 43 |
-
|
| 44 |
-
if use_tanh:
|
| 45 |
-
out[:, :50] = torch.tanh(out[:, :50]) * 3
|
| 46 |
-
|
| 47 |
-
return out
|
| 48 |
-
|
| 49 |
-
class Audio2Mesh(object):
|
| 50 |
-
def __init__(self, args) -> None:
|
| 51 |
-
self.args = args
|
| 52 |
-
|
| 53 |
-
spectre_cfg.model.use_tex = True
|
| 54 |
-
spectre_cfg.model.mask_type = args.mask_type
|
| 55 |
-
spectre_cfg.debug = self.args.debug
|
| 56 |
-
spectre_cfg.model.netA_sync = 'ressesync'
|
| 57 |
-
spectre_cfg.model.gpu_ids = [0]
|
| 58 |
-
|
| 59 |
-
self.spectre = SPECTRE(spectre_cfg)
|
| 60 |
-
self.spectre.eval()
|
| 61 |
-
self.face_tracker = None #FaceTrackerV2() # face landmark detection
|
| 62 |
-
self.mel_step_size = 16
|
| 63 |
-
self.fps = args.fps
|
| 64 |
-
self.Nw = args.tframes
|
| 65 |
-
self.device = self.args.device
|
| 66 |
-
self.image_size = self.args.image_size
|
| 67 |
-
|
| 68 |
-
### only audio
|
| 69 |
-
args.netA_sync = 'ressesync'
|
| 70 |
-
args.gpu_ids = [0]
|
| 71 |
-
args.exp_dim = 53
|
| 72 |
-
args.use_tanh = False
|
| 73 |
-
args.K = 20
|
| 74 |
-
|
| 75 |
-
self.audio2exp = 'pcavs'
|
| 76 |
-
|
| 77 |
-
#
|
| 78 |
-
self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda()
|
| 79 |
-
self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt'])
|
| 80 |
-
|
| 81 |
-
# 5, 160 = 25fps
|
| 82 |
-
self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160)
|
| 83 |
-
|
| 84 |
-
with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ?
|
| 85 |
-
self.fitting_coeffs = pickle.load(f, encoding='bytes')
|
| 86 |
-
|
| 87 |
-
self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']}
|
| 88 |
-
|
| 89 |
-
#### find the close month
|
| 90 |
-
exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1)
|
| 91 |
-
ssss, sorted_indices = torch.sort(exp_tensors)
|
| 92 |
-
self.exp_id = sorted_indices[0].item()
|
| 93 |
-
|
| 94 |
-
if '.ts' in args.render_path:
|
| 95 |
-
self.render = torch.jit.load(args.render_path).cuda()
|
| 96 |
-
self.trt = True
|
| 97 |
-
else:
|
| 98 |
-
self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda()
|
| 99 |
-
self.render.load_state_dict(torch.load(args.render_path))
|
| 100 |
-
self.trt = False
|
| 101 |
-
|
| 102 |
-
print('loaded cached images...')
|
| 103 |
-
|
| 104 |
-
@torch.no_grad()
|
| 105 |
-
def cg2real(self, rendedimages, start_frame=0):
|
| 106 |
-
|
| 107 |
-
## load original image and the mask
|
| 108 |
-
self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\
|
| 109 |
-
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
|
| 110 |
-
self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\
|
| 111 |
-
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
|
| 112 |
-
|
| 113 |
-
self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.)
|
| 114 |
-
self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.)
|
| 115 |
-
|
| 116 |
-
## padding the rended_imgs
|
| 117 |
-
paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous()
|
| 118 |
-
paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous()
|
| 119 |
-
paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous()
|
| 120 |
-
|
| 121 |
-
# paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask)
|
| 122 |
-
padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor)
|
| 123 |
-
padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
|
| 124 |
-
paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
|
| 125 |
-
paded_tensor = (paded_tensor-0.5)*2
|
| 126 |
-
|
| 127 |
-
result = []
|
| 128 |
-
for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'):
|
| 129 |
-
list_A = []
|
| 130 |
-
list_R = []
|
| 131 |
-
list_M = []
|
| 132 |
-
for i in range(self.args.renderbs):
|
| 133 |
-
idx = index + i
|
| 134 |
-
if idx+self.Nw > len(padded_input):
|
| 135 |
-
list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
| 136 |
-
list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
| 137 |
-
list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
|
| 138 |
-
else:
|
| 139 |
-
list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
| 140 |
-
list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
| 141 |
-
list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
|
| 142 |
-
|
| 143 |
-
list_A = torch.cat(list_A)
|
| 144 |
-
list_R = torch.cat(list_R)
|
| 145 |
-
list_M = torch.cat(list_M)
|
| 146 |
-
|
| 147 |
-
idx = (self.Nw//2) * 3
|
| 148 |
-
mask = list_M[:, idx:idx+3]
|
| 149 |
-
|
| 150 |
-
# list_A = padded_input
|
| 151 |
-
mask = maskErosion(mask, offY=self.args.mask)
|
| 152 |
-
list_A = list_A * (1 - mask[:,0:1])
|
| 153 |
-
A = torch.cat([list_A, list_R], 1)
|
| 154 |
-
|
| 155 |
-
if self.trt:
|
| 156 |
-
B = self.render(A.half().cuda())
|
| 157 |
-
elif self.args.netR == 'unet_256':
|
| 158 |
-
# import pdb; pdb.set_trace()
|
| 159 |
-
idx = (self.Nw//2) * 3
|
| 160 |
-
mask = list_M[:, idx:idx+3].cuda()
|
| 161 |
-
mask = maskErosion(mask, offY=self.args.mask)
|
| 162 |
-
B0 = list_A[:, idx:idx+3].cuda()
|
| 163 |
-
B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
|
| 164 |
-
elif self.args.netR == 's2am':
|
| 165 |
-
# import pdb; pdb.set_trace()
|
| 166 |
-
idx = (self.Nw//2) * 3
|
| 167 |
-
mask = list_M[:, idx:idx+3].cuda()
|
| 168 |
-
mask = maskErosion(mask, offY=self.args.mask)
|
| 169 |
-
B0 = list_A[:, idx:idx+3].cuda()
|
| 170 |
-
B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
|
| 171 |
-
else:
|
| 172 |
-
B = self.render(A.cuda())
|
| 173 |
-
|
| 174 |
-
result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1
|
| 175 |
-
|
| 176 |
-
return torch.cat(result)[:len(rendedimages)]
|
| 177 |
-
|
| 178 |
-
@torch.no_grad()
|
| 179 |
-
def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20):
|
| 180 |
-
|
| 181 |
-
xlen = vertices.shape[0]
|
| 182 |
-
all_shape_images = []
|
| 183 |
-
landmark2d = []
|
| 184 |
-
|
| 185 |
-
#### find the most larger pose 51 in the coeffs.
|
| 186 |
-
max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1))
|
| 187 |
-
|
| 188 |
-
for i in tqdm(range(0, xlen, XK)):
|
| 189 |
-
|
| 190 |
-
if i + XK > xlen:
|
| 191 |
-
XK = xlen - i
|
| 192 |
-
|
| 193 |
-
codedictdecoder = {}
|
| 194 |
-
codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda())
|
| 195 |
-
codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda()
|
| 196 |
-
codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK
|
| 197 |
-
codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda()
|
| 198 |
-
codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
|
| 199 |
-
codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
|
| 200 |
-
codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda()
|
| 201 |
-
|
| 202 |
-
codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
|
| 203 |
-
codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
|
| 204 |
-
|
| 205 |
-
sub_vertices = vertices[i:i+XK].cuda()
|
| 206 |
-
|
| 207 |
-
opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False)
|
| 208 |
-
|
| 209 |
-
landmark2d.append(opdict['landmarks2d'].cpu())
|
| 210 |
-
|
| 211 |
-
all_shape_images.append(opdict['rendered_images'].cpu())
|
| 212 |
-
|
| 213 |
-
rendedimages = torch.cat(all_shape_images)
|
| 214 |
-
|
| 215 |
-
lmk2d = torch.cat(landmark2d)
|
| 216 |
-
|
| 217 |
-
return rendedimages, lmk2d
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
@torch.no_grad()
|
| 221 |
-
def run_spectre_v3(self, wav=None, ds_features=None, L=20):
|
| 222 |
-
|
| 223 |
-
wav = audio_normalize(wav)
|
| 224 |
-
all_mel = self.audio.melspectrogram(wav).astype(np.float32).T
|
| 225 |
-
frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm
|
| 226 |
-
audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame)
|
| 227 |
-
|
| 228 |
-
vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1]
|
| 229 |
-
vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1]
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3)
|
| 300 |
-
|
| 301 |
-
return prediction.squeeze()
|
| 302 |
-
|
| 303 |
-
@torch.no_grad()
|
| 304 |
-
def run(self, face, audio, start_frame=0):
|
| 305 |
-
|
| 306 |
-
wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80
|
| 307 |
-
wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav)
|
| 308 |
-
_, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps)
|
| 309 |
-
|
| 310 |
-
##### audio-guided, only use the jaw movement
|
| 311 |
-
all_exps = self.run_spectre_v3(wav)
|
| 312 |
-
|
| 313 |
-
# #### temp. interpolation
|
| 314 |
-
all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
|
| 315 |
-
all_exps = all_exps.permute([0,2,1]).squeeze(0)
|
| 316 |
-
|
| 317 |
-
# run faceformer for face mesh generation
|
| 318 |
-
predicted_vertices = self.test_model(audio)
|
| 319 |
-
predicted_vertices = predicted_vertices.view(-1, 5023*3)
|
| 320 |
-
|
| 321 |
-
#### temp. interpolation
|
| 322 |
-
predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
|
| 323 |
-
predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3)
|
| 324 |
-
|
| 325 |
-
all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT
|
| 326 |
-
|
| 327 |
-
rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True)
|
| 328 |
-
debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr)
|
| 329 |
-
|
| 330 |
-
# cg2real
|
| 331 |
-
debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr)
|
| 332 |
-
|
| 333 |
-
exit()
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
if __name__ == '__main__':
|
| 338 |
-
parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing')
|
| 339 |
-
parser.add_argument('--face', default='examples', type=str, help='')
|
| 340 |
-
parser.add_argument('--audio', default='examples', type=str, help='')
|
| 341 |
-
parser.add_argument('--source_dir', default='examples', type=str,help='TODO')
|
| 342 |
-
parser.add_argument('--result_dir', default='examples', type=str,help='TODO')
|
| 343 |
-
parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs')
|
| 344 |
-
parser.add_argument('--result_tag', default='result', type=str,help='TODO')
|
| 345 |
-
parser.add_argument('--netR', default='unet_256', type=str,help='TODO')
|
| 346 |
-
parser.add_argument('--render_path', default='', type=str,help='TODO')
|
| 347 |
-
parser.add_argument('--ngf', default=16, type=int,help='TODO')
|
| 348 |
-
parser.add_argument('--fps', default=20, type=int,help='TODO')
|
| 349 |
-
parser.add_argument('--mask', default=100, type=int,help='TODO')
|
| 350 |
-
parser.add_argument('--mask_type', default='v3', type=str,help='TODO')
|
| 351 |
-
parser.add_argument('--image_size', default=256, type=int,help='TODO')
|
| 352 |
-
parser.add_argument('--input_nc', default=21, type=int,help='TODO')
|
| 353 |
-
parser.add_argument('--output_nc', default=3, type=int,help='TODO')
|
| 354 |
-
parser.add_argument('--renderbs', default=16, type=int,help='TODO')
|
| 355 |
-
parser.add_argument('--tframes', default=1, type=int,help='TODO')
|
| 356 |
-
parser.add_argument('--debug', action='store_true')
|
| 357 |
-
parser.add_argument('--enhance', action='store_true')
|
| 358 |
-
parser.add_argument('--phone', action='store_true')
|
| 359 |
-
|
| 360 |
-
#### faceformer
|
| 361 |
-
parser.add_argument("--model_name", type=str, default="VOCA")
|
| 362 |
-
parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI')
|
| 363 |
-
parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI')
|
| 364 |
-
parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI')
|
| 365 |
-
parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI')
|
| 366 |
-
parser.add_argument("--device", type=str, default="cuda")
|
| 367 |
-
parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ")
|
| 368 |
-
parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA")
|
| 369 |
-
parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects')
|
| 370 |
-
parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects')
|
| 371 |
-
parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background')
|
| 372 |
-
parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates')
|
| 373 |
-
parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology')
|
| 374 |
|
| 375 |
-
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
a2m = Audio2Mesh(opt)
|
| 382 |
|
| 383 |
-
print('link start!')
|
| 384 |
-
t = time.time()
|
| 385 |
-
# 02780
|
| 386 |
-
a2m.run(opt.face, opt.audio, 0)
|
| 387 |
-
print(time.time() - t)
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import tempfile
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from modules.text2speech import text2speech
|
| 5 |
+
from modules.gfpgan_inference import gfpgan
|
| 6 |
+
from modules.sadtalker_test import SadTalker
|
| 7 |
+
|
| 8 |
+
def get_driven_audio(audio):
|
| 9 |
+
if os.path.isfile(audio):
|
| 10 |
+
return audio
|
| 11 |
+
else:
|
| 12 |
+
save_path = tempfile.NamedTemporaryFile(
|
| 13 |
+
delete=False,
|
| 14 |
+
suffix=("." + "wav"),
|
| 15 |
+
)
|
| 16 |
+
gen_audio = text2speech(audio, save_path.name)
|
| 17 |
+
return gen_audio, gen_audio
|
| 18 |
+
|
| 19 |
+
def get_source_image(image):
|
| 20 |
+
return image
|
| 21 |
+
|
| 22 |
+
def sadtalker_demo(result_dir):
|
| 23 |
+
|
| 24 |
+
sad_talker = SadTalker()
|
| 25 |
+
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
|
| 26 |
+
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
|
| 27 |
+
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> \
|
| 28 |
+
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> \
|
| 29 |
+
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
with gr.Row().style(equal_height=False):
|
| 32 |
+
with gr.Column(variant='panel'):
|
| 33 |
+
with gr.Tabs(elem_id="sadtalker_source_image"):
|
| 34 |
+
with gr.TabItem('Upload image'):
|
| 35 |
+
with gr.Row():
|
| 36 |
+
source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
|
| 37 |
+
|
| 38 |
+
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
| 39 |
+
with gr.TabItem('Upload audio'):
|
| 40 |
+
with gr.Column(variant='panel'):
|
| 41 |
+
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
|
| 42 |
+
# submit_audio_1 = gr.Button('Submit', variant='primary')
|
| 43 |
+
# submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
with gr.Column(variant='panel'):
|
| 47 |
+
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
| 48 |
+
with gr.TabItem('Settings'):
|
| 49 |
+
with gr.Column(variant='panel'):
|
| 50 |
+
is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion)")
|
| 51 |
+
enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
|
| 52 |
+
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
|
| 53 |
+
|
| 54 |
+
with gr.Tabs(elem_id="sadtalker_genearted"):
|
| 55 |
+
gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
|
| 56 |
+
gen_text = gr.Textbox(visible=False)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
with gr.Row():
|
| 60 |
+
examples = [
|
| 61 |
+
[
|
| 62 |
+
'examples/source_image/art_10.png',
|
| 63 |
+
'examples/driven_audio/deyu.wav',
|
| 64 |
+
True,
|
| 65 |
+
False
|
| 66 |
+
],
|
| 67 |
+
[
|
| 68 |
+
'examples/source_image/art_1.png',
|
| 69 |
+
'examples/driven_audio/chinese_poem1.wav',
|
| 70 |
+
True,
|
| 71 |
+
False
|
| 72 |
+
],
|
| 73 |
+
[
|
| 74 |
+
'examples/source_image/art_13.png',
|
| 75 |
+
'examples/driven_audio/fayu.wav',
|
| 76 |
+
True,
|
| 77 |
+
False
|
| 78 |
+
],
|
| 79 |
+
[
|
| 80 |
+
'examples/source_image/art_5.png',
|
| 81 |
+
'examples/driven_audio/chinese_news.wav',
|
| 82 |
+
True,
|
| 83 |
+
False
|
| 84 |
+
],
|
| 85 |
+
]
|
| 86 |
+
gr.Examples(examples=examples,
|
| 87 |
+
inputs=[
|
| 88 |
+
source_image,
|
| 89 |
+
driven_audio,
|
| 90 |
+
is_still_mode,
|
| 91 |
+
enhancer,
|
| 92 |
+
gr.Textbox(value=result_dir, visible=False)],
|
| 93 |
+
outputs=[gen_video, gen_text],
|
| 94 |
+
fn=sad_talker.test,
|
| 95 |
+
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 96 |
+
|
| 97 |
+
submit.click(
|
| 98 |
+
fn=sad_talker.test,
|
| 99 |
+
inputs=[source_image,
|
| 100 |
+
driven_audio,
|
| 101 |
+
is_still_mode,
|
| 102 |
+
enhancer,
|
| 103 |
+
gr.Textbox(value=result_dir, visible=False)],
|
| 104 |
+
outputs=[gen_video, gen_text]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return sadtalker_interface
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
|
| 112 |
+
current_code_path = sys.argv[0]
|
| 113 |
+
current_root_dir = os.path.split(current_code_path)[0]
|
| 114 |
+
sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
|
| 115 |
+
demo = sadtalker_demo(sadtalker_result_dir)
|
| 116 |
+
demo.launch()
|
| 117 |
|
|
|
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|