Spaces:
Running
Running
Delete predict.py
Browse files- predict.py +0 -192
predict.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
"""run bash scripts/download_models.sh first to prepare the weights file"""
|
2 |
-
import os
|
3 |
-
import shutil
|
4 |
-
from argparse import Namespace
|
5 |
-
from src.utils.preprocess import CropAndExtract
|
6 |
-
from src.test_audio2coeff import Audio2Coeff
|
7 |
-
from src.facerender.animate import AnimateFromCoeff
|
8 |
-
from src.generate_batch import get_data
|
9 |
-
from src.generate_facerender_batch import get_facerender_data
|
10 |
-
from src.utils.init_path import init_path
|
11 |
-
from cog import BasePredictor, Input, Path
|
12 |
-
|
13 |
-
checkpoints = "checkpoints"
|
14 |
-
|
15 |
-
|
16 |
-
class Predictor(BasePredictor):
|
17 |
-
def setup(self):
|
18 |
-
"""Load the model into memory to make running multiple predictions efficient"""
|
19 |
-
device = "cuda"
|
20 |
-
|
21 |
-
|
22 |
-
sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
|
23 |
-
|
24 |
-
# init model
|
25 |
-
self.preprocess_model = CropAndExtract(sadtalker_paths, device
|
26 |
-
)
|
27 |
-
|
28 |
-
self.audio_to_coeff = Audio2Coeff(
|
29 |
-
sadtalker_paths,
|
30 |
-
device,
|
31 |
-
)
|
32 |
-
|
33 |
-
self.animate_from_coeff = {
|
34 |
-
"full": AnimateFromCoeff(
|
35 |
-
sadtalker_paths,
|
36 |
-
device,
|
37 |
-
),
|
38 |
-
"others": AnimateFromCoeff(
|
39 |
-
sadtalker_paths,
|
40 |
-
device,
|
41 |
-
),
|
42 |
-
}
|
43 |
-
|
44 |
-
def predict(
|
45 |
-
self,
|
46 |
-
source_image: Path = Input(
|
47 |
-
description="Upload the source image, it can be video.mp4 or picture.png",
|
48 |
-
),
|
49 |
-
driven_audio: Path = Input(
|
50 |
-
description="Upload the driven audio, accepts .wav and .mp4 file",
|
51 |
-
),
|
52 |
-
enhancer: str = Input(
|
53 |
-
description="Choose a face enhancer",
|
54 |
-
choices=["gfpgan", "RestoreFormer"],
|
55 |
-
default="gfpgan",
|
56 |
-
),
|
57 |
-
preprocess: str = Input(
|
58 |
-
description="how to preprocess the images",
|
59 |
-
choices=["crop", "resize", "full"],
|
60 |
-
default="full",
|
61 |
-
),
|
62 |
-
ref_eyeblink: Path = Input(
|
63 |
-
description="path to reference video providing eye blinking",
|
64 |
-
default=None,
|
65 |
-
),
|
66 |
-
ref_pose: Path = Input(
|
67 |
-
description="path to reference video providing pose",
|
68 |
-
default=None,
|
69 |
-
),
|
70 |
-
still: bool = Input(
|
71 |
-
description="can crop back to the original videos for the full body aniamtion when preprocess is full",
|
72 |
-
default=True,
|
73 |
-
),
|
74 |
-
) -> Path:
|
75 |
-
"""Run a single prediction on the model"""
|
76 |
-
|
77 |
-
animate_from_coeff = (
|
78 |
-
self.animate_from_coeff["full"]
|
79 |
-
if preprocess == "full"
|
80 |
-
else self.animate_from_coeff["others"]
|
81 |
-
)
|
82 |
-
|
83 |
-
args = load_default()
|
84 |
-
args.pic_path = str(source_image)
|
85 |
-
args.audio_path = str(driven_audio)
|
86 |
-
device = "cuda"
|
87 |
-
args.still = still
|
88 |
-
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
|
89 |
-
args.ref_pose = None if ref_pose is None else str(ref_pose)
|
90 |
-
|
91 |
-
# crop image and extract 3dmm from image
|
92 |
-
results_dir = "results"
|
93 |
-
if os.path.exists(results_dir):
|
94 |
-
shutil.rmtree(results_dir)
|
95 |
-
os.makedirs(results_dir)
|
96 |
-
first_frame_dir = os.path.join(results_dir, "first_frame_dir")
|
97 |
-
os.makedirs(first_frame_dir)
|
98 |
-
|
99 |
-
print("3DMM Extraction for source image")
|
100 |
-
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
|
101 |
-
args.pic_path, first_frame_dir, preprocess, source_image_flag=True
|
102 |
-
)
|
103 |
-
if first_coeff_path is None:
|
104 |
-
print("Can't get the coeffs of the input")
|
105 |
-
return
|
106 |
-
|
107 |
-
if ref_eyeblink is not None:
|
108 |
-
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
|
109 |
-
0
|
110 |
-
]
|
111 |
-
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
|
112 |
-
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
|
113 |
-
print("3DMM Extraction for the reference video providing eye blinking")
|
114 |
-
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
|
115 |
-
ref_eyeblink, ref_eyeblink_frame_dir
|
116 |
-
)
|
117 |
-
else:
|
118 |
-
ref_eyeblink_coeff_path = None
|
119 |
-
|
120 |
-
if ref_pose is not None:
|
121 |
-
if ref_pose == ref_eyeblink:
|
122 |
-
ref_pose_coeff_path = ref_eyeblink_coeff_path
|
123 |
-
else:
|
124 |
-
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
|
125 |
-
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
|
126 |
-
os.makedirs(ref_pose_frame_dir, exist_ok=True)
|
127 |
-
print("3DMM Extraction for the reference video providing pose")
|
128 |
-
ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
|
129 |
-
ref_pose, ref_pose_frame_dir
|
130 |
-
)
|
131 |
-
else:
|
132 |
-
ref_pose_coeff_path = None
|
133 |
-
|
134 |
-
# audio2ceoff
|
135 |
-
batch = get_data(
|
136 |
-
first_coeff_path,
|
137 |
-
args.audio_path,
|
138 |
-
device,
|
139 |
-
ref_eyeblink_coeff_path,
|
140 |
-
still=still,
|
141 |
-
)
|
142 |
-
coeff_path = self.audio_to_coeff.generate(
|
143 |
-
batch, results_dir, args.pose_style, ref_pose_coeff_path
|
144 |
-
)
|
145 |
-
# coeff2video
|
146 |
-
print("coeff2video")
|
147 |
-
data = get_facerender_data(
|
148 |
-
coeff_path,
|
149 |
-
crop_pic_path,
|
150 |
-
first_coeff_path,
|
151 |
-
args.audio_path,
|
152 |
-
args.batch_size,
|
153 |
-
args.input_yaw,
|
154 |
-
args.input_pitch,
|
155 |
-
args.input_roll,
|
156 |
-
expression_scale=args.expression_scale,
|
157 |
-
still_mode=still,
|
158 |
-
preprocess=preprocess,
|
159 |
-
)
|
160 |
-
animate_from_coeff.generate(
|
161 |
-
data, results_dir, args.pic_path, crop_info,
|
162 |
-
enhancer=enhancer, background_enhancer=args.background_enhancer,
|
163 |
-
preprocess=preprocess)
|
164 |
-
|
165 |
-
output = "/tmp/out.mp4"
|
166 |
-
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
|
167 |
-
shutil.copy(mp4_path, output)
|
168 |
-
|
169 |
-
return Path(output)
|
170 |
-
|
171 |
-
|
172 |
-
def load_default():
|
173 |
-
return Namespace(
|
174 |
-
pose_style=0,
|
175 |
-
batch_size=2,
|
176 |
-
expression_scale=1.0,
|
177 |
-
input_yaw=None,
|
178 |
-
input_pitch=None,
|
179 |
-
input_roll=None,
|
180 |
-
background_enhancer=None,
|
181 |
-
face3dvis=False,
|
182 |
-
net_recon="resnet50",
|
183 |
-
init_path=None,
|
184 |
-
use_last_fc=False,
|
185 |
-
bfm_folder="./src/config/",
|
186 |
-
bfm_model="BFM_model_front.mat",
|
187 |
-
focal=1015.0,
|
188 |
-
center=112.0,
|
189 |
-
camera_d=10.0,
|
190 |
-
z_near=5.0,
|
191 |
-
z_far=15.0,
|
192 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|