boheng.xie commited on
Commit
35cad43
·
1 Parent(s): b7a7858

upload infer_script file

Browse files
Files changed (1) hide show
  1. infer_script.py +332 -0
infer_script.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
+ from omegaconf import OmegaConf
7
+ from skimage.metrics import structural_similarity as ssim
8
+ from collections import deque
9
+
10
+ import torch
11
+ import gc
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+
15
+ from transformers import CLIPVisionModelWithProjection
16
+
17
+ from models.guider import Guider
18
+ from models.referencenet import ReferenceNet2DConditionModel
19
+ from models.unet import UNet3DConditionModel
20
+ from models.video_pipeline import VideoPipeline
21
+
22
+ from dataset.val_dataset import ValDataset, val_collate_fn
23
+
24
+ def load_model_state_dict(model, model_ckpt_path, name):
25
+ ckpt = torch.load(model_ckpt_path, map_location="cpu")
26
+ model_state_dict = model.state_dict()
27
+ model_new_sd = {}
28
+ count = 0
29
+ for k, v in ckpt.items():
30
+ if k in model_state_dict:
31
+ count += 1
32
+ model_new_sd[k] = v
33
+ miss, _ = model.load_state_dict(model_new_sd, strict=False)
34
+ print(f'load {name} from {model_ckpt_path}\n - load params: {count}\n - miss params: {miss}')
35
+
36
+ def frame_analysis(prev_frame, curr_frame):
37
+ prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY)
38
+ curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY)
39
+
40
+ ssim_score = ssim(prev_gray, curr_gray)
41
+ mean_diff = np.mean(np.abs(curr_frame.astype(float) - prev_frame.astype(float)))
42
+
43
+ return ssim_score, mean_diff
44
+
45
+ def is_anomaly(ssim_score, mean_diff, ssim_history, mean_diff_history):
46
+ if len(ssim_history) < 5:
47
+ return False
48
+
49
+ ssim_avg = np.mean(ssim_history)
50
+ mean_diff_avg = np.mean(mean_diff_history)
51
+
52
+ ssim_threshold = 0.85
53
+ mean_diff_threshold = 6.0
54
+
55
+ ssim_change_threshold = 0.05
56
+ mean_diff_change_threshold = 3.0
57
+
58
+ if (ssim_score < ssim_threshold and mean_diff > mean_diff_threshold) or \
59
+ (ssim_score < ssim_avg - ssim_change_threshold and mean_diff > mean_diff_avg + mean_diff_change_threshold):
60
+ return True
61
+
62
+ return False
63
+
64
+ @torch.no_grad()
65
+ def visualize(dataloader, pipeline, generator, W, H, video_length, num_inference_steps, guidance_scale, output_path, output_fps=7, limit=1, show_stats=False, anomaly_action="none", callback_steps=1, context_frames=24, context_stride=1, context_overlap=4, context_batch_size=1,interpolation_factor=1):
66
+ oo_video_path = None
67
+ all_video_path = None
68
+
69
+ for i, batch in enumerate(dataloader):
70
+ ref_frame = batch['ref_frame'][0]
71
+ clip_image = batch['clip_image'][0]
72
+ motions = batch['motions'][0]
73
+ file_name = batch['file_name'][0]
74
+ if motions is None:
75
+ continue
76
+ if 'lmk_name' in batch:
77
+ lmk_name = batch['lmk_name'][0].split('.')[0]
78
+ else:
79
+ lmk_name = 'lmk'
80
+ print(file_name, lmk_name)
81
+
82
+ ref_frame = torch.clamp((ref_frame + 1.0) / 2.0, min=0, max=1)
83
+ ref_frame = ref_frame.permute((1, 2, 3, 0)).squeeze()
84
+ ref_frame = (ref_frame * 255).cpu().numpy().astype(np.uint8)
85
+ ref_image = Image.fromarray(ref_frame)
86
+
87
+ motions = motions.permute((1, 2, 3, 0))
88
+ motions = (motions * 255).cpu().numpy().astype(np.uint8)
89
+ lmk_images = [Image.fromarray(motion) for motion in motions]
90
+
91
+ preds = pipeline(ref_image=ref_image,
92
+ lmk_images=lmk_images,
93
+ width=W,
94
+ height=H,
95
+ video_length=video_length,
96
+ num_inference_steps=num_inference_steps,
97
+ guidance_scale=guidance_scale,
98
+ generator=generator,
99
+ clip_image=clip_image,
100
+ callback_steps=callback_steps,
101
+ context_frames=context_frames,
102
+ context_stride=context_stride,
103
+ context_overlap=context_overlap,
104
+ context_batch_size=context_batch_size,
105
+ interpolation_factor=interpolation_factor
106
+ ).videos
107
+
108
+ preds = preds.permute((0,2,3,4,1)).squeeze(0)
109
+ preds = (preds * 255).cpu().numpy().astype(np.uint8)
110
+
111
+ # Сохраняем все кадры
112
+ frames_dir = os.path.join(output_path, f"frames")
113
+ os.makedirs(frames_dir, exist_ok=True)
114
+ frame_paths = []
115
+ for idx, frame in enumerate(preds):
116
+ frame_path = os.path.join(frames_dir, f"frame_{idx:04d}.png")
117
+ imageio.imwrite(frame_path, frame)
118
+ frame_paths.append(frame_path)
119
+
120
+ # Обработка аномалий
121
+ filtered_frame_paths = []
122
+ prev_frame = None
123
+ ssim_history = deque(maxlen=5)
124
+ mean_diff_history = deque(maxlen=5)
125
+
126
+ for idx, frame_path in enumerate(frame_paths):
127
+ frame = imageio.imread(frame_path)
128
+ if prev_frame is not None:
129
+ ssim_score, mean_diff = frame_analysis(prev_frame, frame)
130
+ ssim_history.append(ssim_score)
131
+ mean_diff_history.append(mean_diff)
132
+
133
+ if show_stats:
134
+ print(f"Frame {idx}: SSIM: {ssim_score:.4f}, Mean Diff: {mean_diff:.4f}")
135
+
136
+ if is_anomaly(ssim_score, mean_diff, ssim_history, mean_diff_history):
137
+
138
+ if show_stats or anomaly_action != "none":
139
+ print(f"Anomaly detected in frame {idx}")
140
+
141
+ if anomaly_action == "remove":
142
+ continue
143
+ # Если "none", просто продолжаем без каких-либо действий
144
+
145
+ filtered_frame_paths.append(frame_path)
146
+ prev_frame = frame
147
+
148
+ # Создание видео из обработанных кадров
149
+ oo_video_path = os.path.join(output_path, f"{lmk_name}_oo.mp4")
150
+ imageio.mimsave(oo_video_path, [imageio.imread(frame_path) for frame_path in filtered_frame_paths], fps=output_fps)
151
+
152
+ if 'frames' in batch:
153
+ frames = batch['frames'][0]
154
+ frames = torch.clamp((frames + 1.0) / 2.0, min=0, max=1)
155
+ frames = frames.permute((1, 2, 3, 0))
156
+ frames = (frames * 255).cpu().numpy().astype(np.uint8)
157
+ combined = [np.concatenate((frame, motion, ref_frame, imageio.imread(pred_path)), axis=1)
158
+ for frame, motion, pred_path in zip(frames, motions, filtered_frame_paths)]
159
+ else:
160
+ combined = [np.concatenate((motion, ref_frame, imageio.imread(pred_path)), axis=1)
161
+ for motion, pred_path in zip(motions, filtered_frame_paths)]
162
+
163
+ all_video_path = os.path.join(output_path, f"{lmk_name}_all.mp4")
164
+ imageio.mimsave(all_video_path, combined, fps=output_fps)
165
+
166
+ if i >= limit:
167
+ break
168
+
169
+ return oo_video_path, all_video_path
170
+
171
+ @torch.no_grad()
172
+ def infer(config_path, model_path, input_path, lmk_path, output_path, model_step, seed,
173
+ resolution_w, resolution_h, video_length, num_inference_steps, guidance_scale, output_fps, show_stats,
174
+ anomaly_action, callback_steps, context_frames, context_stride, context_overlap, context_batch_size,interpolation_factor):
175
+
176
+ config = OmegaConf.load(config_path)
177
+ config.init_checkpoint = model_path
178
+ config.init_num = model_step
179
+ config.resolution_w = resolution_w
180
+ config.resolution_h = resolution_h
181
+ config.video_length = video_length
182
+
183
+ if config.weight_dtype == "fp16":
184
+ weight_dtype = torch.float16
185
+ elif config.weight_dtype == "fp32":
186
+ weight_dtype = torch.float32
187
+ else:
188
+ raise ValueError(f"Do not support weight dtype: {config.weight_dtype}")
189
+
190
+ vae = AutoencoderKL.from_pretrained(config.vae_model_path).to(dtype=weight_dtype, device="cuda")
191
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(dtype=weight_dtype, device="cuda")
192
+ referencenet = ReferenceNet2DConditionModel.from_pretrained_2d(config.base_model_path,
193
+ referencenet_additional_kwargs=config.model.referencenet_additional_kwargs).to(device="cuda")
194
+ unet = UNet3DConditionModel.from_pretrained_2d(config.base_model_path,
195
+ motion_module_path=config.motion_module_path,
196
+ unet_additional_kwargs=config.model.unet_additional_kwargs).to(device="cuda")
197
+ lmk_guider = Guider(conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)).to(device="cuda")
198
+
199
+ load_model_state_dict(referencenet, f'{config.init_checkpoint}/referencenet.pth', 'referencenet')
200
+ load_model_state_dict(unet, f'{config.init_checkpoint}/unet.pth', 'unet')
201
+ load_model_state_dict(lmk_guider, f'{config.init_checkpoint}/lmk_guider.pth', 'lmk_guider')
202
+
203
+ if config.enable_xformers_memory_efficient_attention:
204
+ if is_xformers_available():
205
+ referencenet.enable_xformers_memory_efficient_attention()
206
+ unet.enable_xformers_memory_efficient_attention()
207
+ else:
208
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
209
+
210
+ unet.set_reentrant(use_reentrant=False)
211
+ referencenet.set_reentrant(use_reentrant=False)
212
+
213
+ vae.eval()
214
+ image_encoder.eval()
215
+ unet.eval()
216
+ referencenet.eval()
217
+ lmk_guider.eval()
218
+
219
+ sched_kwargs = OmegaConf.to_container(config.scheduler)
220
+ if config.enable_zero_snr:
221
+ sched_kwargs.update(rescale_betas_zero_snr=True,
222
+ timestep_spacing="trailing",
223
+ prediction_type="v_prediction")
224
+ noise_scheduler = DDIMScheduler(**sched_kwargs)
225
+
226
+ pipeline = VideoPipeline(vae=vae,
227
+ image_encoder=image_encoder,
228
+ referencenet=referencenet,
229
+ unet=unet,
230
+ lmk_guider=lmk_guider,
231
+ scheduler=noise_scheduler).to(vae.device, dtype=weight_dtype)
232
+
233
+ val_dataset = ValDataset(
234
+ input_path=input_path,
235
+ lmk_path=lmk_path,
236
+ resolution_h=config.resolution_h,
237
+ resolution_w=config.resolution_w
238
+ )
239
+
240
+ val_dataloader = torch.utils.data.DataLoader(
241
+ val_dataset,
242
+ batch_size=1,
243
+ num_workers=0,
244
+ shuffle=False,
245
+ collate_fn=val_collate_fn,
246
+ )
247
+
248
+ generator = torch.Generator(device=vae.device)
249
+ generator.manual_seed(seed)
250
+
251
+ oo_video_path, all_video_path = visualize(
252
+ val_dataloader,
253
+ pipeline,
254
+ generator,
255
+ W=config.resolution_w,
256
+ H=config.resolution_h,
257
+ video_length=config.video_length,
258
+ num_inference_steps=num_inference_steps,
259
+ guidance_scale=guidance_scale,
260
+ output_path=output_path,
261
+ output_fps=output_fps,
262
+ show_stats=show_stats,
263
+ anomaly_action=anomaly_action,
264
+ callback_steps=callback_steps,
265
+ context_frames=context_frames,
266
+ context_stride=context_stride,
267
+ context_overlap=context_overlap,
268
+ context_batch_size=context_batch_size,
269
+ interpolation_factor=interpolation_factor,
270
+ limit=100000000
271
+ )
272
+
273
+ del vae, image_encoder, referencenet, unet, lmk_guider, pipeline
274
+ torch.cuda.empty_cache()
275
+ gc.collect()
276
+
277
+ return "Inference completed successfully", oo_video_path, all_video_path
278
+
279
+ def run_inference(config_path, model_path, input_path, lmk_path, output_path, model_step, seed,
280
+ resolution_w, resolution_h, video_length, num_inference_steps=30, guidance_scale=3.5, output_fps=30,
281
+ show_stats=False, anomaly_action="none", callback_steps=1, context_frames=24, context_stride=1,
282
+ context_overlap=4, context_batch_size=1,interpolation_factor=1):
283
+ try:
284
+ # Clear memory
285
+ torch.cuda.empty_cache()
286
+ gc.collect()
287
+
288
+ return infer(config_path, model_path, input_path, lmk_path, output_path, model_step, seed,
289
+ resolution_w, resolution_h, video_length, num_inference_steps, guidance_scale, output_fps,
290
+ show_stats, anomaly_action, callback_steps, context_frames, context_stride, context_overlap, context_batch_size,interpolation_factor)
291
+ finally:
292
+ torch.cuda.empty_cache()
293
+ gc.collect()
294
+
295
+ if __name__ == "__main__":
296
+ import argparse
297
+
298
+ parser = argparse.ArgumentParser()
299
+ parser.add_argument("--config", type=str, required=True, help="Path to the config file")
300
+ parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint")
301
+ parser.add_argument("--input", type=str, required=True, help="Path to the input image")
302
+ parser.add_argument("--lmk", type=str, required=True, help="Path to the landmark file")
303
+ parser.add_argument("--output", type=str, required=True, help="Path to save the output")
304
+ parser.add_argument("--step", type=int, default=0, help="Model step")
305
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
306
+ parser.add_argument("--width", type=int, default=512, help="Output video width")
307
+ parser.add_argument("--height", type=int, default=512, help="Output video height")
308
+ parser.add_argument("--length", type=int, default=16, help="Output video length")
309
+ parser.add_argument("--steps", type=int, default=30, help="Number of inference steps")
310
+ parser.add_argument("--guidance", type=float, default=3.5, help="Guidance scale")
311
+ parser.add_argument("--fps", type=int, default=30, help="Output video FPS")
312
+ parser.add_argument("--show-stats", action="store_true", help="Show frame statistics")
313
+ parser.add_argument("--anomaly-action", type=str, default="none", choices=["none", "remove"], help="Action for anomaly frames")
314
+ parser.add_argument("--callback-steps", type=int, default=1, help="Callback steps")
315
+ parser.add_argument("--context-frames", type=int, default=24, help="Context frames")
316
+ parser.add_argument("--context-stride", type=int, default=1, help="Context stride")
317
+ parser.add_argument("--context-overlap", type=int, default=4, help="Context overlap")
318
+ parser.add_argument("--context-batch-size", type=int, default=1, help="Context batch size")
319
+ parser.add_argument("--interpolation-factor",type=int, default=1, help="Interpolataion factor" )
320
+
321
+ args = parser.parse_args()
322
+
323
+ status, oo_path, all_path = run_inference(
324
+ args.config, args.model, args.input, args.lmk, args.output, args.step, args.seed,
325
+ args.width, args.height, args.length, args.steps, args.guidance, args.fps,
326
+ args.show_stats, args.anomaly_action, args.callback_steps, args.context_frames,
327
+ args.context_stride, args.context_overlap, args.context_batch_size,args.interpolation_factor
328
+ )
329
+
330
+ print(status)
331
+ print(f"Output video (only output): {oo_path}")
332
+ print(f"Output video (all frames): {all_path}")