Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,994 Bytes
052cf68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm
import logging
from data_utils.v2a_utils.vggsound_224_no_audio import VGGSound
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
import torchaudio
from einops import rearrange
from torch.utils.data.dataloader import default_collate
import numpy as np
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def error_avoidance_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
def main(args):
print(f"Using root: {args.root}, tsv_path: {args.tsv_path}, save_dir: {args.save_dir}")
dataset = VGGSound(
root=args.root,
tsv_path=args.tsv_path,
sample_rate=args.sample_rate,
duration_sec=args.duration_sec,
audio_samples=args.audio_samples,
start_row=args.start_row,
end_row=args.end_row,
save_dir=args.save_dir
)
save_dir = args.save_dir
os.makedirs(save_dir, exist_ok=True)
dataloader = DataLoader(dataset, batch_size=2, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate)
print(f"Dataset length: {len(dataset)}")
feature_extractor = FeaturesUtils(
vae_ckpt=None,
vae_config=args.vae_config,
enable_conditions=True,
synchformer_ckpt=args.synchformer_ckpt
).eval().cuda()
feature_extractor = feature_extractor
for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
ids = data['id']
with torch.no_grad():
# audio = data['audio'].cuda(rank, non_blocking=True)
output = {
'caption': str(data['caption']),
'caption_cot': str(data['caption_cot'])
}
print(output)
# latent = feature_extractor.module.encode_audio(audio)
# output['latent'] = latent.detach().cpu()
clip_video = data['clip_video'].cuda()
clip_features = feature_extractor.encode_video_with_clip(clip_video)
output['metaclip_features'] = clip_features.detach().cpu()
sync_video = data['sync_video'].cuda()
sync_features = feature_extractor.encode_video_with_sync(sync_video)
output['sync_features'] = sync_features.detach().cpu()
caption = data['caption']
metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(caption)
output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
output['metaclip_text_features'] = metaclip_text_features.detach().cpu()
caption_cot = data['caption_cot']
t5_features = feature_extractor.encode_t5_text(caption_cot)
output['t5_features'] = t5_features.detach().cpu()
for j in range(len(ids)):
sample_output = {
'id': ids[j],
'caption': output['caption'][j],
'caption_cot': output['caption_cot'][j],
# 'latent': output['latent'][j],
'metaclip_features': output['metaclip_features'][j],
'sync_features': output['sync_features'][j],
'metaclip_global_text_features': output['metaclip_global_text_features'][j],
'metaclip_text_features': output['metaclip_text_features'][j],
't5_features': output['t5_features'][j],
}
# torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')
np.savez(f'{save_dir}/demo.npz', **sample_output)
## test the sync between videos and audios
# torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
# recon_audio = feature_extractor.decode_audio(latent)
# recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
# id = data['id']
# torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Extract Video Training Latents')
parser.add_argument('--root', type=str, default='videos', help='Root directory of the video dataset')
parser.add_argument('--tsv_path', type=str, default='cot_coarse/cot.csv', help='Path to the TSV file')
parser.add_argument('--save-dir', type=str, default='results', help='Save Directory')
parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', type=int, default=0, help='start row')
parser.add_argument('--end-row', type=int, default=None, help='end row')
args = parser.parse_args()
args.audio_samples = int(args.sample_rate * args.duration_sec)
main(args=args)
|