File size: 3,153 Bytes
052cf68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
import os
from pathlib import Path
from typing import Optional, Union

import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from torchvision.utils import save_image

log = logging.getLogger()

_CLIP_SIZE = 384
_CLIP_FPS = 8.0

_SYNC_SIZE = 224
_SYNC_FPS = 25.0


class VGGSound(Dataset):

    def __init__(
        self,
        root: Union[str, Path],
        *,
        tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
        start_row: Optional[int] = None,
        end_row: Optional[int] = None,
        save_dir: str = 'data/vggsound/video_latents_text/train'
    ):
        self.root = Path(root)
        
        # videos = sorted(os.listdir(self.root))
        # videos = set([Path(v).stem for v in videos])  # remove extensions
        videos = []
        self.labels = []
        self.cots = []
        self.videos = []
        missing_videos = []
        # read the tsv for subset information
        df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
        
        # ζŽ§εˆΆε€„η†ηš„θ‘ŒθŒƒε›΄
        if start_row is not None and end_row is not None:
            df_list = df_list[start_row:end_row]
        
        for record in df_list:
            id = record['id']
            # if os.path.exists(f'{save_dir}/{id}.pth'): 
            #     continue
                # try:
                #     torch.load(f'{save_dir}/{id}.pth')
                #     continue
                # except:
                #     print(f'error load file: {save_dir}/{id}.pth')
                #     os.system(f'rm -f {save_dir}/{id}.pth')
            label = record['caption']
            # if id in videos:
            self.labels.append(label)
            self.cots.append(record['caption_cot'])
            # self.labels[id] = label
            self.videos.append(id)
            # else:
            #     missing_videos.append(id)

        log.info(f'{len(videos)} videos found in {root}')
        log.info(f'{len(self.videos)} videos found in {tsv_path}')
        log.info(f'{len(missing_videos)} videos missing in {root}')




    def sample(self, idx: int) -> dict[str, torch.Tensor]:
        video_id = self.videos[idx]
        label = self.labels[idx]
        cot = self.cots[idx]
        data = {
            'id': video_id,
            'caption': label,
            'caption_cot': cot
        }

        return data

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        try:
            return self.sample(idx)
        except Exception as e:
            log.error(f'Error loading video {self.videos[idx]}: {e}')
            return None

    def __len__(self):
        return len(self.labels)


# dataset = VGGSound(
#         root="data/vggsound/video/test",
#         tsv_path="data/vggsound/split_txt/temp.csv",
#         sample_rate=44100,
#         duration_sec=9.0,
#         audio_samples=397312,
#         start_row=0,
#         end_row=None,
#         save_dir="data/vggsound/video_latents_text/test"
#     )
# dataset[0]