File size: 5,994 Bytes
052cf68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm  
import logging  
from data_utils.v2a_utils.vggsound_224_no_audio import VGGSound
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
import torchaudio
from einops import rearrange
from torch.utils.data.dataloader import default_collate
import numpy as np
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def error_avoidance_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)

def main(args):

    print(f"Using root: {args.root}, tsv_path: {args.tsv_path}, save_dir: {args.save_dir}")
    dataset = VGGSound(
        root=args.root,
        tsv_path=args.tsv_path,
        sample_rate=args.sample_rate,
        duration_sec=args.duration_sec,
        audio_samples=args.audio_samples,
        start_row=args.start_row,
        end_row=args.end_row,
        save_dir=args.save_dir
    )
    save_dir = args.save_dir
    os.makedirs(save_dir, exist_ok=True)
    
    dataloader = DataLoader(dataset, batch_size=2, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate)

    print(f"Dataset length: {len(dataset)}")
    feature_extractor = FeaturesUtils(
        vae_ckpt=None,
        vae_config=args.vae_config,
        enable_conditions=True,
        synchformer_ckpt=args.synchformer_ckpt
    ).eval().cuda()

    feature_extractor = feature_extractor

    for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
        ids = data['id']  
        with torch.no_grad():
            # audio = data['audio'].cuda(rank, non_blocking=True)
            output = {
                'caption': str(data['caption']),
                'caption_cot': str(data['caption_cot'])
            }
            print(output)

            # latent = feature_extractor.module.encode_audio(audio)
            # output['latent'] = latent.detach().cpu()

            clip_video = data['clip_video'].cuda()
            clip_features = feature_extractor.encode_video_with_clip(clip_video)
            output['metaclip_features'] = clip_features.detach().cpu()

            sync_video = data['sync_video'].cuda()
            sync_features = feature_extractor.encode_video_with_sync(sync_video)
            output['sync_features'] = sync_features.detach().cpu()

            caption = data['caption']
            metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(caption)
            output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
            output['metaclip_text_features'] = metaclip_text_features.detach().cpu()

            caption_cot = data['caption_cot']
            t5_features = feature_extractor.encode_t5_text(caption_cot)
            output['t5_features'] = t5_features.detach().cpu()

            for j in range(len(ids)):
                sample_output = {
                    'id': ids[j],
                    'caption': output['caption'][j],
                    'caption_cot': output['caption_cot'][j],
                    # 'latent': output['latent'][j],
                    'metaclip_features': output['metaclip_features'][j],
                    'sync_features': output['sync_features'][j],
                    'metaclip_global_text_features': output['metaclip_global_text_features'][j],
                    'metaclip_text_features': output['metaclip_text_features'][j],
                    't5_features': output['t5_features'][j],
                }
                # torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')
                np.savez(f'{save_dir}/demo.npz', **sample_output)

        ## test the sync between videos and audios
        # torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
        # recon_audio = feature_extractor.decode_audio(latent)
        # recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
        # id = data['id']
        # torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
        # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
        # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Extract Video Training Latents')
    parser.add_argument('--root', type=str, default='videos', help='Root directory of the video dataset')
    parser.add_argument('--tsv_path', type=str, default='cot_coarse/cot.csv', help='Path to the TSV file')
    parser.add_argument('--save-dir', type=str, default='results', help='Save Directory')
    parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
    parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
    parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
    parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
    parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
    parser.add_argument('--start-row', type=int, default=0, help='start row')
    parser.add_argument('--end-row', type=int, default=None, help='end row')

    args = parser.parse_args()
    args.audio_samples = int(args.sample_rate * args.duration_sec)

    main(args=args)