Spaces:
Running
Running
| # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import math | |
| import argparse | |
| import shutil | |
| import datetime | |
| import logging | |
| from omegaconf import OmegaConf | |
| from tqdm.auto import tqdm | |
| from einops import rearrange | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import diffusers | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils.logging import get_logger | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from accelerate.utils import set_seed | |
| from latentsync.data.unet_dataset import UNetDataset | |
| from latentsync.models.unet import UNet3DConditionModel | |
| from latentsync.models.syncnet import SyncNet | |
| from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline | |
| from latentsync.utils.util import ( | |
| init_dist, | |
| cosine_loss, | |
| reversed_forward, | |
| ) | |
| from latentsync.utils.util import plot_loss_chart, gather_loss | |
| from latentsync.whisper.audio2feature import Audio2Feature | |
| from latentsync.trepa import TREPALoss | |
| from eval.syncnet import SyncNetEval | |
| from eval.syncnet_detect import SyncNetDetector | |
| from eval.eval_sync_conf import syncnet_eval | |
| import lpips | |
| logger = get_logger(__name__) | |
| def main(config): | |
| # Initialize distributed training | |
| local_rank = init_dist() | |
| global_rank = dist.get_rank() | |
| num_processes = dist.get_world_size() | |
| is_main_process = global_rank == 0 | |
| seed = config.run.seed + global_rank | |
| set_seed(seed) | |
| # Logging folder | |
| folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S") | |
| output_dir = os.path.join(config.data.train_output_dir, folder_name) | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| # Handle the output folder creation | |
| if is_main_process: | |
| diffusers.utils.logging.set_verbosity_info() | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) | |
| os.makedirs(f"{output_dir}/val_videos", exist_ok=True) | |
| os.makedirs(f"{output_dir}/loss_charts", exist_ok=True) | |
| shutil.copy(config.unet_config_path, output_dir) | |
| shutil.copy(config.data.syncnet_config_path, output_dir) | |
| device = torch.device(local_rank) | |
| noise_scheduler = DDIMScheduler.from_pretrained("configs") | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) | |
| vae.config.scaling_factor = 0.18215 | |
| vae.config.shift_factor = 0 | |
| vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
| vae.requires_grad_(False) | |
| vae.to(device) | |
| syncnet_eval_model = SyncNetEval(device=device) | |
| syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model") | |
| syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results") | |
| if config.model.cross_attention_dim == 768: | |
| whisper_model_path = "checkpoints/whisper/small.pt" | |
| elif config.model.cross_attention_dim == 384: | |
| whisper_model_path = "checkpoints/whisper/tiny.pt" | |
| else: | |
| raise NotImplementedError("cross_attention_dim must be 768 or 384") | |
| audio_encoder = Audio2Feature( | |
| model_path=whisper_model_path, | |
| device=device, | |
| audio_embeds_cache_dir=config.data.audio_embeds_cache_dir, | |
| num_frames=config.data.num_frames, | |
| ) | |
| unet, resume_global_step = UNet3DConditionModel.from_pretrained( | |
| OmegaConf.to_container(config.model), | |
| config.ckpt.resume_ckpt_path, # load checkpoint | |
| device=device, | |
| ) | |
| if config.model.add_audio_layer and config.run.use_syncnet: | |
| syncnet_config = OmegaConf.load(config.data.syncnet_config_path) | |
| if syncnet_config.ckpt.inference_ckpt_path == "": | |
| raise ValueError("SyncNet path is not provided") | |
| syncnet = SyncNet(OmegaConf.to_container(syncnet_config.model)).to(device=device, dtype=torch.float16) | |
| syncnet_checkpoint = torch.load(syncnet_config.ckpt.inference_ckpt_path, map_location=device) | |
| syncnet.load_state_dict(syncnet_checkpoint["state_dict"]) | |
| syncnet.requires_grad_(False) | |
| unet.requires_grad_(True) | |
| trainable_params = list(unet.parameters()) | |
| if config.optimizer.scale_lr: | |
| config.optimizer.lr = config.optimizer.lr * num_processes | |
| optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr) | |
| if is_main_process: | |
| logger.info(f"trainable params number: {len(trainable_params)}") | |
| logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") | |
| # Enable xformers | |
| if config.run.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| # Enable gradient checkpointing | |
| if config.run.enable_gradient_checkpointing: | |
| unet.enable_gradient_checkpointing() | |
| # Get the training dataset | |
| train_dataset = UNetDataset(config.data.train_data_dir, config) | |
| distributed_sampler = DistributedSampler( | |
| train_dataset, | |
| num_replicas=num_processes, | |
| rank=global_rank, | |
| shuffle=True, | |
| seed=config.run.seed, | |
| ) | |
| # DataLoaders creation: | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=config.data.batch_size, | |
| shuffle=False, | |
| sampler=distributed_sampler, | |
| num_workers=config.data.num_workers, | |
| pin_memory=False, | |
| drop_last=True, | |
| worker_init_fn=train_dataset.worker_init_fn, | |
| ) | |
| # Get the training iteration | |
| if config.run.max_train_steps == -1: | |
| assert config.run.max_train_epochs != -1 | |
| config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader) | |
| # Scheduler | |
| lr_scheduler = get_scheduler( | |
| config.optimizer.lr_scheduler, | |
| optimizer=optimizer, | |
| num_warmup_steps=config.optimizer.lr_warmup_steps, | |
| num_training_steps=config.run.max_train_steps, | |
| ) | |
| if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise: | |
| lpips_loss_func = lpips.LPIPS(net="vgg").to(device) | |
| if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise: | |
| trepa_loss_func = TREPALoss(device=device) | |
| # Validation pipeline | |
| pipeline = LipsyncPipeline( | |
| vae=vae, | |
| audio_encoder=audio_encoder, | |
| unet=unet, | |
| scheduler=noise_scheduler, | |
| ).to(device) | |
| pipeline.set_progress_bar_config(disable=True) | |
| # DDP warpper | |
| unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader)) | |
| # Afterwards we recalculate our number of training epochs | |
| num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch) | |
| # Train! | |
| total_batch_size = config.data.batch_size * num_processes | |
| if is_main_process: | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {config.data.batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Total optimization steps = {config.run.max_train_steps}") | |
| global_step = resume_global_step | |
| first_epoch = resume_global_step // num_update_steps_per_epoch | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm( | |
| range(0, config.run.max_train_steps), | |
| initial=resume_global_step, | |
| desc="Steps", | |
| disable=not is_main_process, | |
| ) | |
| train_step_list = [] | |
| sync_loss_list = [] | |
| recon_loss_list = [] | |
| val_step_list = [] | |
| sync_conf_list = [] | |
| # Support mixed-precision training | |
| scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None | |
| for epoch in range(first_epoch, num_train_epochs): | |
| train_dataloader.sampler.set_epoch(epoch) | |
| unet.train() | |
| for step, batch in enumerate(train_dataloader): | |
| ### >>>> Training >>>> ### | |
| if config.model.add_audio_layer: | |
| if batch["mel"] != []: | |
| mel = batch["mel"].to(device, dtype=torch.float16) | |
| audio_embeds_list = [] | |
| try: | |
| for idx in range(len(batch["video_path"])): | |
| video_path = batch["video_path"][idx] | |
| start_idx = batch["start_idx"][idx] | |
| with torch.no_grad(): | |
| audio_feat = audio_encoder.audio2feat(video_path) | |
| audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx) | |
| audio_embeds_list.append(audio_embeds) | |
| except Exception as e: | |
| logger.info(f"{type(e).__name__} - {e} - {video_path}") | |
| continue | |
| audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384) | |
| audio_embeds = audio_embeds.to(device, dtype=torch.float16) | |
| else: | |
| audio_embeds = None | |
| # Convert videos to latent space | |
| gt_images = batch["gt"].to(device, dtype=torch.float16) | |
| gt_masked_images = batch["masked_gt"].to(device, dtype=torch.float16) | |
| mask = batch["mask"].to(device, dtype=torch.float16) | |
| ref_images = batch["ref"].to(device, dtype=torch.float16) | |
| gt_images = rearrange(gt_images, "b f c h w -> (b f) c h w") | |
| gt_masked_images = rearrange(gt_masked_images, "b f c h w -> (b f) c h w") | |
| mask = rearrange(mask, "b f c h w -> (b f) c h w") | |
| ref_images = rearrange(ref_images, "b f c h w -> (b f) c h w") | |
| with torch.no_grad(): | |
| gt_latents = vae.encode(gt_images).latent_dist.sample() | |
| gt_masked_images = vae.encode(gt_masked_images).latent_dist.sample() | |
| ref_images = vae.encode(ref_images).latent_dist.sample() | |
| mask = torch.nn.functional.interpolate(mask, size=config.data.resolution // vae_scale_factor) | |
| gt_latents = ( | |
| rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor | |
| ) * vae.config.scaling_factor | |
| gt_masked_images = ( | |
| rearrange(gt_masked_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
| - vae.config.shift_factor | |
| ) * vae.config.scaling_factor | |
| ref_images = ( | |
| rearrange(ref_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor | |
| ) * vae.config.scaling_factor | |
| mask = rearrange(mask, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
| # Sample noise that we'll add to the latents | |
| if config.run.use_mixed_noise: | |
| # Refer to the paper: https://arxiv.org/abs/2305.10474 | |
| noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5 | |
| noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev | |
| noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1) | |
| noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5 | |
| noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev | |
| noise = noise_ind + noise_shared | |
| else: | |
| noise = torch.randn_like(gt_latents) | |
| noise = noise[:, :, 0:1].repeat( | |
| 1, 1, config.data.num_frames, 1, 1 | |
| ) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716 | |
| bsz = gt_latents.shape[0] | |
| # Sample a random timestep for each video | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_tensor = noise_scheduler.add_noise(gt_latents, noise, timesteps) | |
| # Get the target for loss depending on the prediction type | |
| if noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif noise_scheduler.config.prediction_type == "v_prediction": | |
| raise NotImplementedError | |
| else: | |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
| unet_input = torch.cat([noisy_tensor, mask, gt_masked_images, ref_images], dim=1) | |
| # Predict the noise and compute loss | |
| # Mixed-precision training | |
| with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training): | |
| pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample | |
| if config.run.recon_loss_weight != 0: | |
| recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") | |
| else: | |
| recon_loss = 0 | |
| pred_latents = reversed_forward(noise_scheduler, pred_noise, timesteps, noisy_tensor) | |
| if config.run.pixel_space_supervise: | |
| pred_images = vae.decode( | |
| rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor | |
| + vae.config.shift_factor | |
| ).sample | |
| if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise: | |
| pred_images_perceptual = pred_images[:, :, pred_images.shape[2] // 2 :, :] | |
| gt_images_perceptual = gt_images[:, :, gt_images.shape[2] // 2 :, :] | |
| lpips_loss = lpips_loss_func(pred_images_perceptual.float(), gt_images_perceptual.float()).mean() | |
| else: | |
| lpips_loss = 0 | |
| if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise: | |
| trepa_pred_images = rearrange(pred_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
| trepa_gt_images = rearrange(gt_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
| trepa_loss = trepa_loss_func(trepa_pred_images, trepa_gt_images) | |
| else: | |
| trepa_loss = 0 | |
| if config.model.add_audio_layer and config.run.use_syncnet: | |
| if config.run.pixel_space_supervise: | |
| syncnet_input = rearrange(pred_images, "(b f) c h w -> b (f c) h w", f=config.data.num_frames) | |
| else: | |
| syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w") | |
| if syncnet_config.data.lower_half: | |
| height = syncnet_input.shape[2] | |
| syncnet_input = syncnet_input[:, :, height // 2 :, :] | |
| ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device) | |
| vision_embeds, audio_embeds = syncnet(syncnet_input, mel) | |
| sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean() | |
| sync_loss_list.append(gather_loss(sync_loss, device)) | |
| else: | |
| sync_loss = 0 | |
| loss = ( | |
| recon_loss * config.run.recon_loss_weight | |
| + sync_loss * config.run.sync_loss_weight | |
| + lpips_loss * config.run.perceptual_loss_weight | |
| + trepa_loss * config.run.trepa_loss_weight | |
| ) | |
| train_step_list.append(global_step) | |
| if config.run.recon_loss_weight != 0: | |
| recon_loss_list.append(gather_loss(recon_loss, device)) | |
| optimizer.zero_grad() | |
| # Backpropagate | |
| if config.run.mixed_precision_training: | |
| scaler.scale(loss).backward() | |
| """ >>> gradient clipping >>> """ | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm) | |
| """ <<< gradient clipping <<< """ | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| """ >>> gradient clipping >>> """ | |
| torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm) | |
| """ <<< gradient clipping <<< """ | |
| optimizer.step() | |
| # Check the grad of attn blocks for debugging | |
| # print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].audio_cross_attn.attn.to_q.weight.grad) | |
| lr_scheduler.step() | |
| progress_bar.update(1) | |
| global_step += 1 | |
| ### <<<< Training <<<< ### | |
| # Save checkpoint and conduct validation | |
| if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0): | |
| if config.run.recon_loss_weight != 0: | |
| plot_loss_chart( | |
| os.path.join(output_dir, f"loss_charts/recon_loss_chart-{global_step}.png"), | |
| ("Reconstruction loss", train_step_list, recon_loss_list), | |
| ) | |
| if config.model.add_audio_layer: | |
| if sync_loss_list != []: | |
| plot_loss_chart( | |
| os.path.join(output_dir, f"loss_charts/sync_loss_chart-{global_step}.png"), | |
| ("Sync loss", train_step_list, sync_loss_list), | |
| ) | |
| model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt") | |
| state_dict = { | |
| "global_step": global_step, | |
| "state_dict": unet.module.state_dict(), # to unwrap DDP | |
| } | |
| try: | |
| torch.save(state_dict, model_save_path) | |
| logger.info(f"Saved checkpoint to {model_save_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving model: {e}") | |
| # Validation | |
| logger.info("Running validation... ") | |
| validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4") | |
| validation_video_mask_path = os.path.join(output_dir, f"val_videos/val_video_mask.mp4") | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| pipeline( | |
| config.data.val_video_path, | |
| config.data.val_audio_path, | |
| validation_video_out_path, | |
| validation_video_mask_path, | |
| num_frames=config.data.num_frames, | |
| num_inference_steps=config.run.inference_steps, | |
| guidance_scale=config.run.guidance_scale, | |
| weight_dtype=torch.float16, | |
| width=config.data.resolution, | |
| height=config.data.resolution, | |
| mask=config.data.mask, | |
| ) | |
| logger.info(f"Saved validation video output to {validation_video_out_path}") | |
| val_step_list.append(global_step) | |
| if config.model.add_audio_layer: | |
| try: | |
| _, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp") | |
| except Exception as e: | |
| logger.info(e) | |
| conf = 0 | |
| sync_conf_list.append(conf) | |
| plot_loss_chart( | |
| os.path.join(output_dir, f"loss_charts/sync_conf_chart-{global_step}.png"), | |
| ("Sync confidence", val_step_list, sync_conf_list), | |
| ) | |
| logs = {"step_loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]} | |
| progress_bar.set_postfix(**logs) | |
| if global_step >= config.run.max_train_steps: | |
| break | |
| progress_bar.close() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # Config file path | |
| parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml") | |
| args = parser.parse_args() | |
| config = OmegaConf.load(args.unet_config_path) | |
| config.unet_config_path = args.unet_config_path | |
| main(config) | |