jhj0517
commited on
Commit
·
45d5794
1
Parent(s):
e5db983
raise error
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -252,76 +252,79 @@ class LivePortraitInferencer:
|
|
| 252 |
model_type=model_type
|
| 253 |
)
|
| 254 |
|
| 255 |
-
|
|
|
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
|
| 262 |
-
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
|
| 274 |
-
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
|
| 289 |
-
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
|
| 320 |
-
|
| 321 |
|
| 322 |
-
|
| 323 |
|
| 324 |
-
|
|
|
|
|
|
|
| 325 |
|
| 326 |
def download_if_no_models(self,
|
| 327 |
model_type: str = ModelType.HUMAN.value,
|
|
|
|
| 252 |
model_type=model_type
|
| 253 |
)
|
| 254 |
|
| 255 |
+
try:
|
| 256 |
+
vid_info = get_video_info(vid_input=driving_vid_path)
|
| 257 |
|
| 258 |
+
if src_image is not None:
|
| 259 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
| 260 |
+
self.crop_factor = crop_factor
|
| 261 |
+
self.src_image = src_image
|
| 262 |
|
| 263 |
+
self.psi_list = [self.prepare_source(src_image, crop_factor)]
|
| 264 |
|
| 265 |
+
progress(0, desc="Extracting frames from the video..")
|
| 266 |
+
driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
|
| 267 |
|
| 268 |
+
driving_length = 0
|
| 269 |
+
if driving_images is not None:
|
| 270 |
+
if id(driving_images) != id(self.driving_images):
|
| 271 |
+
self.driving_images = driving_images
|
| 272 |
+
self.driving_values = self.prepare_driving_video(driving_images)
|
| 273 |
+
driving_length = len(self.driving_values)
|
| 274 |
|
| 275 |
+
total_length = len(driving_images)
|
| 276 |
|
| 277 |
+
c_i_es = ExpressionSet()
|
| 278 |
+
c_o_es = ExpressionSet()
|
| 279 |
+
d_0_es = None
|
| 280 |
|
| 281 |
+
psi = None
|
| 282 |
+
with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
|
| 283 |
+
for i in range(total_length):
|
| 284 |
|
| 285 |
+
if i == 0:
|
| 286 |
+
psi = self.psi_list[i]
|
| 287 |
+
s_info = psi.x_s_info
|
| 288 |
+
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
| 289 |
|
| 290 |
+
new_es = ExpressionSet(es=s_es)
|
| 291 |
|
| 292 |
+
if i < driving_length:
|
| 293 |
+
d_i_info = self.driving_values[i]
|
| 294 |
+
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) # .float().to(device="cuda:0")
|
| 295 |
|
| 296 |
+
if d_0_es is None:
|
| 297 |
+
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
| 298 |
|
| 299 |
+
self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
|
| 300 |
+
self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))
|
| 301 |
|
| 302 |
+
new_es.e += d_i_info['exp'] - d_0_es.e
|
| 303 |
+
new_es.r += d_i_r - d_0_es.r
|
| 304 |
+
new_es.t += d_i_info['t'] - d_0_es.t
|
| 305 |
|
| 306 |
+
r_new = get_rotation_matrix(
|
| 307 |
+
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
|
| 308 |
+
d_new = new_es.s * (new_es.e @ r_new) + new_es.t
|
| 309 |
+
d_new = self.pipeline.stitching(psi.x_s_user, d_new)
|
| 310 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
|
| 311 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
| 312 |
|
| 313 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
|
| 314 |
+
cv2.INTER_LINEAR)
|
| 315 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
|
| 316 |
+
np.uint8)
|
| 317 |
|
| 318 |
+
out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
|
| 319 |
+
save_image(out, out_frame_path)
|
| 320 |
|
| 321 |
+
progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
|
| 322 |
|
| 323 |
+
video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR, frame_rate=vid_info.frame_rate, output_dir=os.path.join(self.output_dir, "videos"))
|
| 324 |
|
| 325 |
+
return video_path
|
| 326 |
+
except Exception as e:
|
| 327 |
+
raise
|
| 328 |
|
| 329 |
def download_if_no_models(self,
|
| 330 |
model_type: str = ModelType.HUMAN.value,
|