alexnasa commited on
Commit
18aa80b
·
verified ·
1 Parent(s): 468a4ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +687 -687
app.py CHANGED
@@ -1,688 +1,688 @@
1
- import spaces
2
- import subprocess
3
- import gradio as gr
4
-
5
- import os, sys
6
- from glob import glob
7
- from datetime import datetime
8
- import math
9
- import random
10
- import librosa
11
- import numpy as np
12
- import uuid
13
- import shutil
14
-
15
- import importlib, site, sys
16
-
17
- import torch
18
-
19
- print(f'torch version:{torch.__version__}')
20
-
21
-
22
- import torch.nn as nn
23
- from tqdm import tqdm
24
- from functools import partial
25
- from omegaconf import OmegaConf
26
- from argparse import Namespace
27
-
28
- # load the one true config you dumped
29
- _args_cfg = OmegaConf.load("args_config.yaml")
30
- args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
31
-
32
- from OmniAvatar.utils.args_config import set_global_args
33
-
34
- set_global_args(args)
35
- # args = parse_args()
36
-
37
- from OmniAvatar.utils.io_utils import load_state_dict
38
- from peft import LoraConfig, inject_adapter_in_model
39
- from OmniAvatar.models.model_manager import ModelManager
40
- from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
41
- from OmniAvatar.wan_video import WanVideoPipeline
42
- from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
43
- import torchvision.transforms as TT
44
- from transformers import Wav2Vec2FeatureExtractor
45
- import torchvision.transforms as transforms
46
- import torch.nn.functional as F
47
- from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
48
- from huggingface_hub import hf_hub_download, snapshot_download
49
-
50
- os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
51
-
52
- def tensor_to_pil(tensor):
53
- """
54
- Args:
55
- tensor: torch.Tensor with shape like
56
- (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
57
- values in [-1, 1], on any device.
58
- Returns:
59
- A PIL.Image in RGB mode.
60
- """
61
- # 1) Remove batch dim if it exists
62
- if tensor.dim() > 3 and tensor.shape[0] == 1:
63
- tensor = tensor[0]
64
-
65
- # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
66
- tensor = tensor.squeeze()
67
-
68
- # Now we should have exactly 3 dims: (C, H, W)
69
- if tensor.dim() != 3:
70
- raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
71
-
72
- # 3) Move to CPU float32
73
- tensor = tensor.cpu().float()
74
-
75
- # 4) Undo normalization from [-1,1] -> [0,1]
76
- tensor = (tensor + 1.0) / 2.0
77
-
78
- # 5) Clamp to [0,1]
79
- tensor = torch.clamp(tensor, 0.0, 1.0)
80
-
81
- # 6) To NumPy H×W×C in [0,255]
82
- np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
83
-
84
- # 7) Build PIL Image
85
- return Image.fromarray(np_img)
86
-
87
-
88
- def set_seed(seed: int = 42):
89
- random.seed(seed)
90
- np.random.seed(seed)
91
- torch.manual_seed(seed)
92
- torch.cuda.manual_seed(seed) # 设置当前GPU
93
- torch.cuda.manual_seed_all(seed) # 设置所有GPU
94
-
95
- def read_from_file(p):
96
- with open(p, "r") as fin:
97
- for l in fin:
98
- yield l.strip()
99
-
100
- def match_size(image_size, h, w):
101
- ratio_ = 9999
102
- size_ = 9999
103
- select_size = None
104
- for image_s in image_size:
105
- ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
106
- size_tmp = abs(max(image_s) - max(w, h))
107
- if ratio_tmp < ratio_:
108
- ratio_ = ratio_tmp
109
- size_ = size_tmp
110
- select_size = image_s
111
- if ratio_ == ratio_tmp:
112
- if size_ == size_tmp:
113
- select_size = image_s
114
- return select_size
115
-
116
- def resize_pad(image, ori_size, tgt_size):
117
- h, w = ori_size
118
- scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
119
- scale_h = int(h * scale_ratio)
120
- scale_w = int(w * scale_ratio)
121
-
122
- image = transforms.Resize(size=[scale_h, scale_w])(image)
123
-
124
- padding_h = tgt_size[0] - scale_h
125
- padding_w = tgt_size[1] - scale_w
126
- pad_top = padding_h // 2
127
- pad_bottom = padding_h - pad_top
128
- pad_left = padding_w // 2
129
- pad_right = padding_w - pad_left
130
-
131
- image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
132
- return image
133
-
134
- class WanInferencePipeline(nn.Module):
135
- def __init__(self, args):
136
- super().__init__()
137
- self.args = args
138
- self.device = torch.device(f"cuda")
139
- self.dtype = torch.bfloat16
140
- self.pipe = self.load_model()
141
- chained_trainsforms = []
142
- chained_trainsforms.append(TT.ToTensor())
143
- self.transform = TT.Compose(chained_trainsforms)
144
-
145
- if self.args.use_audio:
146
- from OmniAvatar.models.wav2vec import Wav2VecModel
147
- self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
148
- self.args.wav2vec_path
149
- )
150
- self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
151
- self.audio_encoder.feature_extractor._freeze_parameters()
152
-
153
-
154
- def load_model(self):
155
- ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
156
- assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
157
- if self.args.train_architecture == 'lora':
158
- self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
159
- else:
160
- resume_path = ckpt_path
161
-
162
- self.step = 0
163
-
164
- # Load models
165
- model_manager = ModelManager(device="cuda", infer=True)
166
-
167
- model_manager.load_models(
168
- [
169
- self.args.dit_path.split(","),
170
- self.args.vae_path,
171
- self.args.text_encoder_path
172
- ],
173
- torch_dtype=self.dtype,
174
- device='cuda',
175
- )
176
-
177
- pipe = WanVideoPipeline.from_model_manager(model_manager,
178
- torch_dtype=self.dtype,
179
- device="cuda",
180
- use_usp=False,
181
- infer=True)
182
-
183
- if self.args.train_architecture == "lora":
184
- print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
185
- self.add_lora_to_model(
186
- pipe.denoising_model(),
187
- lora_rank=self.args.lora_rank,
188
- lora_alpha=self.args.lora_alpha,
189
- lora_target_modules=self.args.lora_target_modules,
190
- init_lora_weights=self.args.init_lora_weights,
191
- pretrained_lora_path=pretrained_lora_path,
192
- )
193
- print(next(pipe.denoising_model().parameters()).device)
194
- else:
195
- missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
196
- print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
197
- pipe.requires_grad_(False)
198
- pipe.eval()
199
- # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
200
- return pipe
201
-
202
- def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
203
- # Add LoRA to UNet
204
-
205
- self.lora_alpha = lora_alpha
206
- if init_lora_weights == "kaiming":
207
- init_lora_weights = True
208
-
209
- lora_config = LoraConfig(
210
- r=lora_rank,
211
- lora_alpha=lora_alpha,
212
- init_lora_weights=init_lora_weights,
213
- target_modules=lora_target_modules.split(","),
214
- )
215
- model = inject_adapter_in_model(lora_config, model)
216
-
217
- # Lora pretrained lora weights
218
- if pretrained_lora_path is not None:
219
- state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
220
- if state_dict_converter is not None:
221
- state_dict = state_dict_converter(state_dict)
222
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
223
- all_keys = [i for i, _ in model.named_parameters()]
224
- num_updated_keys = len(all_keys) - len(missing_keys)
225
- num_unexpected_keys = len(unexpected_keys)
226
-
227
- print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
228
-
229
- def get_times(self, prompt,
230
- image_path=None,
231
- audio_path=None,
232
- seq_len=101, # not used while audio_path is not None
233
- height=720,
234
- width=720,
235
- overlap_frame=None,
236
- num_steps=None,
237
- negative_prompt=None,
238
- guidance_scale=None,
239
- audio_scale=None):
240
-
241
- overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
242
- num_steps = num_steps if num_steps is not None else self.args.num_steps
243
- negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
244
- guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
245
- audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
246
-
247
- if image_path is not None:
248
- from PIL import Image
249
- image = Image.open(image_path).convert("RGB")
250
-
251
- image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
252
-
253
- _, _, h, w = image.shape
254
- select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
255
- image = resize_pad(image, (h, w), select_size)
256
- image = image * 2.0 - 1.0
257
- image = image[:, :, None]
258
-
259
- else:
260
- image = None
261
- select_size = [height, width]
262
- num = self.args.max_tokens * 16 * 16 * 4
263
- den = select_size[0] * select_size[1]
264
- L0 = num // den
265
- diff = (L0 - 1) % 4
266
- L = L0 - diff
267
- if L < 1:
268
- L = 1
269
- T = (L + 3) // 4
270
-
271
-
272
- if self.args.random_prefix_frames:
273
- fixed_frame = overlap_frame
274
- assert fixed_frame % 4 == 1
275
- else:
276
- fixed_frame = 1
277
- prefix_lat_frame = (3 + fixed_frame) // 4
278
- first_fixed_frame = 1
279
-
280
-
281
- audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
282
-
283
- input_values = np.squeeze(
284
- self.wav_feature_extractor(audio, sampling_rate=16000).input_values
285
- )
286
- input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
287
- audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
288
-
289
- if audio_len < L - first_fixed_frame:
290
- audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
291
- elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
292
- audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
293
-
294
- seq_len = audio_len
295
-
296
- times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
297
- if times * (L-fixed_frame) + fixed_frame < seq_len:
298
- times += 1
299
-
300
- return times
301
-
302
- @torch.no_grad()
303
- def forward(self, prompt,
304
- image_path=None,
305
- audio_path=None,
306
- seq_len=101, # not used while audio_path is not None
307
- height=720,
308
- width=720,
309
- overlap_frame=None,
310
- num_steps=None,
311
- negative_prompt=None,
312
- guidance_scale=None,
313
- audio_scale=None):
314
- overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
315
- num_steps = num_steps if num_steps is not None else self.args.num_steps
316
- negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
317
- guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
318
- audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
319
-
320
- if image_path is not None:
321
- from PIL import Image
322
- image = Image.open(image_path).convert("RGB")
323
-
324
- image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
325
-
326
- _, _, h, w = image.shape
327
- select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
328
- image = resize_pad(image, (h, w), select_size)
329
- image = image * 2.0 - 1.0
330
- image = image[:, :, None]
331
-
332
- else:
333
- image = None
334
- select_size = [height, width]
335
- # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
336
- # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
337
- # T = (L + 3) // 4 # latent frames
338
-
339
- # step 1: numerator and denominator as ints
340
- num = args.max_tokens * 16 * 16 * 4
341
- den = select_size[0] * select_size[1]
342
-
343
- # step 2: integer division
344
- L0 = num // den # exact floor division, no float in sight
345
-
346
- # step 3: make it ≡ 1 mod 4
347
- # if L0 % 4 == 1, keep L0;
348
- # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
349
- # but ensure the result stays positive.
350
- diff = (L0 - 1) % 4
351
- L = L0 - diff
352
- if L < 1:
353
- L = 1 # or whatever your minimal frame count is
354
-
355
- # step 4: latent frames
356
- T = (L + 3) // 4
357
-
358
-
359
- if self.args.i2v:
360
- if self.args.random_prefix_frames:
361
- fixed_frame = overlap_frame
362
- assert fixed_frame % 4 == 1
363
- else:
364
- fixed_frame = 1
365
- prefix_lat_frame = (3 + fixed_frame) // 4
366
- first_fixed_frame = 1
367
- else:
368
- fixed_frame = 0
369
- prefix_lat_frame = 0
370
- first_fixed_frame = 0
371
-
372
-
373
- if audio_path is not None and self.args.use_audio:
374
- audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
375
- input_values = np.squeeze(
376
- self.wav_feature_extractor(audio, sampling_rate=16000).input_values
377
- )
378
- input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
379
- ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
380
- input_values = input_values.unsqueeze(0)
381
- # padding audio
382
- if audio_len < L - first_fixed_frame:
383
- audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
384
- elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
385
- audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
386
- input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
387
- with torch.no_grad():
388
- hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
389
- audio_embeddings = hidden_states.last_hidden_state
390
- for mid_hidden_states in hidden_states.hidden_states:
391
- audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
392
- seq_len = audio_len
393
- audio_embeddings = audio_embeddings.squeeze(0)
394
- audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
395
- else:
396
- audio_embeddings = None
397
-
398
- # loop
399
- times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
400
- if times * (L-fixed_frame) + fixed_frame < seq_len:
401
- times += 1
402
- video = []
403
- image_emb = {}
404
- img_lat = None
405
- if self.args.i2v:
406
- self.pipe.load_models_to_device(['vae'])
407
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
408
-
409
- msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
410
- image_cat = img_lat.repeat(1, 1, T, 1, 1)
411
- msk[:, :, 1:] = 1
412
- image_emb["y"] = torch.cat([image_cat, msk], dim=1)
413
-
414
- for t in range(times):
415
- print(f"[{t+1}/{times}]")
416
- audio_emb = {}
417
- if t == 0:
418
- overlap = first_fixed_frame
419
- else:
420
- overlap = fixed_frame
421
- image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
422
- prefix_overlap = (3 + overlap) // 4
423
- if audio_embeddings is not None:
424
- if t == 0:
425
- audio_tensor = audio_embeddings[
426
- :min(L - overlap, audio_embeddings.shape[0])
427
- ]
428
- else:
429
- audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
430
- audio_tensor = audio_embeddings[
431
- audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
432
- ]
433
-
434
- audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
435
- audio_prefix = audio_tensor[-fixed_frame:]
436
- audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
437
- audio_emb["audio_emb"] = audio_tensor
438
- else:
439
- audio_prefix = None
440
- if image is not None and img_lat is None:
441
- self.pipe.load_models_to_device(['vae'])
442
- img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
443
- assert img_lat.shape[2] == prefix_overlap
444
- img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
445
- frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
446
- negative_prompt, num_inference_steps=num_steps,
447
- cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
448
- return_latent=True,
449
- tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
450
-
451
- torch.cuda.empty_cache()
452
- img_lat = None
453
- image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
454
-
455
- if t == 0:
456
- video.append(frames)
457
- else:
458
- video.append(frames[:, overlap:])
459
- video = torch.cat(video, dim=1)
460
- video = video[:, :ori_audio_len + 1]
461
-
462
- return video
463
-
464
-
465
- snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
466
- snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
467
- snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
468
-
469
-
470
- # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
471
- # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
472
- # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
473
-
474
- import tempfile
475
-
476
- from PIL import Image
477
-
478
-
479
- set_seed(args.seed)
480
- seq_len = args.seq_len
481
- inferpipe = WanInferencePipeline(args)
482
-
483
-
484
- def update_generate_button(image_path, audio_path, text, num_steps):
485
-
486
- if image_path is None or audio_path is None:
487
- return gr.update(value="⌚ Zero GPU Required: --")
488
-
489
- duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
490
-
491
- return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s")
492
-
493
- def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
494
-
495
- audio_chunks = inferpipe.get_times(
496
- prompt=text,
497
- image_path=image_path,
498
- audio_path=audio_path,
499
- seq_len=args.seq_len,
500
- num_steps=num_steps
501
- )
502
-
503
- warmup_s = 30
504
- duration_s = (20 * num_steps) + warmup_s
505
-
506
- if audio_chunks > 1:
507
- duration_s = (20 * num_steps * audio_chunks) + warmup_s
508
-
509
- print(f'for {audio_chunks} times, might take {duration_s}')
510
-
511
- return int(duration_s)
512
-
513
- def preprocess_img(image_path, session_id = None):
514
-
515
- if session_id is None:
516
- session_id = uuid.uuid4().hex
517
-
518
- image = Image.open(image_path).convert("RGB")
519
-
520
- image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
521
-
522
- _, _, h, w = image.shape
523
- select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
524
- image = resize_pad(image, (h, w), select_size)
525
- image = image * 2.0 - 1.0
526
- image = image[:, :, None]
527
-
528
- output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
529
-
530
- img_dir = output_dir + '/image'
531
- os.makedirs(img_dir, exist_ok=True)
532
- input_img_path = os.path.join(img_dir, f"img_input.jpg")
533
-
534
- image = tensor_to_pil(image)
535
- image.save(input_img_path)
536
-
537
- return input_img_path
538
-
539
-
540
- @spaces.GPU(duration=get_duration)
541
- def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
542
-
543
- if session_id is None:
544
- session_id = uuid.uuid4().hex
545
-
546
- output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
547
-
548
- audio_dir = output_dir + '/audio'
549
- os.makedirs(audio_dir, exist_ok=True)
550
- if args.silence_duration_s > 0:
551
- input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
552
- else:
553
- input_audio_path = audio_path
554
- prompt_dir = output_dir + '/prompt'
555
- os.makedirs(prompt_dir, exist_ok=True)
556
-
557
- if args.silence_duration_s > 0:
558
- add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
559
-
560
- tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
561
- prompt_path = os.path.join(prompt_dir, f"prompt.txt")
562
-
563
- video = inferpipe(
564
- prompt=text,
565
- image_path=image_path,
566
- audio_path=input_audio_path,
567
- seq_len=args.seq_len,
568
- num_steps=num_steps
569
- )
570
-
571
- torch.cuda.empty_cache()
572
-
573
- add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
574
- video_paths = save_video_as_grid_and_mp4(video,
575
- output_dir,
576
- args.fps,
577
- prompt=text,
578
- prompt_path = prompt_path,
579
- audio_path=tmp2_audio_path if args.use_audio else None,
580
- prefix=f'result')
581
-
582
- return video_paths[0]
583
-
584
- def cleanup(request: gr.Request):
585
-
586
- sid = request.session_hash
587
- if sid:
588
- d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
589
- shutil.rmtree(d1, ignore_errors=True)
590
-
591
- def start_session(request: gr.Request):
592
-
593
- return request.session_hash
594
-
595
- css = """
596
- #col-container {
597
- margin: 0 auto;
598
- max-width: 1560px;
599
- }
600
- """
601
- theme = gr.themes.Ocean()
602
-
603
- with gr.Blocks(css=css, theme=theme) as demo:
604
-
605
- session_state = gr.State()
606
- demo.load(start_session, outputs=[session_state])
607
-
608
-
609
- with gr.Column(elem_id="col-container"):
610
- gr.HTML(
611
- """
612
- <div style="text-align: left;">
613
- <p style="font-size:16px; display: inline; margin: 0;">
614
- <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
615
- </p>
616
- <a href="https://github.com/Omni-Avatar/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
617
- <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
618
- </a>
619
- </div>
620
- <div style="text-align: left;">
621
- HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
622
- <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
623
- </a>
624
- </div>
625
-
626
- <div style="text-align: left;">
627
- <a href="https://huggingface.co/alexnasa">
628
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
629
- </a>
630
- </div>
631
-
632
- """
633
- )
634
-
635
- with gr.Row():
636
-
637
- with gr.Column():
638
-
639
- image_input = gr.Image(label="Reference Image", type="filepath", height=512)
640
- audio_input = gr.Audio(label="Input Audio", type="filepath")
641
-
642
-
643
- with gr.Column():
644
-
645
- output_video = gr.Video(label="Avatar", height=512)
646
- num_steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
647
- time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
648
- infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
649
- with gr.Accordion("Advanced Settings", open=False):
650
- text_input = gr.Textbox(label="Prompt Text", lines=4, value="A realistic video of a person speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement their speech. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
651
-
652
- with gr.Column():
653
-
654
- examples = gr.Examples(
655
- examples=[
656
- [
657
- "examples/images/female-001.png",
658
- "examples/audios/mushroom.wav",
659
- "A realistic video of a woman speaking and sometimes looking directly to the camera, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
660
- 12
661
- ],
662
- [
663
- "examples/images/male-001.png",
664
- "examples/audios/tape.wav",
665
- "A realistic video of a man moving his hands extensively and speaking. The motion of his hands matches his speech. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
666
- 8
667
- ],
668
- ],
669
- inputs=[image_input, audio_input, text_input, num_steps],
670
- outputs=[output_video],
671
- fn=infer,
672
- cache_examples=True
673
- )
674
-
675
- infer_btn.click(
676
- fn=infer,
677
- inputs=[image_input, audio_input, text_input, num_steps, session_state],
678
- outputs=[output_video]
679
- )
680
- image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input]).then(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
681
- audio_input.upload(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
682
- num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
683
-
684
-
685
- if __name__ == "__main__":
686
- demo.unload(cleanup)
687
- demo.queue()
688
  demo.launch(ssr_mode=False)
 
1
+ import spaces
2
+ import subprocess
3
+ import gradio as gr
4
+
5
+ import os, sys
6
+ from glob import glob
7
+ from datetime import datetime
8
+ import math
9
+ import random
10
+ import librosa
11
+ import numpy as np
12
+ import uuid
13
+ import shutil
14
+
15
+ import importlib, site, sys
16
+
17
+ import torch
18
+
19
+ print(f'torch version:{torch.__version__}')
20
+
21
+
22
+ import torch.nn as nn
23
+ from tqdm import tqdm
24
+ from functools import partial
25
+ from omegaconf import OmegaConf
26
+ from argparse import Namespace
27
+
28
+ # load the one true config you dumped
29
+ _args_cfg = OmegaConf.load("args_config.yaml")
30
+ args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
31
+
32
+ from OmniAvatar.utils.args_config import set_global_args
33
+
34
+ set_global_args(args)
35
+ # args = parse_args()
36
+
37
+ from OmniAvatar.utils.io_utils import load_state_dict
38
+ from peft import LoraConfig, inject_adapter_in_model
39
+ from OmniAvatar.models.model_manager import ModelManager
40
+ from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
41
+ from OmniAvatar.wan_video import WanVideoPipeline
42
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
43
+ import torchvision.transforms as TT
44
+ from transformers import Wav2Vec2FeatureExtractor
45
+ import torchvision.transforms as transforms
46
+ import torch.nn.functional as F
47
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
48
+ from huggingface_hub import hf_hub_download, snapshot_download
49
+
50
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
51
+
52
+ def tensor_to_pil(tensor):
53
+ """
54
+ Args:
55
+ tensor: torch.Tensor with shape like
56
+ (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
57
+ values in [-1, 1], on any device.
58
+ Returns:
59
+ A PIL.Image in RGB mode.
60
+ """
61
+ # 1) Remove batch dim if it exists
62
+ if tensor.dim() > 3 and tensor.shape[0] == 1:
63
+ tensor = tensor[0]
64
+
65
+ # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
66
+ tensor = tensor.squeeze()
67
+
68
+ # Now we should have exactly 3 dims: (C, H, W)
69
+ if tensor.dim() != 3:
70
+ raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
71
+
72
+ # 3) Move to CPU float32
73
+ tensor = tensor.cpu().float()
74
+
75
+ # 4) Undo normalization from [-1,1] -> [0,1]
76
+ tensor = (tensor + 1.0) / 2.0
77
+
78
+ # 5) Clamp to [0,1]
79
+ tensor = torch.clamp(tensor, 0.0, 1.0)
80
+
81
+ # 6) To NumPy H×W×C in [0,255]
82
+ np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
83
+
84
+ # 7) Build PIL Image
85
+ return Image.fromarray(np_img)
86
+
87
+
88
+ def set_seed(seed: int = 42):
89
+ random.seed(seed)
90
+ np.random.seed(seed)
91
+ torch.manual_seed(seed)
92
+ torch.cuda.manual_seed(seed) # 设置当前GPU
93
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
94
+
95
+ def read_from_file(p):
96
+ with open(p, "r") as fin:
97
+ for l in fin:
98
+ yield l.strip()
99
+
100
+ def match_size(image_size, h, w):
101
+ ratio_ = 9999
102
+ size_ = 9999
103
+ select_size = None
104
+ for image_s in image_size:
105
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
106
+ size_tmp = abs(max(image_s) - max(w, h))
107
+ if ratio_tmp < ratio_:
108
+ ratio_ = ratio_tmp
109
+ size_ = size_tmp
110
+ select_size = image_s
111
+ if ratio_ == ratio_tmp:
112
+ if size_ == size_tmp:
113
+ select_size = image_s
114
+ return select_size
115
+
116
+ def resize_pad(image, ori_size, tgt_size):
117
+ h, w = ori_size
118
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
119
+ scale_h = int(h * scale_ratio)
120
+ scale_w = int(w * scale_ratio)
121
+
122
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
123
+
124
+ padding_h = tgt_size[0] - scale_h
125
+ padding_w = tgt_size[1] - scale_w
126
+ pad_top = padding_h // 2
127
+ pad_bottom = padding_h - pad_top
128
+ pad_left = padding_w // 2
129
+ pad_right = padding_w - pad_left
130
+
131
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
132
+ return image
133
+
134
+ class WanInferencePipeline(nn.Module):
135
+ def __init__(self, args):
136
+ super().__init__()
137
+ self.args = args
138
+ self.device = torch.device(f"cuda")
139
+ self.dtype = torch.bfloat16
140
+ self.pipe = self.load_model()
141
+ chained_trainsforms = []
142
+ chained_trainsforms.append(TT.ToTensor())
143
+ self.transform = TT.Compose(chained_trainsforms)
144
+
145
+ if self.args.use_audio:
146
+ from OmniAvatar.models.wav2vec import Wav2VecModel
147
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
148
+ self.args.wav2vec_path
149
+ )
150
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
151
+ self.audio_encoder.feature_extractor._freeze_parameters()
152
+
153
+
154
+ def load_model(self):
155
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
156
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
157
+ if self.args.train_architecture == 'lora':
158
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
159
+ else:
160
+ resume_path = ckpt_path
161
+
162
+ self.step = 0
163
+
164
+ # Load models
165
+ model_manager = ModelManager(device="cuda", infer=True)
166
+
167
+ model_manager.load_models(
168
+ [
169
+ self.args.dit_path.split(","),
170
+ self.args.vae_path,
171
+ self.args.text_encoder_path
172
+ ],
173
+ torch_dtype=self.dtype,
174
+ device='cuda',
175
+ )
176
+
177
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
178
+ torch_dtype=self.dtype,
179
+ device="cuda",
180
+ use_usp=False,
181
+ infer=True)
182
+
183
+ if self.args.train_architecture == "lora":
184
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
185
+ self.add_lora_to_model(
186
+ pipe.denoising_model(),
187
+ lora_rank=self.args.lora_rank,
188
+ lora_alpha=self.args.lora_alpha,
189
+ lora_target_modules=self.args.lora_target_modules,
190
+ init_lora_weights=self.args.init_lora_weights,
191
+ pretrained_lora_path=pretrained_lora_path,
192
+ )
193
+ print(next(pipe.denoising_model().parameters()).device)
194
+ else:
195
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
196
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
197
+ pipe.requires_grad_(False)
198
+ pipe.eval()
199
+ # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
200
+ return pipe
201
+
202
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
203
+ # Add LoRA to UNet
204
+
205
+ self.lora_alpha = lora_alpha
206
+ if init_lora_weights == "kaiming":
207
+ init_lora_weights = True
208
+
209
+ lora_config = LoraConfig(
210
+ r=lora_rank,
211
+ lora_alpha=lora_alpha,
212
+ init_lora_weights=init_lora_weights,
213
+ target_modules=lora_target_modules.split(","),
214
+ )
215
+ model = inject_adapter_in_model(lora_config, model)
216
+
217
+ # Lora pretrained lora weights
218
+ if pretrained_lora_path is not None:
219
+ state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
220
+ if state_dict_converter is not None:
221
+ state_dict = state_dict_converter(state_dict)
222
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
223
+ all_keys = [i for i, _ in model.named_parameters()]
224
+ num_updated_keys = len(all_keys) - len(missing_keys)
225
+ num_unexpected_keys = len(unexpected_keys)
226
+
227
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
228
+
229
+ def get_times(self, prompt,
230
+ image_path=None,
231
+ audio_path=None,
232
+ seq_len=101, # not used while audio_path is not None
233
+ height=720,
234
+ width=720,
235
+ overlap_frame=None,
236
+ num_steps=None,
237
+ negative_prompt=None,
238
+ guidance_scale=None,
239
+ audio_scale=None):
240
+
241
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
242
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
243
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
244
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
245
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
246
+
247
+ if image_path is not None:
248
+ from PIL import Image
249
+ image = Image.open(image_path).convert("RGB")
250
+
251
+ image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
252
+
253
+ _, _, h, w = image.shape
254
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
255
+ image = resize_pad(image, (h, w), select_size)
256
+ image = image * 2.0 - 1.0
257
+ image = image[:, :, None]
258
+
259
+ else:
260
+ image = None
261
+ select_size = [height, width]
262
+ num = self.args.max_tokens * 16 * 16 * 4
263
+ den = select_size[0] * select_size[1]
264
+ L0 = num // den
265
+ diff = (L0 - 1) % 4
266
+ L = L0 - diff
267
+ if L < 1:
268
+ L = 1
269
+ T = (L + 3) // 4
270
+
271
+
272
+ if self.args.random_prefix_frames:
273
+ fixed_frame = overlap_frame
274
+ assert fixed_frame % 4 == 1
275
+ else:
276
+ fixed_frame = 1
277
+ prefix_lat_frame = (3 + fixed_frame) // 4
278
+ first_fixed_frame = 1
279
+
280
+
281
+ audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
282
+
283
+ input_values = np.squeeze(
284
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
285
+ )
286
+ input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
287
+ audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
288
+
289
+ if audio_len < L - first_fixed_frame:
290
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
291
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
292
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
293
+
294
+ seq_len = audio_len
295
+
296
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
297
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
298
+ times += 1
299
+
300
+ return times
301
+
302
+ @torch.no_grad()
303
+ def forward(self, prompt,
304
+ image_path=None,
305
+ audio_path=None,
306
+ seq_len=101, # not used while audio_path is not None
307
+ height=720,
308
+ width=720,
309
+ overlap_frame=None,
310
+ num_steps=None,
311
+ negative_prompt=None,
312
+ guidance_scale=None,
313
+ audio_scale=None):
314
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
315
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
316
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
317
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
318
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
319
+
320
+ if image_path is not None:
321
+ from PIL import Image
322
+ image = Image.open(image_path).convert("RGB")
323
+
324
+ image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
325
+
326
+ _, _, h, w = image.shape
327
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
328
+ image = resize_pad(image, (h, w), select_size)
329
+ image = image * 2.0 - 1.0
330
+ image = image[:, :, None]
331
+
332
+ else:
333
+ image = None
334
+ select_size = [height, width]
335
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
336
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
337
+ # T = (L + 3) // 4 # latent frames
338
+
339
+ # step 1: numerator and denominator as ints
340
+ num = args.max_tokens * 16 * 16 * 4
341
+ den = select_size[0] * select_size[1]
342
+
343
+ # step 2: integer division
344
+ L0 = num // den # exact floor division, no float in sight
345
+
346
+ # step 3: make it ≡ 1 mod 4
347
+ # if L0 % 4 == 1, keep L0;
348
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
349
+ # but ensure the result stays positive.
350
+ diff = (L0 - 1) % 4
351
+ L = L0 - diff
352
+ if L < 1:
353
+ L = 1 # or whatever your minimal frame count is
354
+
355
+ # step 4: latent frames
356
+ T = (L + 3) // 4
357
+
358
+
359
+ if self.args.i2v:
360
+ if self.args.random_prefix_frames:
361
+ fixed_frame = overlap_frame
362
+ assert fixed_frame % 4 == 1
363
+ else:
364
+ fixed_frame = 1
365
+ prefix_lat_frame = (3 + fixed_frame) // 4
366
+ first_fixed_frame = 1
367
+ else:
368
+ fixed_frame = 0
369
+ prefix_lat_frame = 0
370
+ first_fixed_frame = 0
371
+
372
+
373
+ if audio_path is not None and self.args.use_audio:
374
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
375
+ input_values = np.squeeze(
376
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
377
+ )
378
+ input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
379
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
380
+ input_values = input_values.unsqueeze(0)
381
+ # padding audio
382
+ if audio_len < L - first_fixed_frame:
383
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
384
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
385
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
386
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
387
+ with torch.no_grad():
388
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
389
+ audio_embeddings = hidden_states.last_hidden_state
390
+ for mid_hidden_states in hidden_states.hidden_states:
391
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
392
+ seq_len = audio_len
393
+ audio_embeddings = audio_embeddings.squeeze(0)
394
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
395
+ else:
396
+ audio_embeddings = None
397
+
398
+ # loop
399
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
400
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
401
+ times += 1
402
+ video = []
403
+ image_emb = {}
404
+ img_lat = None
405
+ if self.args.i2v:
406
+ self.pipe.load_models_to_device(['vae'])
407
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
408
+
409
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
410
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
411
+ msk[:, :, 1:] = 1
412
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
413
+
414
+ for t in range(times):
415
+ print(f"[{t+1}/{times}]")
416
+ audio_emb = {}
417
+ if t == 0:
418
+ overlap = first_fixed_frame
419
+ else:
420
+ overlap = fixed_frame
421
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
422
+ prefix_overlap = (3 + overlap) // 4
423
+ if audio_embeddings is not None:
424
+ if t == 0:
425
+ audio_tensor = audio_embeddings[
426
+ :min(L - overlap, audio_embeddings.shape[0])
427
+ ]
428
+ else:
429
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
430
+ audio_tensor = audio_embeddings[
431
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
432
+ ]
433
+
434
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
435
+ audio_prefix = audio_tensor[-fixed_frame:]
436
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
437
+ audio_emb["audio_emb"] = audio_tensor
438
+ else:
439
+ audio_prefix = None
440
+ if image is not None and img_lat is None:
441
+ self.pipe.load_models_to_device(['vae'])
442
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
443
+ assert img_lat.shape[2] == prefix_overlap
444
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
445
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
446
+ negative_prompt, num_inference_steps=num_steps,
447
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
448
+ return_latent=True,
449
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
450
+
451
+ torch.cuda.empty_cache()
452
+ img_lat = None
453
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
454
+
455
+ if t == 0:
456
+ video.append(frames)
457
+ else:
458
+ video.append(frames[:, overlap:])
459
+ video = torch.cat(video, dim=1)
460
+ video = video[:, :ori_audio_len + 1]
461
+
462
+ return video
463
+
464
+
465
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
466
+ snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
467
+ snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
468
+
469
+
470
+ # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
471
+ # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
472
+ # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
473
+
474
+ import tempfile
475
+
476
+ from PIL import Image
477
+
478
+
479
+ set_seed(args.seed)
480
+ seq_len = args.seq_len
481
+ inferpipe = WanInferencePipeline(args)
482
+
483
+
484
+ def update_generate_button(image_path, audio_path, text, num_steps):
485
+
486
+ if image_path is None or audio_path is None:
487
+ return gr.update(value="⌚ Zero GPU Required: --")
488
+
489
+ duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
490
+
491
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s")
492
+
493
+ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
494
+
495
+ audio_chunks = inferpipe.get_times(
496
+ prompt=text,
497
+ image_path=image_path,
498
+ audio_path=audio_path,
499
+ seq_len=args.seq_len,
500
+ num_steps=num_steps
501
+ )
502
+
503
+ warmup_s = 30
504
+ duration_s = (20 * num_steps) + warmup_s
505
+
506
+ if audio_chunks > 1:
507
+ duration_s = (20 * num_steps * audio_chunks) + warmup_s
508
+
509
+ print(f'for {audio_chunks} times, might take {duration_s}')
510
+
511
+ return int(duration_s)
512
+
513
+ def preprocess_img(image_path, session_id = None):
514
+
515
+ if session_id is None:
516
+ session_id = uuid.uuid4().hex
517
+
518
+ image = Image.open(image_path).convert("RGB")
519
+
520
+ image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
521
+
522
+ _, _, h, w = image.shape
523
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
524
+ image = resize_pad(image, (h, w), select_size)
525
+ image = image * 2.0 - 1.0
526
+ image = image[:, :, None]
527
+
528
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
529
+
530
+ img_dir = output_dir + '/image'
531
+ os.makedirs(img_dir, exist_ok=True)
532
+ input_img_path = os.path.join(img_dir, f"img_input.jpg")
533
+
534
+ image = tensor_to_pil(image)
535
+ image.save(input_img_path)
536
+
537
+ return input_img_path
538
+
539
+
540
+ @spaces.GPU(duration=get_duration)
541
+ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
542
+
543
+ if session_id is None:
544
+ session_id = uuid.uuid4().hex
545
+
546
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
547
+
548
+ audio_dir = output_dir + '/audio'
549
+ os.makedirs(audio_dir, exist_ok=True)
550
+ if args.silence_duration_s > 0:
551
+ input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
552
+ else:
553
+ input_audio_path = audio_path
554
+ prompt_dir = output_dir + '/prompt'
555
+ os.makedirs(prompt_dir, exist_ok=True)
556
+
557
+ if args.silence_duration_s > 0:
558
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
559
+
560
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
561
+ prompt_path = os.path.join(prompt_dir, f"prompt.txt")
562
+
563
+ video = inferpipe(
564
+ prompt=text,
565
+ image_path=image_path,
566
+ audio_path=input_audio_path,
567
+ seq_len=args.seq_len,
568
+ num_steps=num_steps
569
+ )
570
+
571
+ torch.cuda.empty_cache()
572
+
573
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
574
+ video_paths = save_video_as_grid_and_mp4(video,
575
+ output_dir,
576
+ args.fps,
577
+ prompt=text,
578
+ prompt_path = prompt_path,
579
+ audio_path=tmp2_audio_path if args.use_audio else None,
580
+ prefix=f'result')
581
+
582
+ return video_paths[0]
583
+
584
+ def cleanup(request: gr.Request):
585
+
586
+ sid = request.session_hash
587
+ if sid:
588
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
589
+ shutil.rmtree(d1, ignore_errors=True)
590
+
591
+ def start_session(request: gr.Request):
592
+
593
+ return request.session_hash
594
+
595
+ css = """
596
+ #col-container {
597
+ margin: 0 auto;
598
+ max-width: 1560px;
599
+ }
600
+ """
601
+ theme = gr.themes.Ocean()
602
+
603
+ with gr.Blocks(css=css, theme=theme) as demo:
604
+
605
+ session_state = gr.State()
606
+ demo.load(start_session, outputs=[session_state])
607
+
608
+
609
+ with gr.Column(elem_id="col-container"):
610
+ gr.HTML(
611
+ """
612
+ <div style="text-align: left;">
613
+ <p style="font-size:16px; display: inline; margin: 0;">
614
+ <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
615
+ </p>
616
+ <a href="https://github.com/Omni-Avatar/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
617
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
618
+ </a>
619
+ </div>
620
+ <div style="text-align: left;">
621
+ HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
622
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
623
+ </a>
624
+ </div>
625
+
626
+ <div style="text-align: left;">
627
+ <a href="https://huggingface.co/alexnasa">
628
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
629
+ </a>
630
+ </div>
631
+
632
+ """
633
+ )
634
+
635
+ with gr.Row():
636
+
637
+ with gr.Column():
638
+
639
+ image_input = gr.Image(label="Reference Image", type="filepath", height=512)
640
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
641
+
642
+
643
+ with gr.Column():
644
+
645
+ output_video = gr.Video(label="Avatar", height=512)
646
+ num_steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
647
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
648
+ infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
649
+ with gr.Accordion("Advanced Settings", open=False):
650
+ text_input = gr.Textbox(label="Prompt Text", lines=4, value="A realistic video of a person speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement their speech. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
651
+
652
+ with gr.Column():
653
+
654
+ examples = gr.Examples(
655
+ examples=[
656
+ [
657
+ "examples/images/female-001.png",
658
+ "examples/audios/mushroom.wav",
659
+ "A realistic video of a woman speaking and sometimes looking directly to the camera, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
660
+ 12
661
+ ],
662
+ [
663
+ "examples/images/male-001.png",
664
+ "examples/audios/reality.wav",
665
+ "A realistic video of a man moving his hands extensively and speaking. The motion of his hands matches his speech. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
666
+ 8
667
+ ],
668
+ ],
669
+ inputs=[image_input, audio_input, text_input, num_steps],
670
+ outputs=[output_video],
671
+ fn=infer,
672
+ cache_examples=True
673
+ )
674
+
675
+ infer_btn.click(
676
+ fn=infer,
677
+ inputs=[image_input, audio_input, text_input, num_steps, session_state],
678
+ outputs=[output_video]
679
+ )
680
+ image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input]).then(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
681
+ audio_input.upload(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
682
+ num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
683
+
684
+
685
+ if __name__ == "__main__":
686
+ demo.unload(cleanup)
687
+ demo.queue()
688
  demo.launch(ssr_mode=False)