from prefigure.prefigure import get_all_args, push_wandb_config import json import os os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp" import re import torch import torchaudio # import pytorch_lightning as pl import lightning as L from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter from lightning.pytorch.callbacks import Callback from lightning.pytorch.tuner import Tuner from lightning.pytorch import seed_everything import random from datetime import datetime # from think_sound.data.dataset import create_dataloader_from_config from think_sound.data.datamodule import DataModule from think_sound.models import create_model_from_config from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config from think_sound.training.utils import copy_state_dict from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils from torch.utils.data import Dataset from typing import Optional, Union from torchvision.transforms import v2 from torio.io import StreamingMediaDecoder from torchvision.utils import save_image from transformers import AutoProcessor import torch.nn.functional as F import gradio as gr import tempfile import subprocess from huggingface_hub import hf_hub_download from moviepy.editor import VideoFileClip _CLIP_SIZE = 224 _CLIP_FPS = 8.0 _SYNC_SIZE = 224 _SYNC_FPS = 25.0 def pad_to_square(video_tensor): if len(video_tensor.shape) != 4: raise ValueError("Input tensor must have shape (l, c, h, w)") l, c, h, w = video_tensor.shape max_side = max(h, w) pad_h = max_side - h pad_w = max_side - w padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) return video_padded class VGGSound(Dataset): def __init__( self, sample_rate: int = 44_100, duration_sec: float = 9.0, audio_samples: int = None, normalize_audio: bool = False, ): if audio_samples is None: self.audio_samples = int(sample_rate * duration_sec) else: self.audio_samples = audio_samples effective_duration = audio_samples / sample_rate # make sure the duration is close enough, within 15ms assert abs(effective_duration - duration_sec) < 0.015, \ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' self.sample_rate = sample_rate self.duration_sec = duration_sec self.expected_audio_length = self.audio_samples self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) self.clip_transform = v2.Compose([ v2.Lambda(pad_to_square), # 先填充为正方形 v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") self.sync_transform = v2.Compose([ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), v2.CenterCrop(_SYNC_SIZE), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) self.resampler = {} def sample(self, video_path,label): video_id = video_path reader = StreamingMediaDecoder(video_path) reader.add_basic_video_stream( frames_per_chunk=int(_CLIP_FPS * self.duration_sec), frame_rate=_CLIP_FPS, format='rgb24', ) reader.add_basic_video_stream( frames_per_chunk=int(_SYNC_FPS * self.duration_sec), frame_rate=_SYNC_FPS, format='rgb24', ) reader.fill_buffer() data_chunk = reader.pop_chunks() clip_chunk = data_chunk[0] sync_chunk = data_chunk[1] if sync_chunk is None: raise RuntimeError(f'Sync video returned None {video_id}') clip_chunk = clip_chunk[:self.clip_expected_length] # import ipdb # ipdb.set_trace() if clip_chunk.shape[0] != self.clip_expected_length: current_length = clip_chunk.shape[0] padding_needed = self.clip_expected_length - current_length # Check that padding needed is no more than 2 assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' # If assertion passes, proceed with padding if padding_needed > 0: last_frame = clip_chunk[-1] log.info(last_frame.shape) # Repeat the last frame to reach the expected length padding = last_frame.repeat(padding_needed, 1, 1, 1) clip_chunk = torch.cat((clip_chunk, padding), dim=0) # raise RuntimeError(f'CLIP video wrong length {video_id}, ' # f'expected {self.clip_expected_length}, ' # f'got {clip_chunk.shape[0]}') # save_image(clip_chunk[0] / 255.0,'ori.png') clip_chunk = pad_to_square(clip_chunk) clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] sync_chunk = sync_chunk[:self.sync_expected_length] if sync_chunk.shape[0] != self.sync_expected_length: # padding using the last frame, but no more than 2 current_length = sync_chunk.shape[0] last_frame = sync_chunk[-1] # 重复最后一帧以进行填充 padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' sync_chunk = torch.cat((sync_chunk, padding), dim=0) # raise RuntimeError(f'Sync video wrong length {video_id}, ' # f'expected {self.sync_expected_length}, ' # f'got {sync_chunk.shape[0]}') sync_chunk = self.sync_transform(sync_chunk) # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' data = { 'id': video_id, 'caption': label, # 'audio': audio_chunk, 'clip_video': clip_chunk, 'sync_video': sync_chunk, } return data # 检查设备 if torch.cuda.is_available(): device = 'cuda' extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0' else: device = 'cpu' extra_device = 'cpu' vae_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="vae.ckpt",repo_type="model") synchformer_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model") feature_extractor = FeaturesUtils( vae_ckpt=vae_ckpt, vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json', enable_conditions=True, synchformer_ckpt=synchformer_ckpt ).eval().to(extra_device) args = get_all_args() seed = 10086 seed_everything(seed, workers=True) #Get JSON config from args.model_config with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f: model_config = json.load(f) model = create_model_from_config(model_config) ## speed by torch.compile if args.compile: model = torch.compile(model) if args.pretrained_ckpt_path: copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion. if args.remove_pretransform_weight_norm == "pre_load": remove_weight_norm_from_model(model.pretransform) load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.') # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")} model.pretransform.load_state_dict(load_vae_state) # Remove weight_norm from the pretransform if specified if args.remove_pretransform_weight_norm == "post_load": remove_weight_norm_from_model(model.pretransform) ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model") training_wrapper = create_training_wrapper_from_config(model_config, model) # 加载模型权重时根据设备选择map_location if device == 'cuda': training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict']) else: training_wrapper.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']) def get_video_duration(video_path): video = VideoFileClip(video_path) return video.duration def get_audio(video_path, caption): # 允许caption为空 if caption is None: caption = '' timer = Timer(duration="00:15:00:00") #get video duration duration_sec = get_video_duration(video_path) print(duration_sec) preprocesser = VGGSound(duration_sec=duration_sec) data = preprocesser.sample(video_path, caption) preprocessed_data = {} metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption']) preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0) preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0) t5_features = feature_extractor.encode_t5_text(data['caption']) preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0) clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device)) preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0) sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device)) preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0) preprocessed_data['video_exist'] = torch.tensor(True) print("clip_shape", preprocessed_data['metaclip_features'].shape) print("sync_shape", preprocessed_data['sync_features'].shape) sync_seq_len = preprocessed_data['sync_features'].shape[0] clip_seq_len = preprocessed_data['metaclip_features'].shape[0] latent_seq_len = (int)(194/9*duration_sec) training_wrapper.diffusion.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len) metadata = [preprocessed_data] batch_size = 1 length = latent_seq_len with torch.amp.autocast(device): conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device) video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning) noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device) with torch.amp.autocast(device): model = training_wrapper.diffusion.model if training_wrapper.diffusion_objective == "v": fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) elif training_wrapper.diffusion_objective == "rectified_flow": import time start_time = time.time() fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) end_time = time.time() execution_time = end_time - start_time print(f"执行时间: {execution_time:.2f} 秒") if training_wrapper.diffusion.pretransform is not None: fakes = training_wrapper.diffusion.pretransform.decode(fakes) audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() # 保存临时音频文件 with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: torchaudio.save(tmp_audio.name, audios[0], 44100) audio_path = tmp_audio.name return audio_path get_audio("./examples/3_mute.mp4", "Axe striking") # 合成新视频:用ffmpeg将音频与原视频合成 def synthesize_video_with_audio(video_file, caption): # 允许caption为空 if caption is None: caption = '' audio_path = get_audio(video_file, caption) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video: output_video_path = tmp_video.name # ffmpeg命令:用新音频替换原视频音轨 cmd = [ 'ffmpeg', '-y', '-i', video_file, '-i', audio_path, '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0', '-shortest', output_video_path ] subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return output_video_path # Gradio界面 with gr.Blocks() as demo: gr.Markdown("# ThinkSound\nupload video and caption(optional), and get video with audio!") with gr.Row(): video_input = gr.Video(label="upload video") caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1) output_video = gr.Video(label="output video") btn = gr.Button("start synthesize") btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video) gr.Examples( examples=[ ["./examples/1_mute.mp4", "Playing Trumpet"], ["./examples/2_mute.mp4", "Axe striking"], ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier"], ["./examples/4_mute.mp4", "train passing by"], ["./examples/5_mute.mp4", "Lighting Firecrackers"] ], inputs=[video_input, caption_input], ) demo.launch(share=True)