Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,153 Bytes
052cf68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from torchvision.utils import save_image
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
start_row: Optional[int] = None,
end_row: Optional[int] = None,
save_dir: str = 'data/vggsound/video_latents_text/train'
):
self.root = Path(root)
# videos = sorted(os.listdir(self.root))
# videos = set([Path(v).stem for v in videos]) # remove extensions
videos = []
self.labels = []
self.cots = []
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
# ζ§εΆε€ηηθ‘θε΄
if start_row is not None and end_row is not None:
df_list = df_list[start_row:end_row]
for record in df_list:
id = record['id']
# if os.path.exists(f'{save_dir}/{id}.pth'):
# continue
# try:
# torch.load(f'{save_dir}/{id}.pth')
# continue
# except:
# print(f'error load file: {save_dir}/{id}.pth')
# os.system(f'rm -f {save_dir}/{id}.pth')
label = record['caption']
# if id in videos:
self.labels.append(label)
self.cots.append(record['caption_cot'])
# self.labels[id] = label
self.videos.append(id)
# else:
# missing_videos.append(id)
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
label = self.labels[idx]
cot = self.cots[idx]
data = {
'id': video_id,
'caption': label,
'caption_cot': cot
}
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)
# dataset = VGGSound(
# root="data/vggsound/video/test",
# tsv_path="data/vggsound/split_txt/temp.csv",
# sample_rate=44100,
# duration_sec=9.0,
# audio_samples=397312,
# start_row=0,
# end_row=None,
# save_dir="data/vggsound/video_latents_text/test"
# )
# dataset[0] |