File size: 14,567 Bytes
08f69f6
eabc0a6
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b79a6
50f13e6
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b79a6
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7b7e74
 
0d1fec4
37b79a6
08f69f6
 
 
 
 
 
 
37b79a6
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d1fec4
08f69f6
 
353e603
 
 
08f69f6
37b79a6
 
 
 
353e603
dcfa77b
f844705
08f69f6
 
 
 
 
37b79a6
 
 
 
08f69f6
 
37b79a6
 
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
37b79a6
 
 
 
 
 
08f69f6
 
 
 
37b79a6
08f69f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d03f49
 
 
 
 
08f69f6
4d03f49
08f69f6
4d03f49
08f69f6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from prefigure.prefigure import get_all_args, push_wandb_config
import spaces
import json
import os
os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp"
import re
import torch
import torchaudio
# import pytorch_lightning as pl
import lightning as L
from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.tuner import Tuner
from lightning.pytorch import seed_everything
import random
from datetime import datetime
# from think_sound.data.dataset import create_dataloader_from_config
from think_sound.data.datamodule import DataModule
from think_sound.models import create_model_from_config
from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config
from think_sound.training.utils import copy_state_dict
from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
from torch.utils.data import Dataset
from typing import Optional, Union
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from torchvision.utils import save_image
from transformers import AutoProcessor
import torch.nn.functional as F
import gradio as gr
import tempfile
import subprocess
from huggingface_hub import hf_hub_download
from moviepy.editor import VideoFileClip
os.system("conda install -c conda-forge 'ffmpeg<7'")

_CLIP_SIZE = 224
_CLIP_FPS = 8.0

_SYNC_SIZE = 224
_SYNC_FPS = 25.0

def pad_to_square(video_tensor):
    if len(video_tensor.shape) != 4:
        raise ValueError("Input tensor must have shape (l, c, h, w)")

    l, c, h, w = video_tensor.shape
    max_side = max(h, w)

    pad_h = max_side - h
    pad_w = max_side - w
    
    padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)

    video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)

    return video_padded


class VGGSound(Dataset):

    def __init__(
        self,
        sample_rate: int = 44_100,
        duration_sec: float = 9.0,
        audio_samples: int = None,
        normalize_audio: bool = False,
    ):
        if audio_samples is None:
            self.audio_samples = int(sample_rate * duration_sec)
        else:
            self.audio_samples = audio_samples
            effective_duration = audio_samples / sample_rate
            # make sure the duration is close enough, within 15ms
            assert abs(effective_duration - duration_sec) < 0.015, \
                f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'

        self.sample_rate = sample_rate
        self.duration_sec = duration_sec

        self.expected_audio_length = self.audio_samples
        self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
        self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)

        self.clip_transform = v2.Compose([
            v2.Lambda(pad_to_square),          # 先填充为正方形
            v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ])
        self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
        self.sync_transform = v2.Compose([
            v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
            v2.CenterCrop(_SYNC_SIZE),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        self.resampler = {}

    def sample(self, video_path,label):
        video_id = video_path

        reader = StreamingMediaDecoder(video_path)
        reader.add_basic_video_stream(
            frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
            frame_rate=_CLIP_FPS,
            format='rgb24',
        )
        reader.add_basic_video_stream(
            frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
            frame_rate=_SYNC_FPS,
            format='rgb24',
        ) 

        reader.fill_buffer()
        data_chunk = reader.pop_chunks()

        clip_chunk = data_chunk[0]
        sync_chunk = data_chunk[1]

        if sync_chunk is None:
            raise RuntimeError(f'Sync video returned None {video_id}')

        clip_chunk = clip_chunk[:self.clip_expected_length]
        # import ipdb
        # ipdb.set_trace()
        if clip_chunk.shape[0] != self.clip_expected_length:
            current_length = clip_chunk.shape[0]
            padding_needed = self.clip_expected_length - current_length
            
            # Check that padding needed is no more than 2
            assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'

            # If assertion passes, proceed with padding
            if padding_needed > 0:
                last_frame = clip_chunk[-1]
                log.info(last_frame.shape) 
                # Repeat the last frame to reach the expected length
                padding = last_frame.repeat(padding_needed, 1, 1, 1)
                clip_chunk = torch.cat((clip_chunk, padding), dim=0)
            # raise RuntimeError(f'CLIP video wrong length {video_id}, '
            #                    f'expected {self.clip_expected_length}, '
            #                    f'got {clip_chunk.shape[0]}')
        
        # save_image(clip_chunk[0] / 255.0,'ori.png')
        clip_chunk = pad_to_square(clip_chunk)

        clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]

        sync_chunk = sync_chunk[:self.sync_expected_length]
        if sync_chunk.shape[0] != self.sync_expected_length:
            # padding using the last frame, but no more than 2
            current_length = sync_chunk.shape[0]
            last_frame = sync_chunk[-1]
            # 重复最后一帧以进行填充
            padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
            assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
            sync_chunk = torch.cat((sync_chunk, padding), dim=0)
            # raise RuntimeError(f'Sync video wrong length {video_id}, '
            #                    f'expected {self.sync_expected_length}, '
            #                    f'got {sync_chunk.shape[0]}')
        
        sync_chunk = self.sync_transform(sync_chunk)
        # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
        # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
        data = {
            'id': video_id,
            'caption': label,
            # 'audio': audio_chunk,
            'clip_video': clip_chunk,
            'sync_video': sync_chunk,
        }

        return data

# 检查设备
if torch.cuda.is_available():
    device = 'cuda'
    extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
else:
    device = 'cpu'
    extra_device = 'cpu'

print(f"load in device {device}")

vae_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="vae.ckpt",repo_type="model")
synchformer_ckpt = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
feature_extractor = FeaturesUtils(
    vae_ckpt=vae_ckpt,
    vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json',
    enable_conditions=True,
    synchformer_ckpt=synchformer_ckpt
).eval().to(extra_device)



