ThinkSound / extract_latents.py
liuhuadai
support cot
052cf68
raw
history blame
5.99 kB
import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm
import logging
from data_utils.v2a_utils.vggsound_224_no_audio import VGGSound
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
import torchaudio
from einops import rearrange
from torch.utils.data.dataloader import default_collate
import numpy as np
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def error_avoidance_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
def main(args):
print(f"Using root: {args.root}, tsv_path: {args.tsv_path}, save_dir: {args.save_dir}")
dataset = VGGSound(
root=args.root,
tsv_path=args.tsv_path,
sample_rate=args.sample_rate,
duration_sec=args.duration_sec,
audio_samples=args.audio_samples,
start_row=args.start_row,
end_row=args.end_row,
save_dir=args.save_dir
)
save_dir = args.save_dir
os.makedirs(save_dir, exist_ok=True)
dataloader = DataLoader(dataset, batch_size=2, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate)
print(f"Dataset length: {len(dataset)}")
feature_extractor = FeaturesUtils(
vae_ckpt=None,
vae_config=args.vae_config,
enable_conditions=True,
synchformer_ckpt=args.synchformer_ckpt
).eval().cuda()
feature_extractor = feature_extractor
for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
ids = data['id']
with torch.no_grad():
# audio = data['audio'].cuda(rank, non_blocking=True)
output = {
'caption': str(data['caption']),
'caption_cot': str(data['caption_cot'])
}
print(output)
# latent = feature_extractor.module.encode_audio(audio)
# output['latent'] = latent.detach().cpu()
clip_video = data['clip_video'].cuda()
clip_features = feature_extractor.encode_video_with_clip(clip_video)
output['metaclip_features'] = clip_features.detach().cpu()
sync_video = data['sync_video'].cuda()
sync_features = feature_extractor.encode_video_with_sync(sync_video)
output['sync_features'] = sync_features.detach().cpu()
caption = data['caption']
metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(caption)
output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
output['metaclip_text_features'] = metaclip_text_features.detach().cpu()
caption_cot = data['caption_cot']
t5_features = feature_extractor.encode_t5_text(caption_cot)
output['t5_features'] = t5_features.detach().cpu()
for j in range(len(ids)):
sample_output = {
'id': ids[j],
'caption': output['caption'][j],
'caption_cot': output['caption_cot'][j],
# 'latent': output['latent'][j],
'metaclip_features': output['metaclip_features'][j],
'sync_features': output['sync_features'][j],
'metaclip_global_text_features': output['metaclip_global_text_features'][j],
'metaclip_text_features': output['metaclip_text_features'][j],
't5_features': output['t5_features'][j],
}
# torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')
np.savez(f'{save_dir}/demo.npz', **sample_output)
## test the sync between videos and audios
# torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
# recon_audio = feature_extractor.decode_audio(latent)
# recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
# id = data['id']
# torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Extract Video Training Latents')
parser.add_argument('--root', type=str, default='videos', help='Root directory of the video dataset')
parser.add_argument('--tsv_path', type=str, default='cot_coarse/cot.csv', help='Path to the TSV file')
parser.add_argument('--save-dir', type=str, default='results', help='Save Directory')
parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', type=int, default=0, help='start row')
parser.add_argument('--end-row', type=int, default=None, help='end row')
args = parser.parse_args()
args.audio_samples = int(args.sample_rate * args.duration_sec)
main(args=args)