kevinwang676 commited on
Commit
7acaf6f
·
verified ·
1 Parent(s): 0f16690

Delete predict.py

Browse files
Files changed (1) hide show
  1. predict.py +0 -192
predict.py DELETED
@@ -1,192 +0,0 @@
1
- """run bash scripts/download_models.sh first to prepare the weights file"""
2
- import os
3
- import shutil
4
- from argparse import Namespace
5
- from src.utils.preprocess import CropAndExtract
6
- from src.test_audio2coeff import Audio2Coeff
7
- from src.facerender.animate import AnimateFromCoeff
8
- from src.generate_batch import get_data
9
- from src.generate_facerender_batch import get_facerender_data
10
- from src.utils.init_path import init_path
11
- from cog import BasePredictor, Input, Path
12
-
13
- checkpoints = "checkpoints"
14
-
15
-
16
- class Predictor(BasePredictor):
17
- def setup(self):
18
- """Load the model into memory to make running multiple predictions efficient"""
19
- device = "cuda"
20
-
21
-
22
- sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
23
-
24
- # init model
25
- self.preprocess_model = CropAndExtract(sadtalker_paths, device
26
- )
27
-
28
- self.audio_to_coeff = Audio2Coeff(
29
- sadtalker_paths,
30
- device,
31
- )
32
-
33
- self.animate_from_coeff = {
34
- "full": AnimateFromCoeff(
35
- sadtalker_paths,
36
- device,
37
- ),
38
- "others": AnimateFromCoeff(
39
- sadtalker_paths,
40
- device,
41
- ),
42
- }
43
-
44
- def predict(
45
- self,
46
- source_image: Path = Input(
47
- description="Upload the source image, it can be video.mp4 or picture.png",
48
- ),
49
- driven_audio: Path = Input(
50
- description="Upload the driven audio, accepts .wav and .mp4 file",
51
- ),
52
- enhancer: str = Input(
53
- description="Choose a face enhancer",
54
- choices=["gfpgan", "RestoreFormer"],
55
- default="gfpgan",
56
- ),
57
- preprocess: str = Input(
58
- description="how to preprocess the images",
59
- choices=["crop", "resize", "full"],
60
- default="full",
61
- ),
62
- ref_eyeblink: Path = Input(
63
- description="path to reference video providing eye blinking",
64
- default=None,
65
- ),
66
- ref_pose: Path = Input(
67
- description="path to reference video providing pose",
68
- default=None,
69
- ),
70
- still: bool = Input(
71
- description="can crop back to the original videos for the full body aniamtion when preprocess is full",
72
- default=True,
73
- ),
74
- ) -> Path:
75
- """Run a single prediction on the model"""
76
-
77
- animate_from_coeff = (
78
- self.animate_from_coeff["full"]
79
- if preprocess == "full"
80
- else self.animate_from_coeff["others"]
81
- )
82
-
83
- args = load_default()
84
- args.pic_path = str(source_image)
85
- args.audio_path = str(driven_audio)
86
- device = "cuda"
87
- args.still = still
88
- args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
89
- args.ref_pose = None if ref_pose is None else str(ref_pose)
90
-
91
- # crop image and extract 3dmm from image
92
- results_dir = "results"
93
- if os.path.exists(results_dir):
94
- shutil.rmtree(results_dir)
95
- os.makedirs(results_dir)
96
- first_frame_dir = os.path.join(results_dir, "first_frame_dir")
97
- os.makedirs(first_frame_dir)
98
-
99
- print("3DMM Extraction for source image")
100
- first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
101
- args.pic_path, first_frame_dir, preprocess, source_image_flag=True
102
- )
103
- if first_coeff_path is None:
104
- print("Can't get the coeffs of the input")
105
- return
106
-
107
- if ref_eyeblink is not None:
108
- ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
109
- 0
110
- ]
111
- ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
112
- os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
113
- print("3DMM Extraction for the reference video providing eye blinking")
114
- ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
115
- ref_eyeblink, ref_eyeblink_frame_dir
116
- )
117
- else:
118
- ref_eyeblink_coeff_path = None
119
-
120
- if ref_pose is not None:
121
- if ref_pose == ref_eyeblink:
122
- ref_pose_coeff_path = ref_eyeblink_coeff_path
123
- else:
124
- ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
125
- ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
126
- os.makedirs(ref_pose_frame_dir, exist_ok=True)
127
- print("3DMM Extraction for the reference video providing pose")
128
- ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
129
- ref_pose, ref_pose_frame_dir
130
- )
131
- else:
132
- ref_pose_coeff_path = None
133
-
134
- # audio2ceoff
135
- batch = get_data(
136
- first_coeff_path,
137
- args.audio_path,
138
- device,
139
- ref_eyeblink_coeff_path,
140
- still=still,
141
- )
142
- coeff_path = self.audio_to_coeff.generate(
143
- batch, results_dir, args.pose_style, ref_pose_coeff_path
144
- )
145
- # coeff2video
146
- print("coeff2video")
147
- data = get_facerender_data(
148
- coeff_path,
149
- crop_pic_path,
150
- first_coeff_path,
151
- args.audio_path,
152
- args.batch_size,
153
- args.input_yaw,
154
- args.input_pitch,
155
- args.input_roll,
156
- expression_scale=args.expression_scale,
157
- still_mode=still,
158
- preprocess=preprocess,
159
- )
160
- animate_from_coeff.generate(
161
- data, results_dir, args.pic_path, crop_info,
162
- enhancer=enhancer, background_enhancer=args.background_enhancer,
163
- preprocess=preprocess)
164
-
165
- output = "/tmp/out.mp4"
166
- mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
167
- shutil.copy(mp4_path, output)
168
-
169
- return Path(output)
170
-
171
-
172
- def load_default():
173
- return Namespace(
174
- pose_style=0,
175
- batch_size=2,
176
- expression_scale=1.0,
177
- input_yaw=None,
178
- input_pitch=None,
179
- input_roll=None,
180
- background_enhancer=None,
181
- face3dvis=False,
182
- net_recon="resnet50",
183
- init_path=None,
184
- use_last_fc=False,
185
- bfm_folder="./src/config/",
186
- bfm_model="BFM_model_front.mat",
187
- focal=1015.0,
188
- center=112.0,
189
- camera_d=10.0,
190
- z_near=5.0,
191
- z_far=15.0,
192
- )