args = get_all_args()

seed = 10086

seed_everything(seed, workers=True)


#Get JSON config from args.model_config
with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f:
    model_config = json.load(f)

model = create_model_from_config(model_config)

## speed by torch.compile
if args.compile:
    model = torch.compile(model)
    
if args.pretrained_ckpt_path:
    copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder.  diffusion.

if args.remove_pretransform_weight_norm == "pre_load":
    remove_weight_norm_from_model(model.pretransform)


load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.') 
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
model.pretransform.load_state_dict(load_vae_state)

# Remove weight_norm from the pretransform if specified
if args.remove_pretransform_weight_norm == "post_load":
    remove_weight_norm_from_model(model.pretransform)
ckpt_path = hf_hub_download(repo_id="liuhuadai/ThinkSound", filename="thinksound.ckpt",repo_type="model")
training_wrapper = create_training_wrapper_from_config(model_config, model)
# 加载模型权重时根据设备选择map_location
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])

training_wrapper.to("cuda")

def get_video_duration(video_path):
    video = VideoFileClip(video_path)
    return video.duration

@spaces.GPU(duration=60)
@torch.inference_mode()
@torch.no_grad()
def get_audio(video_path, caption):
    # 允许caption为空
    if caption is None:
        caption = ''
    timer = Timer(duration="00:15:00:00")
    #get video duration
    duration_sec = get_video_duration(video_path)
    print(duration_sec)
    preprocesser = VGGSound(duration_sec=duration_sec)
    data = preprocesser.sample(video_path, caption)



    preprocessed_data = {}
    metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption'])
    preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
    preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)

    t5_features = feature_extractor.encode_t5_text(data['caption'])
    preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)

    clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
    preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0)

    sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device))
    preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0)
    preprocessed_data['video_exist'] = torch.tensor(True)
    print("clip_shape", preprocessed_data['metaclip_features'].shape)
    print("sync_shape", preprocessed_data['sync_features'].shape)
    sync_seq_len = preprocessed_data['sync_features'].shape[0]
    clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
    latent_seq_len = (int)(194/9*duration_sec)
    training_wrapper.diffusion.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)

    metadata = [preprocessed_data]

    batch_size = 1
    length = latent_seq_len
    with torch.amp.autocast(device):
        conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
    
    video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
    conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
    conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat

    cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
    noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
    with torch.amp.autocast(device):
        model = training_wrapper.diffusion.model
        if training_wrapper.diffusion_objective == "v":
            fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
        elif training_wrapper.diffusion_objective == "rectified_flow":
            import time
            start_time = time.time()
            fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
            end_time = time.time()
            execution_time = end_time - start_time
            print(f"执行时间: {execution_time:.2f} 秒")
        if training_wrapper.diffusion.pretransform is not None:
            fakes = training_wrapper.diffusion.pretransform.decode(fakes)

    audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
    # 保存临时音频文件
    with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
        torchaudio.save(tmp_audio.name, audios[0], 44100)
        audio_path = tmp_audio.name
    return audio_path

def synthesize_video_with_audio(video_file, caption):
    # 允许caption为空
    if caption is None:
        caption = ''
    audio_path = get_audio(video_file, caption)
    with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
        output_video_path = tmp_video.name
    # ffmpeg命令:用新音频替换原视频音轨
    cmd = [
        'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
        '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
        '-shortest', output_video_path
    ]
    subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return output_video_path

# Gradio界面
with gr.Blocks() as demo:
    gr.Markdown("# ThinkSound\nupload video and caption(optional), and get video with audio!")
    with gr.Row():
        video_input = gr.Video(label="upload video")
        caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1)
    output_video = gr.Video(label="output video")
    btn = gr.Button("start synthesize")
    btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video)

    gr.Examples(
        examples=[
            ["./examples/1_mute.mp4", "Playing Trumpet", "./examples/1.mp4"],
            ["./examples/2_mute.mp4", "Axe striking", "./examples/2.mp4"],
            ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "./examples/3.mp4"],
            ["./examples/4_mute.mp4", "train passing by", "./examples/4.mp4"],
            ["./examples/5_mute.mp4", "Lighting Firecrackers", "./examples/5.mp4"]
        ],
        inputs=[video_input, caption_input,output_video],
    )
    
demo.launch(share=True)