Severian commited on
Commit
23aca6c
·
verified ·
1 Parent(s): 21e7629

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +659 -242
app.py CHANGED
@@ -1,249 +1,666 @@
1
  # coding: utf-8
2
 
3
  """
4
- The entrance of the gradio for animal
5
  """
6
 
 
7
  import os
8
- import tyro
9
- import subprocess
10
  import gradio as gr
11
- import os.path as osp
12
- from src.utils.helper import load_description
13
- from src.gradio_pipeline import GradioPipelineAnimal
14
- from src.config.crop_config import CropConfig
15
- from src.config.argument_config import ArgumentConfig
16
- from src.config.inference_config import InferenceConfig
17
-
18
-
19
- def partial_fields(target_class, kwargs):
20
- return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
21
-
22
-
23
- def fast_check_ffmpeg():
24
- try:
25
- subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
26
- return True
27
- except:
28
- return False
29
-
30
-
31
- # set tyro theme
32
- tyro.extras.set_accent_color("bright_cyan")
33
- args = tyro.cli(ArgumentConfig)
34
-
35
- ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
36
- if osp.exists(ffmpeg_dir):
37
- os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
38
-
39
- if not fast_check_ffmpeg():
40
- raise ImportError(
41
- "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
42
- )
43
- # specify configs for inference
44
- inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
45
- crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
46
-
47
- gradio_pipeline_animal: GradioPipelineAnimal = GradioPipelineAnimal(
48
- inference_cfg=inference_cfg,
49
- crop_cfg=crop_cfg,
50
- args=args
51
- )
52
-
53
- if args.gradio_temp_dir not in (None, ''):
54
- os.environ["GRADIO_TEMP_DIR"] = args.gradio_temp_dir
55
- os.makedirs(args.gradio_temp_dir, exist_ok=True)
56
-
57
- def gpu_wrapped_execute_video(*args, **kwargs):
58
- return gradio_pipeline_animal.execute_video(*args, **kwargs)
59
-
60
-
61
- # assets
62
- title_md = "assets/gradio/gradio_title.md"
63
- example_portrait_dir = "assets/examples/source"
64
- example_video_dir = "assets/examples/driving"
65
- data_examples_i2v = [
66
- [osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "d3.mp4"), True, False, False, False],
67
- [osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "d6.mp4"), True, False, False, False],
68
- [osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "d19.mp4"), True, False, False, False],
69
- ]
70
- data_examples_i2v_pickle = [
71
- [osp.join(example_portrait_dir, "s25.jpg"), osp.join(example_video_dir, "wink.pkl"), True, False, False, False],
72
- [osp.join(example_portrait_dir, "s40.jpg"), osp.join(example_video_dir, "talking.pkl"), True, False, False, False],
73
- [osp.join(example_portrait_dir, "s41.jpg"), osp.join(example_video_dir, "aggrieved.pkl"), True, False, False, False],
74
- ]
75
- #################### interface logic ####################
76
-
77
- # Define components first
78
- output_image = gr.Image(type="numpy")
79
- output_image_paste_back = gr.Image(type="numpy")
80
- output_video_i2v = gr.Video(autoplay=False)
81
- output_video_concat_i2v = gr.Video(autoplay=False)
82
- output_video_i2v_gif = gr.Image(type="numpy")
83
-
84
-
85
- with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")])) as demo:
86
- gr.HTML(load_description(title_md))
87
-
88
- gr.Markdown(load_description("assets/gradio/gradio_description_upload_animal.md"))
89
- with gr.Row():
90
- with gr.Column():
91
- with gr.Accordion(open=True, label="🐱 Source Animal Image"):
92
- source_image_input = gr.Image(type="filepath")
93
- gr.Examples(
94
- examples=[
95
- [osp.join(example_portrait_dir, "s25.jpg")],
96
- [osp.join(example_portrait_dir, "s30.jpg")],
97
- [osp.join(example_portrait_dir, "s31.jpg")],
98
- [osp.join(example_portrait_dir, "s32.jpg")],
99
- [osp.join(example_portrait_dir, "s33.jpg")],
100
- [osp.join(example_portrait_dir, "s39.jpg")],
101
- [osp.join(example_portrait_dir, "s40.jpg")],
102
- [osp.join(example_portrait_dir, "s41.jpg")],
103
- [osp.join(example_portrait_dir, "s38.jpg")],
104
- [osp.join(example_portrait_dir, "s36.jpg")],
105
- ],
106
- inputs=[source_image_input],
107
- cache_examples=False,
108
- )
109
-
110
- with gr.Accordion(open=True, label="Cropping Options for Source Image"):
111
- with gr.Row():
112
- flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
113
- scale = gr.Number(value=2.3, label="source crop scale", minimum=1.8, maximum=3.2, step=0.05)
114
- vx_ratio = gr.Number(value=0.0, label="source crop x", minimum=-0.5, maximum=0.5, step=0.01)
115
- vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)
116
-
117
- with gr.Column():
118
- with gr.Tabs():
119
- with gr.TabItem("📁 Driving Pickle") as tab_pickle:
120
- with gr.Accordion(open=True, label="Driving Pickle"):
121
- driving_video_pickle_input = gr.File()
122
- gr.Examples(
123
- examples=[
124
- [osp.join(example_video_dir, "wink.pkl")],
125
- [osp.join(example_video_dir, "shy.pkl")],
126
- [osp.join(example_video_dir, "aggrieved.pkl")],
127
- [osp.join(example_video_dir, "open_lip.pkl")],
128
- [osp.join(example_video_dir, "laugh.pkl")],
129
- [osp.join(example_video_dir, "talking.pkl")],
130
- [osp.join(example_video_dir, "shake_face.pkl")],
131
- ],
132
- inputs=[driving_video_pickle_input],
133
- cache_examples=False,
134
- )
135
- with gr.TabItem("🎞️ Driving Video") as tab_video:
136
- with gr.Accordion(open=True, label="Driving Video"):
137
- driving_video_input = gr.Video()
138
- gr.Examples(
139
- examples=[
140
- # [osp.join(example_video_dir, "d0.mp4")],
141
- # [osp.join(example_video_dir, "d18.mp4")],
142
- [osp.join(example_video_dir, "d19.mp4")],
143
- [osp.join(example_video_dir, "d14.mp4")],
144
- [osp.join(example_video_dir, "d6.mp4")],
145
- [osp.join(example_video_dir, "d3.mp4")],
146
- ],
147
- inputs=[driving_video_input],
148
- cache_examples=False,
149
- )
150
-
151
- tab_selection = gr.Textbox(visible=False)
152
- tab_pickle.select(lambda: "Pickle", None, tab_selection)
153
- tab_video.select(lambda: "Video", None, tab_selection)
154
- with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
155
- with gr.Row():
156
- flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving)")
157
- scale_crop_driving_video = gr.Number(value=2.2, label="driving crop scale", minimum=1.8, maximum=3.2, step=0.05)
158
- vx_ratio_crop_driving_video = gr.Number(value=0.0, label="driving crop x", minimum=-0.5, maximum=0.5, step=0.01)
159
- vy_ratio_crop_driving_video = gr.Number(value=-0.1, label="driving crop y", minimum=-0.5, maximum=0.5, step=0.01)
160
-
161
- with gr.Row():
162
- with gr.Accordion(open=False, label="Animation Options"):
163
- with gr.Row():
164
- flag_stitching = gr.Checkbox(value=False, label="stitching (not recommended)")
165
- flag_remap_input = gr.Checkbox(value=False, label="paste-back (not recommended)")
166
- driving_multiplier = gr.Number(value=1.0, label="driving multiplier", minimum=0.0, maximum=2.0, step=0.02)
167
-
168
- gr.Markdown(load_description("assets/gradio/gradio_description_animate_clear.md"))
169
- with gr.Row():
170
- process_button_animation = gr.Button("🚀 Animate", variant="primary")
171
- with gr.Row():
172
- with gr.Column():
173
- with gr.Accordion(open=True, label="The animated video in the cropped image space"):
174
- output_video_i2v.render()
175
- with gr.Column():
176
- with gr.Accordion(open=True, label="The animated gif in the cropped image space"):
177
- output_video_i2v_gif.render()
178
- with gr.Column():
179
- with gr.Accordion(open=True, label="The animated video"):
180
- output_video_concat_i2v.render()
181
- with gr.Row():
182
- process_button_reset = gr.ClearButton([source_image_input, driving_video_input, output_video_i2v, output_video_concat_i2v, output_video_i2v_gif], value="🧹 Clear")
183
-
184
- with gr.Row():
185
- # Examples
186
- gr.Markdown("## You could also choose the examples below by one click ⬇️")
187
- with gr.Row():
188
- with gr.Tabs():
189
- with gr.TabItem("📁 Driving Pickle") as tab_video:
190
- gr.Examples(
191
- examples=data_examples_i2v_pickle,
192
- fn=gpu_wrapped_execute_video,
193
- inputs=[
194
- source_image_input,
195
- driving_video_pickle_input,
196
- flag_do_crop_input,
197
- flag_stitching,
198
- flag_remap_input,
199
- flag_crop_driving_video_input,
200
- ],
201
- outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
202
- examples_per_page=len(data_examples_i2v_pickle),
203
- cache_examples=False,
204
- )
205
- with gr.TabItem("🎞️ Driving Video") as tab_video:
206
- gr.Examples(
207
- examples=data_examples_i2v,
208
- fn=gpu_wrapped_execute_video,
209
- inputs=[
210
- source_image_input,
211
- driving_video_input,
212
- flag_do_crop_input,
213
- flag_stitching,
214
- flag_remap_input,
215
- flag_crop_driving_video_input,
216
- ],
217
- outputs=[output_image, output_image_paste_back, output_video_i2v_gif],
218
- examples_per_page=len(data_examples_i2v),
219
- cache_examples=False,
220
- )
221
-
222
- process_button_animation.click(
223
- fn=gpu_wrapped_execute_video,
224
- inputs=[
225
- source_image_input,
226
- driving_video_input,
227
- driving_video_pickle_input,
228
- flag_do_crop_input,
229
- flag_remap_input,
230
- driving_multiplier,
231
- flag_stitching,
232
- flag_crop_driving_video_input,
233
- scale,
234
- vx_ratio,
235
- vy_ratio,
236
- scale_crop_driving_video,
237
- vx_ratio_crop_driving_video,
238
- vy_ratio_crop_driving_video,
239
- tab_selection,
240
- ],
241
- outputs=[output_video_i2v, output_video_concat_i2v, output_video_i2v_gif],
242
- show_progress=True
243
- )
244
-
245
- demo.launch(
246
- server_port=args.server_port,
247
- share=args.share,
248
- server_name=args.server_name
249
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # coding: utf-8
2
 
3
  """
4
+ Pipeline for gradio
5
  """
6
 
7
+ import os.path as osp
8
  import os
9
+ import cv2
10
+ from rich.progress import track
11
  import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+
15
+ from .config.argument_config import ArgumentConfig
16
+ from .live_portrait_pipeline import LivePortraitPipeline
17
+ from .live_portrait_pipeline_animal import LivePortraitPipelineAnimal
18
+ from .utils.io import load_img_online, load_video, resize_to_limit
19
+ from .utils.filter import smooth
20
+ from .utils.rprint import rlog as log
21
+ from .utils.crop import prepare_paste_back, paste_back
22
+ from .utils.camera import get_rotation_matrix
23
+ from .utils.video import get_fps, has_audio_stream, concat_frames, images2video, add_audio_to_video
24
+ from .utils.helper import is_square_video, mkdir, dct2device, basename
25
+ from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
26
+
27
+
28
+ def update_args(args, user_args):
29
+ """update the args according to user inputs
30
+ """
31
+ for k, v in user_args.items():
32
+ if hasattr(args, k):
33
+ setattr(args, k, v)
34
+ return args
35
+
36
+
37
+ class GradioPipeline(LivePortraitPipeline):
38
+ """gradio for human
39
+ """
40
+
41
+ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
42
+ super().__init__(inference_cfg, crop_cfg)
43
+ # self.live_portrait_wrapper = self.live_portrait_wrapper
44
+ self.args = args
45
+
46
+ @torch.no_grad()
47
+ def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
48
+ if eyeball_direction_x > 0:
49
+ delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
50
+ delta_new[0, 15, 0] += eyeball_direction_x * 0.001
51
+ else:
52
+ delta_new[0, 11, 0] += eyeball_direction_x * 0.001
53
+ delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
54
+
55
+ delta_new[0, 11, 1] += eyeball_direction_y * -0.001
56
+ delta_new[0, 15, 1] += eyeball_direction_y * -0.001
57
+ blink = -eyeball_direction_y / 2.
58
+
59
+ delta_new[0, 11, 1] += blink * -0.001
60
+ delta_new[0, 13, 1] += blink * 0.0003
61
+ delta_new[0, 15, 1] += blink * -0.001
62
+ delta_new[0, 16, 1] += blink * 0.0003
63
+
64
+ return delta_new
65
+
66
+ @torch.no_grad()
67
+ def update_delta_new_smile(self, smile, delta_new, **kwargs):
68
+ delta_new[0, 20, 1] += smile * -0.01
69
+ delta_new[0, 14, 1] += smile * -0.02
70
+ delta_new[0, 17, 1] += smile * 0.0065
71
+ delta_new[0, 17, 2] += smile * 0.003
72
+ delta_new[0, 13, 1] += smile * -0.00275
73
+ delta_new[0, 16, 1] += smile * -0.00275
74
+ delta_new[0, 3, 1] += smile * -0.0035
75
+ delta_new[0, 7, 1] += smile * -0.0035
76
+
77
+ return delta_new
78
+
79
+ @torch.no_grad()
80
+ def update_delta_new_wink(self, wink, delta_new, **kwargs):
81
+ delta_new[0, 11, 1] += wink * 0.001
82
+ delta_new[0, 13, 1] += wink * -0.0003
83
+ delta_new[0, 17, 0] += wink * 0.0003
84
+ delta_new[0, 17, 1] += wink * 0.0003
85
+ delta_new[0, 3, 1] += wink * -0.0003
86
+
87
+ return delta_new
88
+
89
+ @torch.no_grad()
90
+ def update_delta_new_eyebrow(self, eyebrow, delta_new, **kwargs):
91
+ if eyebrow > 0:
92
+ delta_new[0, 1, 1] += eyebrow * 0.001
93
+ delta_new[0, 2, 1] += eyebrow * -0.001
94
+ else:
95
+ delta_new[0, 1, 0] += eyebrow * -0.001
96
+ delta_new[0, 2, 0] += eyebrow * 0.001
97
+ delta_new[0, 1, 1] += eyebrow * 0.0003
98
+ delta_new[0, 2, 1] += eyebrow * -0.0003
99
+ return delta_new
100
+
101
+ @torch.no_grad()
102
+ def update_delta_new_lip_variation_zero(self, lip_variation_zero, delta_new, **kwargs):
103
+ delta_new[0, 19, 0] += lip_variation_zero
104
+
105
+ return delta_new
106
+
107
+ @torch.no_grad()
108
+ def update_delta_new_lip_variation_one(self, lip_variation_one, delta_new, **kwargs):
109
+ delta_new[0, 14, 1] += lip_variation_one * 0.001
110
+ delta_new[0, 3, 1] += lip_variation_one * -0.0005
111
+ delta_new[0, 7, 1] += lip_variation_one * -0.0005
112
+ delta_new[0, 17, 2] += lip_variation_one * -0.0005
113
+
114
+ return delta_new
115
+
116
+ @torch.no_grad()
117
+ def update_delta_new_lip_variation_two(self, lip_variation_two, delta_new, **kwargs):
118
+ delta_new[0, 20, 2] += lip_variation_two * -0.001
119
+ delta_new[0, 20, 1] += lip_variation_two * -0.001
120
+ delta_new[0, 14, 1] += lip_variation_two * -0.001
121
+
122
+ return delta_new
123
+
124
+ @torch.no_grad()
125
+ def update_delta_new_lip_variation_three(self, lip_variation_three, delta_new, **kwargs):
126
+ delta_new[0, 19, 1] += lip_variation_three * 0.001
127
+ delta_new[0, 19, 2] += lip_variation_three * 0.0001
128
+ delta_new[0, 17, 1] += lip_variation_three * -0.0001
129
+
130
+ return delta_new
131
+
132
+ @torch.no_grad()
133
+ def update_delta_new_mov_x(self, mov_x, delta_new, **kwargs):
134
+ delta_new[0, 5, 0] += mov_x
135
+
136
+ return delta_new
137
+
138
+ @torch.no_grad()
139
+ def update_delta_new_mov_y(self, mov_y, delta_new, **kwargs):
140
+ delta_new[0, 5, 1] += mov_y
141
+
142
+ return delta_new
143
+
144
+ @torch.no_grad()
145
+ def execute_video(
146
+ self,
147
+ input_source_image_path=None,
148
+ input_source_video_path=None,
149
+ input_driving_video_path=None,
150
+ input_driving_image_path=None,
151
+ input_driving_video_pickle_path=None,
152
+ flag_normalize_lip=False,
153
+ flag_relative_input=True,
154
+ flag_do_crop_input=True,
155
+ flag_remap_input=True,
156
+ flag_stitching_input=True,
157
+ animation_region="all",
158
+ driving_option_input="pose-friendly",
159
+ driving_multiplier=1.0,
160
+ flag_crop_driving_video_input=True,
161
+ # flag_video_editing_head_rotation=False,
162
+ scale=2.3,
163
+ vx_ratio=0.0,
164
+ vy_ratio=-0.125,
165
+ scale_crop_driving_video=2.2,
166
+ vx_ratio_crop_driving_video=0.0,
167
+ vy_ratio_crop_driving_video=-0.1,
168
+ driving_smooth_observation_variance=3e-7,
169
+ tab_selection=None,
170
+ v_tab_selection=None
171
+ ):
172
+ """ for video-driven portrait animation or video editing
173
+ """
174
+ if tab_selection == 'Image':
175
+ input_source_path = input_source_image_path
176
+ elif tab_selection == 'Video':
177
+ input_source_path = input_source_video_path
178
+ else:
179
+ input_source_path = input_source_image_path
180
+
181
+ if v_tab_selection == 'Video':
182
+ input_driving_path = input_driving_video_path
183
+ elif v_tab_selection == 'Image':
184
+ input_driving_path = input_driving_image_path
185
+ elif v_tab_selection == 'Pickle':
186
+ input_driving_path = input_driving_video_pickle_path
187
+ else:
188
+ input_driving_path = input_driving_video_path
189
+
190
+ if input_source_path is not None and input_driving_path is not None:
191
+ if osp.exists(input_driving_path) and v_tab_selection == 'Video' and not flag_crop_driving_video_input and is_square_video(input_driving_path) is False:
192
+ flag_crop_driving_video_input = True
193
+ log("The driving video is not square, it will be cropped to square automatically.")
194
+ gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
195
+
196
+ args_user = {
197
+ 'source': input_source_path,
198
+ 'driving': input_driving_path,
199
+ 'flag_normalize_lip' : flag_normalize_lip,
200
+ 'flag_relative_motion': flag_relative_input,
201
+ 'flag_do_crop': flag_do_crop_input,
202
+ 'flag_pasteback': flag_remap_input,
203
+ 'flag_stitching': flag_stitching_input,
204
+ 'animation_region': animation_region,
205
+ 'driving_option': driving_option_input,
206
+ 'driving_multiplier': driving_multiplier,
207
+ 'flag_crop_driving_video': flag_crop_driving_video_input,
208
+ 'scale': scale,
209
+ 'vx_ratio': vx_ratio,
210
+ 'vy_ratio': vy_ratio,
211
+ 'scale_crop_driving_video': scale_crop_driving_video,
212
+ 'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
213
+ 'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
214
+ 'driving_smooth_observation_variance': driving_smooth_observation_variance,
215
+ }
216
+ # update config from user input
217
+ self.args = update_args(self.args, args_user)
218
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
219
+ self.cropper.update_config(self.args.__dict__)
220
+
221
+ output_path, output_path_concat = self.execute(self.args)
222
+ gr.Info("Run successfully!", duration=2)
223
+ if output_path.endswith(".jpg"):
224
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True)
225
+ else:
226
+ return output_path, gr.update(visible=True), output_path_concat, gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
227
+ else:
228
+ raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
229
+
230
+ @torch.no_grad()
231
+ def execute_image_retargeting(
232
+ self,
233
+ input_eye_ratio: float,
234
+ input_lip_ratio: float,
235
+ input_head_pitch_variation: float,
236
+ input_head_yaw_variation: float,
237
+ input_head_roll_variation: float,
238
+ mov_x: float,
239
+ mov_y: float,
240
+ mov_z: float,
241
+ lip_variation_zero: float,
242
+ lip_variation_one: float,
243
+ lip_variation_two: float,
244
+ lip_variation_three: float,
245
+ smile: float,
246
+ wink: float,
247
+ eyebrow: float,
248
+ eyeball_direction_x: float,
249
+ eyeball_direction_y: float,
250
+ input_image,
251
+ retargeting_source_scale: float,
252
+ flag_stitching_retargeting_input=True,
253
+ flag_do_crop_input_retargeting_image=True):
254
+ """ for single image retargeting
255
+ """
256
+ if input_head_pitch_variation is None or input_head_yaw_variation is None or input_head_roll_variation is None:
257
+ raise gr.Error("Invalid relative pose input 💥!", duration=5)
258
+ # disposable feature
259
+ f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
260
+ self.prepare_retargeting_image(
261
+ input_image, input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation, retargeting_source_scale, flag_do_crop=flag_do_crop_input_retargeting_image)
262
+
263
+ if input_eye_ratio is None or input_lip_ratio is None:
264
+ raise gr.Error("Invalid ratio input 💥!", duration=5)
265
+ else:
266
+ device = self.live_portrait_wrapper.device
267
+ # inference_cfg = self.live_portrait_wrapper.inference_cfg
268
+ x_s_user = x_s_user.to(device)
269
+ f_s_user = f_s_user.to(device)
270
+ R_s_user = R_s_user.to(device)
271
+ R_d_user = R_d_user.to(device)
272
+ mov_x = torch.tensor(mov_x).to(device)
273
+ mov_y = torch.tensor(mov_y).to(device)
274
+ mov_z = torch.tensor(mov_z).to(device)
275
+ eyeball_direction_x = torch.tensor(eyeball_direction_x).to(device)
276
+ eyeball_direction_y = torch.tensor(eyeball_direction_y).to(device)
277
+ smile = torch.tensor(smile).to(device)
278
+ wink = torch.tensor(wink).to(device)
279
+ eyebrow = torch.tensor(eyebrow).to(device)
280
+ lip_variation_zero = torch.tensor(lip_variation_zero).to(device)
281
+ lip_variation_one = torch.tensor(lip_variation_one).to(device)
282
+ lip_variation_two = torch.tensor(lip_variation_two).to(device)
283
+ lip_variation_three = torch.tensor(lip_variation_three).to(device)
284
+
285
+ x_c_s = x_s_info['kp'].to(device)
286
+ delta_new = x_s_info['exp'].to(device)
287
+ scale_new = x_s_info['scale'].to(device)
288
+ t_new = x_s_info['t'].to(device)
289
+ R_d_new = (R_d_user @ R_s_user.permute(0, 2, 1)) @ R_s_user
290
+
291
+ if eyeball_direction_x != 0 or eyeball_direction_y != 0:
292
+ delta_new = self.update_delta_new_eyeball_direction(eyeball_direction_x, eyeball_direction_y, delta_new)
293
+ if smile != 0:
294
+ delta_new = self.update_delta_new_smile(smile, delta_new)
295
+ if wink != 0:
296
+ delta_new = self.update_delta_new_wink(wink, delta_new)
297
+ if eyebrow != 0:
298
+ delta_new = self.update_delta_new_eyebrow(eyebrow, delta_new)
299
+ if lip_variation_zero != 0:
300
+ delta_new = self.update_delta_new_lip_variation_zero(lip_variation_zero, delta_new)
301
+ if lip_variation_one != 0:
302
+ delta_new = self.update_delta_new_lip_variation_one(lip_variation_one, delta_new)
303
+ if lip_variation_two != 0:
304
+ delta_new = self.update_delta_new_lip_variation_two(lip_variation_two, delta_new)
305
+ if lip_variation_three != 0:
306
+ delta_new = self.update_delta_new_lip_variation_three(lip_variation_three, delta_new)
307
+ if mov_x != 0:
308
+ delta_new = self.update_delta_new_mov_x(-mov_x, delta_new)
309
+ if mov_y !=0 :
310
+ delta_new = self.update_delta_new_mov_y(mov_y, delta_new)
311
+
312
+ x_d_new = mov_z * scale_new * (x_c_s @ R_d_new + delta_new) + t_new
313
+ eyes_delta, lip_delta = None, None
314
+ if input_eye_ratio != self.source_eye_ratio:
315
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[float(input_eye_ratio)]], source_lmk_user)
316
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
317
+ if input_lip_ratio != self.source_lip_ratio:
318
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[float(input_lip_ratio)]], source_lmk_user)
319
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
320
+ print(lip_delta)
321
+ x_d_new = x_d_new + \
322
+ (eyes_delta if eyes_delta is not None else 0) + \
323
+ (lip_delta if lip_delta is not None else 0)
324
+
325
+ if flag_stitching_retargeting_input:
326
+ x_d_new = self.live_portrait_wrapper.stitching(x_s_user, x_d_new)
327
+ out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
328
+ out = self.live_portrait_wrapper.parse_output(out['out'])[0]
329
+ if flag_do_crop_input_retargeting_image:
330
+ out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
331
+ else:
332
+ out_to_ori_blend = out
333
+ return out, out_to_ori_blend
334
+
335
+ @torch.no_grad()
336
+ def prepare_retargeting_image(
337
+ self,
338
+ input_image,
339
+ input_head_pitch_variation, input_head_yaw_variation, input_head_roll_variation,
340
+ retargeting_source_scale,
341
+ flag_do_crop=True):
342
+ """ for single image retargeting
343
+ """
344
+ if input_image is not None:
345
+ # gr.Info("Upload successfully!", duration=2)
346
+ args_user = {'scale': retargeting_source_scale}
347
+ self.args = update_args(self.args, args_user)
348
+ self.cropper.update_config(self.args.__dict__)
349
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
350
+ ######## process source portrait ########
351
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=2)
352
+ if flag_do_crop:
353
+ crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
354
+ I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
355
+ source_lmk_user = crop_info['lmk_crop']
356
+ crop_M_c2o = crop_info['M_c2o']
357
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
358
+ else:
359
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
360
+ source_lmk_user = self.cropper.calc_lmk_from_cropped_image(img_rgb)
361
+ crop_M_c2o = None
362
+ mask_ori = None
363
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
364
+ x_d_info_user_pitch = x_s_info['pitch'] + input_head_pitch_variation
365
+ x_d_info_user_yaw = x_s_info['yaw'] + input_head_yaw_variation
366
+ x_d_info_user_roll = x_s_info['roll'] + input_head_roll_variation
367
+ R_s_user = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
368
+ R_d_user = get_rotation_matrix(x_d_info_user_pitch, x_d_info_user_yaw, x_d_info_user_roll)
369
+ ############################################
370
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
371
+ x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
372
+ return f_s_user, x_s_user, R_s_user, R_d_user, x_s_info, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
373
+ else:
374
+ raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
375
+
376
+ @torch.no_grad()
377
+ def init_retargeting_image(self, retargeting_source_scale: float, source_eye_ratio: float, source_lip_ratio:float, input_image = None):
378
+ """ initialize the retargeting slider
379
+ """
380
+ if input_image != None:
381
+ args_user = {'scale': retargeting_source_scale}
382
+ self.args = update_args(self.args, args_user)
383
+ self.cropper.update_config(self.args.__dict__)
384
+ # inference_cfg = self.live_portrait_wrapper.inference_cfg
385
+ ######## process source portrait ########
386
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
387
+ log(f"Load source image from {input_image}.")
388
+ crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
389
+ if crop_info is None:
390
+ raise gr.Error("Source portrait NO face detected", duration=2)
391
+ source_eye_ratio = calc_eye_close_ratio(crop_info['lmk_crop'][None])
392
+ source_lip_ratio = calc_lip_close_ratio(crop_info['lmk_crop'][None])
393
+ self.source_eye_ratio = round(float(source_eye_ratio.mean()), 2)
394
+ self.source_lip_ratio = round(float(source_lip_ratio[0][0]), 2)
395
+ log("Calculating eyes-open and lip-open ratios successfully!")
396
+ return self.source_eye_ratio, self.source_lip_ratio
397
+ else:
398
+ return source_eye_ratio, source_lip_ratio
399
+
400
+ @torch.no_grad()
401
+ def execute_video_retargeting(self, input_lip_ratio: float, input_video, retargeting_source_scale: float, driving_smooth_observation_variance_retargeting: float, video_retargeting_silence=False, flag_do_crop_input_retargeting_video=True):
402
+ """ retargeting the lip-open ratio of each source frame
403
+ """
404
+ # disposable feature
405
+ device = self.live_portrait_wrapper.device
406
+
407
+ if not video_retargeting_silence:
408
+ f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames = \
409
+ self.prepare_retargeting_video(input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=flag_do_crop_input_retargeting_video)
410
+ if input_lip_ratio is None:
411
+ raise gr.Error("Invalid ratio input 💥!", duration=5)
412
+ else:
413
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
414
+
415
+ I_p_pstbk_lst = None
416
+ if flag_do_crop_input_retargeting_video:
417
+ I_p_pstbk_lst = []
418
+ I_p_lst = []
419
+ for i in track(range(n_frames), description='Retargeting video...', total=n_frames):
420
+ x_s_user_i = x_s_user_lst[i].to(device)
421
+ f_s_user_i = f_s_user_lst[i].to(device)
422
+
423
+ lip_delta_retargeting = lip_delta_retargeting_lst_smooth[i]
424
+ x_d_i_new = x_s_user_i + lip_delta_retargeting
425
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
426
+ out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
427
+ I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
428
+ I_p_lst.append(I_p_i)
429
+
430
+ if flag_do_crop_input_retargeting_video:
431
+ I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
432
+ I_p_pstbk_lst.append(I_p_pstbk)
433
+ else:
434
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
435
+ f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames = \
436
+ self.prepare_video_lip_silence(input_video, device, flag_do_crop=flag_do_crop_input_retargeting_video)
437
+
438
+ I_p_pstbk_lst = None
439
+ if flag_do_crop_input_retargeting_video:
440
+ I_p_pstbk_lst = []
441
+ I_p_lst = []
442
+ for i in track(range(n_frames), description='Silencing lip...', total=n_frames):
443
+ x_s_user_i = x_s_user_lst[i].to(device)
444
+ f_s_user_i = f_s_user_lst[i].to(device)
445
+ x_d_i_new = x_d_i_new_lst[i]
446
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s_user_i, x_d_i_new)
447
+ out = self.live_portrait_wrapper.warp_decode(f_s_user_i, x_s_user_i, x_d_i_new)
448
+ I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
449
+ I_p_lst.append(I_p_i)
450
+
451
+ if flag_do_crop_input_retargeting_video:
452
+ I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_lst[i])
453
+ I_p_pstbk_lst.append(I_p_pstbk)
454
+
455
+ mkdir(self.args.output_dir)
456
+ flag_source_has_audio = has_audio_stream(input_video)
457
+
458
+ ######### build the final concatenation result #########
459
+ # source frame | generation
460
+ frames_concatenated = concat_frames(driving_image_lst=None, source_image_lst=img_crop_256x256_lst, I_p_lst=I_p_lst)
461
+ wfp_concat = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat.mp4')
462
+ images2video(frames_concatenated, wfp=wfp_concat, fps=source_fps)
463
+
464
+ if flag_source_has_audio:
465
+ # final result with concatenation
466
+ wfp_concat_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_concat_with_audio.mp4')
467
+ add_audio_to_video(wfp_concat, input_video, wfp_concat_with_audio)
468
+ os.replace(wfp_concat_with_audio, wfp_concat)
469
+ log(f"Replace {wfp_concat_with_audio} with {wfp_concat}")
470
+
471
+ # save the animated result
472
+ wfp = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting.mp4')
473
+ if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0:
474
+ images2video(I_p_pstbk_lst, wfp=wfp, fps=source_fps)
475
+ else:
476
+ images2video(I_p_lst, wfp=wfp, fps=source_fps)
477
+
478
+ ######### build the final result #########
479
+ if flag_source_has_audio:
480
+ wfp_with_audio = osp.join(self.args.output_dir, f'{basename(input_video)}_retargeting_with_audio.mp4')
481
+ add_audio_to_video(wfp, input_video, wfp_with_audio)
482
+ os.replace(wfp_with_audio, wfp)
483
+ log(f"Replace {wfp_with_audio} with {wfp}")
484
+ gr.Info("Run successfully!", duration=2)
485
+ return wfp_concat, wfp
486
+
487
+ @torch.no_grad()
488
+ def prepare_retargeting_video(self, input_video, retargeting_source_scale, device, input_lip_ratio, driving_smooth_observation_variance_retargeting, flag_do_crop=True):
489
+ """ for video retargeting
490
+ """
491
+ if input_video is not None:
492
+ # gr.Info("Upload successfully!", duration=2)
493
+ args_user = {'scale': retargeting_source_scale}
494
+ self.args = update_args(self.args, args_user)
495
+ self.cropper.update_config(self.args.__dict__)
496
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
497
+ ######## process source video ########
498
+ source_rgb_lst = load_video(input_video)
499
+ source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
500
+ source_fps = int(get_fps(input_video))
501
+ n_frames = len(source_rgb_lst)
502
+ log(f"Load source video from {input_video}. FPS is {source_fps}")
503
+
504
+ if flag_do_crop:
505
+ ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
506
+ log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
507
+ if len(ret_s["frame_crop_lst"]) != n_frames:
508
+ n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
509
+ img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
510
+ mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
511
+ else:
512
+ source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
513
+ img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
514
+ source_M_c2o_lst, mask_ori_lst = None, None
515
+
516
+ c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
517
+ # save the motion template
518
+ I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
519
+ source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
520
+
521
+ c_d_lip_retargeting = [input_lip_ratio]
522
+ f_s_user_lst, x_s_user_lst, lip_delta_retargeting_lst = [], [], []
523
+ for i in track(range(n_frames), description='Preparing retargeting video...', total=n_frames):
524
+ x_s_info = source_template_dct['motion'][i]
525
+ x_s_info = dct2device(x_s_info, device)
526
+ x_s_user = x_s_info['x_s']
527
+
528
+ source_lmk = source_lmk_crop_lst[i]
529
+ img_crop_256x256 = img_crop_256x256_lst[i]
530
+ I_s = I_s_lst[i]
531
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
532
+
533
+ combined_lip_ratio_tensor_retargeting = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_retargeting, source_lmk)
534
+ lip_delta_retargeting = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor_retargeting)
535
+ f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); lip_delta_retargeting_lst.append(lip_delta_retargeting.cpu().numpy().astype(np.float32))
536
+ lip_delta_retargeting_lst_smooth = smooth(lip_delta_retargeting_lst, lip_delta_retargeting_lst[0].shape, device, driving_smooth_observation_variance_retargeting)
537
+
538
+ return f_s_user_lst, x_s_user_lst, source_lmk_crop_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, lip_delta_retargeting_lst_smooth, source_fps, n_frames
539
+ else:
540
+ # when press the clear button, go here
541
+ raise gr.Error("Please upload a source video as the retargeting input 🤗🤗🤗", duration=5)
542
+
543
+ @torch.no_grad()
544
+ def prepare_video_lip_silence(self, input_video, device, flag_do_crop=True):
545
+ """ for keeping lips in the source video silent
546
+ """
547
+ if input_video is not None:
548
+ inference_cfg = self.live_portrait_wrapper.inference_cfg
549
+ ######## process source video ########
550
+ source_rgb_lst = load_video(input_video)
551
+ source_rgb_lst = [resize_to_limit(img, inference_cfg.source_max_dim, inference_cfg.source_division) for img in source_rgb_lst]
552
+ source_fps = int(get_fps(input_video))
553
+ n_frames = len(source_rgb_lst)
554
+ log(f"Load source video from {input_video}. FPS is {source_fps}")
555
+
556
+ if flag_do_crop:
557
+ ret_s = self.cropper.crop_source_video(source_rgb_lst, self.cropper.crop_cfg)
558
+ log(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
559
+ if len(ret_s["frame_crop_lst"]) != n_frames:
560
+ n_frames = min(len(source_rgb_lst), len(ret_s["frame_crop_lst"]))
561
+ img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
562
+ mask_ori_lst = [prepare_paste_back(inference_cfg.mask_crop, source_M_c2o, dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) for source_M_c2o in source_M_c2o_lst]
563
+ else:
564
+ source_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(source_rgb_lst)
565
+ img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in source_rgb_lst] # force to resize to 256x256
566
+ source_M_c2o_lst, mask_ori_lst = None, None
567
+
568
+ c_s_eyes_lst, c_s_lip_lst = self.live_portrait_wrapper.calc_ratio(source_lmk_crop_lst)
569
+ # save the motion template
570
+ I_s_lst = self.live_portrait_wrapper.prepare_videos(img_crop_256x256_lst)
571
+ source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=source_fps)
572
+
573
+ f_s_user_lst, x_s_user_lst, x_d_i_new_lst = [], [], []
574
+ for i in track(range(n_frames), description='Preparing silencing lip...', total=n_frames):
575
+ x_s_info = source_template_dct['motion'][i]
576
+ x_s_info = dct2device(x_s_info, device)
577
+ scale_s = x_s_info['scale']
578
+ x_s_user = x_s_info['x_s']
579
+ x_c_s = x_s_info['kp']
580
+ R_s = x_s_info['R']
581
+ t_s = x_s_info['t']
582
+ delta_new = torch.zeros_like(x_s_info['exp']) + torch.from_numpy(inference_cfg.lip_array).to(dtype=torch.float32, device=device)
583
+ for eyes_idx in [11, 13, 15, 16, 18]:
584
+ delta_new[:, eyes_idx, :] = x_s_info['exp'][:, eyes_idx, :]
585
+ source_lmk = source_lmk_crop_lst[i]
586
+ img_crop_256x256 = img_crop_256x256_lst[i]
587
+ I_s = I_s_lst[i]
588
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
589
+ x_d_i_new = scale_s * (x_c_s @ R_s + delta_new) + t_s
590
+ f_s_user_lst.append(f_s_user); x_s_user_lst.append(x_s_user); x_d_i_new_lst.append(x_d_i_new)
591
+ return f_s_user_lst, x_s_user_lst, x_d_i_new_lst, source_M_c2o_lst, mask_ori_lst, source_rgb_lst, img_crop_256x256_lst, source_fps, n_frames
592
+ else:
593
+ # when press the clear button, go here
594
+ raise gr.Error("Please upload a source video as the input 🤗🤗🤗", duration=5)
595
+
596
+ class GradioPipelineAnimal(LivePortraitPipelineAnimal):
597
+ """gradio for animal
598
+ """
599
+ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
600
+ inference_cfg.flag_crop_driving_video = True # ensure the face_analysis_wrapper is enabled
601
+ super().__init__(inference_cfg, crop_cfg)
602
+ # self.live_portrait_wrapper_animal = self.live_portrait_wrapper_animal
603
+ self.args = args
604
+
605
+
606
+ @torch.no_grad()
607
+ def execute_video(
608
+ self,
609
+ input_source_image_path=None,
610
+ input_driving_video_path=None,
611
+ input_driving_video_pickle_path=None,
612
+ flag_do_crop_input=False,
613
+ flag_remap_input=False,
614
+ driving_multiplier=1.0,
615
+ flag_stitching=False,
616
+ flag_crop_driving_video_input=False,
617
+ scale=2.3,
618
+ vx_ratio=0.0,
619
+ vy_ratio=-0.125,
620
+ scale_crop_driving_video=2.2,
621
+ vx_ratio_crop_driving_video=0.0,
622
+ vy_ratio_crop_driving_video=-0.1,
623
+ tab_selection=None,
624
+ ):
625
+ """ for video-driven potrait animation
626
+ """
627
+ input_source_path = input_source_image_path
628
+
629
+ if tab_selection == 'Video':
630
+ input_driving_path = input_driving_video_path
631
+ elif tab_selection == 'Pickle':
632
+ input_driving_path = input_driving_video_pickle_path
633
+ else:
634
+ input_driving_path = input_driving_video_pickle_path
635
+
636
+ if input_source_path is not None and input_driving_path is not None:
637
+ if osp.exists(input_driving_path) and tab_selection == 'Video' and is_square_video(input_driving_path) is False:
638
+ flag_crop_driving_video_input = True
639
+ log("The driving video is not square, it will be cropped to square automatically.")
640
+ gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)
641
+
642
+ args_user = {
643
+ 'source': input_source_path,
644
+ 'driving': input_driving_path,
645
+ 'flag_do_crop': flag_do_crop_input,
646
+ 'flag_pasteback': flag_remap_input,
647
+ 'driving_multiplier': driving_multiplier,
648
+ 'flag_stitching': flag_stitching,
649
+ 'flag_crop_driving_video': flag_crop_driving_video_input,
650
+ 'scale': scale,
651
+ 'vx_ratio': vx_ratio,
652
+ 'vy_ratio': vy_ratio,
653
+ 'scale_crop_driving_video': scale_crop_driving_video,
654
+ 'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
655
+ 'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
656
+ }
657
+ # update config from user input
658
+ self.args = update_args(self.args, args_user)
659
+ self.live_portrait_wrapper_animal.update_config(self.args.__dict__)
660
+ self.cropper.update_config(self.args.__dict__)
661
+ # video driven animation
662
+ video_path, video_path_concat, video_gif_path = self.execute(self.args)
663
+ gr.Info("Run successfully!", duration=2)
664
+ return video_path, video_path_concat, video_gif_path
665
+ else:
666
+ raise gr.Error("Please upload the source animal image, and driving video 🤗🤗🤗", duration=5)