LPDoctor commited on
Commit
18db7f4
·
1 Parent(s): f94464b

Implement core functionality for ThinkSound audio generation app, including video processing, audio synthesis, and Gradio interface setup. Update README with new title and emoji.

Browse files
README.md CHANGED
@@ -1,13 +1,11 @@
1
  ---
2
- title: ThinkSound Audio App
3
- emoji: 👀
4
- colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.36.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ThinkSound
3
+ emoji: 🔊
4
+ colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ ---
 
 
ThinkSound ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 9eeb443af2ab75afc046e27494c522dfd87fa2c1
app.py CHANGED
@@ -1,7 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ from prefigure.prefigure import get_all_args, push_wandb_config
2
+ import spaces
3
+ import json
4
+ import os
5
+ os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp"
6
+ import re
7
+ import torch
8
+ import torchaudio
9
+ # import pytorch_lightning as pl
10
+ import lightning as L
11
+ from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
12
+ from lightning.pytorch.callbacks import Callback
13
+ from lightning.pytorch.tuner import Tuner
14
+ from lightning.pytorch import seed_everything
15
+ import random
16
+ from datetime import datetime
17
+ from ThinkSound.data.datamodule import DataModule
18
+ from ThinkSound.models import create_model_from_config
19
+ from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
20
+ from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
21
+ from ThinkSound.training.utils import copy_state_dict
22
+ from ThinkSound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
23
+ from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
24
+ from torch.utils.data import Dataset
25
+ from typing import Optional, Union
26
+ from torchvision.transforms import v2
27
+ from torio.io import StreamingMediaDecoder
28
+ from torchvision.utils import save_image
29
+ from transformers import AutoProcessor
30
+ import torch.nn.functional as F
31
  import gradio as gr
32
+ import tempfile
33
+ import subprocess
34
+ from huggingface_hub import hf_hub_download
35
+ from moviepy.editor import VideoFileClip
36
+ # os.system("conda install -c conda-forge 'ffmpeg<7'")
37
+
38
+ _CLIP_SIZE = 224
39
+ _CLIP_FPS = 8.0
40
+
41
+ _SYNC_SIZE = 224
42
+ _SYNC_FPS = 25.0
43
+
44
+ def pad_to_square(video_tensor):
45
+ if len(video_tensor.shape) != 4:
46
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
47
+
48
+ l, c, h, w = video_tensor.shape
49
+ max_side = max(h, w)
50
+
51
+ pad_h = max_side - h
52
+ pad_w = max_side - w
53
+
54
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
55
+
56
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
57
+
58
+ return video_padded
59
+
60
+
61
+ class VGGSound(Dataset):
62
+
63
+ def __init__(
64
+ self,
65
+ sample_rate: int = 44_100,
66
+ duration_sec: float = 9.0,
67
+ audio_samples: int = None,
68
+ normalize_audio: bool = False,
69
+ ):
70
+ if audio_samples is None:
71
+ self.audio_samples = int(sample_rate * duration_sec)
72
+ else:
73
+ self.audio_samples = audio_samples
74
+ effective_duration = audio_samples / sample_rate
75
+ # make sure the duration is close enough, within 15ms
76
+ assert abs(effective_duration - duration_sec) < 0.015, \
77
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
78
+
79
+ self.sample_rate = sample_rate
80
+ self.duration_sec = duration_sec
81
+
82
+ self.expected_audio_length = self.audio_samples
83
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
84
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
85
+
86
+ self.clip_transform = v2.Compose([
87
+ v2.Lambda(pad_to_square), # 先填充为正方形
88
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
89
+ v2.ToImage(),
90
+ v2.ToDtype(torch.float32, scale=True),
91
+ ])
92
+ self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
93
+ self.sync_transform = v2.Compose([
94
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
95
+ v2.CenterCrop(_SYNC_SIZE),
96
+ v2.ToImage(),
97
+ v2.ToDtype(torch.float32, scale=True),
98
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
99
+ ])
100
+
101
+ self.resampler = {}
102
+
103
+ def sample(self, video_path,label,cot):
104
+ video_id = video_path
105
+
106
+ reader = StreamingMediaDecoder(video_path)
107
+ reader.add_basic_video_stream(
108
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
109
+ frame_rate=_CLIP_FPS,
110
+ format='rgb24',
111
+ )
112
+ reader.add_basic_video_stream(
113
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
114
+ frame_rate=_SYNC_FPS,
115
+ format='rgb24',
116
+ )
117
+
118
+ reader.fill_buffer()
119
+ data_chunk = reader.pop_chunks()
120
+
121
+ clip_chunk = data_chunk[0]
122
+ sync_chunk = data_chunk[1]
123
+
124
+ if sync_chunk is None:
125
+ raise RuntimeError(f'Sync video returned None {video_id}')
126
+
127
+ clip_chunk = clip_chunk[:self.clip_expected_length]
128
+ # import ipdb
129
+ # ipdb.set_trace()
130
+ if clip_chunk.shape[0] != self.clip_expected_length:
131
+ current_length = clip_chunk.shape[0]
132
+ padding_needed = self.clip_expected_length - current_length
133
+
134
+ # Check that padding needed is no more than 2
135
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
136
+
137
+ # If assertion passes, proceed with padding
138
+ if padding_needed > 0:
139
+ last_frame = clip_chunk[-1]
140
+ log.info(last_frame.shape)
141
+ # Repeat the last frame to reach the expected length
142
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
143
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
144
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
145
+ # f'expected {self.clip_expected_length}, '
146
+ # f'got {clip_chunk.shape[0]}')
147
+
148
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
149
+ clip_chunk = pad_to_square(clip_chunk)
150
+
151
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
152
+
153
+ sync_chunk = sync_chunk[:self.sync_expected_length]
154
+ if sync_chunk.shape[0] != self.sync_expected_length:
155
+ # padding using the last frame, but no more than 2
156
+ current_length = sync_chunk.shape[0]
157
+ last_frame = sync_chunk[-1]
158
+
159
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
160
+ assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
161
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
162
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
163
+ # f'expected {self.sync_expected_length}, '
164
+ # f'got {sync_chunk.shape[0]}')
165
+
166
+ sync_chunk = self.sync_transform(sync_chunk)
167
+ # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
168
+ # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
169
+ data = {
170
+ 'id': video_id,
171
+ 'caption': label,
172
+ 'caption_cot': cot,
173
+ # 'audio': audio_chunk,
174
+ 'clip_video': clip_chunk,
175
+ 'sync_video': sync_chunk,
176
+ }
177
+
178
+ return data
179
+
180
+ # 检查设备
181
+ if torch.cuda.is_available():
182
+ device = 'cuda'
183
+ extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
184
+ else:
185
+ device = 'cpu'
186
+ extra_device = 'cpu'
187
+
188
+ print(f"load in device {device}")
189
+
190
+ vae_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="vae.ckpt",repo_type="model")
191
+ synchformer_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
192
+
193
+ feature_extractor = FeaturesUtils(
194
+ vae_ckpt=None,
195
+ vae_config='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json',
196
+ enable_conditions=True,
197
+ synchformer_ckpt=synchformer_ckpt
198
+ ).eval().to(extra_device)
199
+
200
+ args = get_all_args()
201
+
202
+ seed = 10086
203
+
204
+ seed_everything(seed, workers=True)
205
+
206
+
207
+ #Get JSON config from args.model_config
208
+ with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
+ model_config = json.load(f)
210
+
211
+ diffusion_model = create_model_from_config(model_config)
212
+ ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
213
+ diffusion_model.load_state_dict(torch.load(ckpt_path))
214
+ diffusion_model.to(device)
215
+
216
+ ## speed by torch.compile
217
+ if args.compile:
218
+ diffusion_model = torch.compile(diffusion_model)
219
+
220
+
221
+ load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
222
+ # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
223
+ diffusion_model.pretransform.load_state_dict(load_vae_state)
224
+
225
+ def get_video_duration(video_path):
226
+ video = VideoFileClip(video_path)
227
+ return video.duration
228
+
229
+ @spaces.GPU(duration=60)
230
+ @torch.inference_mode()
231
+ @torch.no_grad()
232
+ def synthesize_video_with_audio(video_file, caption, cot):
233
+ yield "⏳ Extracting Features…", None
234
+ video_path = video_file
235
+ if caption is None:
236
+ caption = ''
237
+ if cot is None:
238
+ cot = caption
239
+ timer = Timer(duration="00:15:00:00")
240
+ #get video duration
241
+ duration_sec = get_video_duration(video_path)
242
+ print(duration_sec)
243
+ preprocesser = VGGSound(duration_sec=duration_sec)
244
+ data = preprocesser.sample(video_path, caption, cot)
245
+
246
+
247
+ preprocessed_data = {}
248
+ metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption'])
249
+ preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
250
+ preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
251
+
252
+ t5_features = feature_extractor.encode_t5_text(data['caption_cot'])
253
+ preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
254
+
255
+ clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
256
+ preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0)
257
+
258
+ sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device))
259
+ preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0)
260
+ preprocessed_data['video_exist'] = torch.tensor(True)
261
+ print("clip_shape", preprocessed_data['metaclip_features'].shape)
262
+ print("sync_shape", preprocessed_data['sync_features'].shape)
263
+ sync_seq_len = preprocessed_data['sync_features'].shape[0]
264
+ clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
265
+ latent_seq_len = (int)(194/9*duration_sec)
266
+ diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
267
+
268
+ metadata = [preprocessed_data]
269
+
270
+ batch_size = 1
271
+ length = latent_seq_len
272
+ with torch.amp.autocast(device):
273
+ conditioning = diffusion_model.conditioner(metadata, device)
274
+
275
+ video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
276
+ conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
277
+ conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
278
+
279
+ yield "⏳ Inferring…", None
280
+
281
+ cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
282
+ noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
283
+ with torch.amp.autocast(device):
284
+ if diffusion_model.diffusion_objective == "v":
285
+ fakes = sample(diffusion_model.model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
286
+ elif diffusion_model.diffusion_objective == "rectified_flow":
287
+ import time
288
+ start_time = time.time()
289
+ fakes = sample_discrete_euler(diffusion_model.model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
290
+ end_time = time.time()
291
+ execution_time = end_time - start_time
292
+ print(f"execution_time: {execution_time:.2f} 秒")
293
+
294
+ if diffusion_model.pretransform is not None:
295
+ fakes = diffusion_model.pretransform.decode(fakes)
296
+
297
+ audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
298
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
299
+ torchaudio.save(tmp_audio.name, audios[0], 44100)
300
+ audio_path = tmp_audio.name
301
+
302
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
303
+ output_video_path = tmp_video.name
304
+
305
+ cmd = [
306
+ 'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
307
+ '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
308
+ '-shortest', output_video_path
309
+ ]
310
+ subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
311
+
312
+ # return output_video_path
313
+ yield "✅ Generation completed!", output_video_path
314
+
315
+ demo = gr.Interface(
316
+ fn=synthesize_video_with_audio,
317
+ inputs=[
318
+ gr.Video(label="Upload Video"),
319
+ gr.Textbox(label="Caption (optional)", placeholder="can be empty",),
320
+ gr.Textbox(label="CoT Description (optional)", lines=6, placeholder="can be empty",),
321
+ ],
322
+ outputs=[
323
+ gr.Text(label="Status"),
324
+ gr.Video(label="Result"),
325
+ ],
326
+ title="ThinkSound Demo",
327
+ description="Upload a video, caption, or CoT to generate audio. For an enhanced experience, we automatically merge the generated audio with your original silent video. (Note: Flexible audio generation lengths are supported.:)",
328
+ examples=[
329
+ ["examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "Begin by creating a soft, steady background of light pacifier suckling. Add subtle, breathy rhythms to mimic a newborn's gentle mouth movements. Keep the sound smooth, natural, and soothing."],
330
+ ["examples/2_mute.mp4", "Printer Printing", "Generate a continuous printer printing sound with periodic beeps and paper movement, plus a cat pawing at the machine. Add subtle ambient room noise for authenticity, keeping the focus on printing, beeps, and the cat's interaction."],
331
+ ["examples/5_mute.mp4", "Lighting Firecrackers", "Generate the sound of firecrackers lighting and exploding repeatedly on the ground, followed by fireworks bursting in the sky. Incorporate occasional subtle echoes to mimic an outdoor night ambiance, with no human voices present."],
332
+ ["examples/4_mute.mp4", "Plastic Debris Handling", "Begin with the sound of hands scooping up loose plastic debris, followed by the subtle cascading noise as the pieces fall and scatter back down. Include soft crinkling and rustling to emphasize the texture of the plastic. Add ambient factory background noise with distant machinery to create an industrial atmosphere."]
333
+ ],
334
+ cache_examples=True
335
+ )
336
+
337
+ if __name__ == "__main__":
338
+ demo.queue().launch(share=True)
339
+
340
+ demo.launch(share=True)
341
+
342
 
 
 
343
 
 
 
cot_vgg_demo_caption.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ demo.npz
data_utils/ext/synchformer/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Vladimir Iashin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data_utils/ext/synchformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from data_utils.ext.synchformer.synchformer import Synchformer
data_utils/ext/synchformer/divided_224_16x4.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: Ssv2
4
+ BATCH_SIZE: 32
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ CHECKPOINT_EPOCH_RESET: True
9
+ CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
10
+ DATA:
11
+ NUM_FRAMES: 16
12
+ SAMPLING_RATE: 4
13
+ TRAIN_JITTER_SCALES: [256, 320]
14
+ TRAIN_CROP_SIZE: 224
15
+ TEST_CROP_SIZE: 224
16
+ INPUT_CHANNEL_NUM: [3]
17
+ MEAN: [0.5, 0.5, 0.5]
18
+ STD: [0.5, 0.5, 0.5]
19
+ PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
20
+ PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
21
+ INV_UNIFORM_SAMPLE: True
22
+ RANDOM_FLIP: False
23
+ REVERSE_INPUT_CHANNEL: True
24
+ USE_RAND_AUGMENT: True
25
+ RE_PROB: 0.0
26
+ USE_REPEATED_AUG: False
27
+ USE_RANDOM_RESIZE_CROPS: False
28
+ COLORJITTER: False
29
+ GRAYSCALE: False
30
+ GAUSSIAN: False
31
+ SOLVER:
32
+ BASE_LR: 1e-4
33
+ LR_POLICY: steps_with_relative_lrs
34
+ LRS: [1, 0.1, 0.01]
35
+ STEPS: [0, 20, 30]
36
+ MAX_EPOCH: 35
37
+ MOMENTUM: 0.9
38
+ WEIGHT_DECAY: 5e-2
39
+ WARMUP_EPOCHS: 0.0
40
+ OPTIMIZING_METHOD: adamw
41
+ USE_MIXED_PRECISION: True
42
+ SMOOTHING: 0.2
43
+ SLOWFAST:
44
+ ALPHA: 8
45
+ VIT:
46
+ PATCH_SIZE: 16
47
+ PATCH_SIZE_TEMP: 2
48
+ CHANNELS: 3
49
+ EMBED_DIM: 768
50
+ DEPTH: 12
51
+ NUM_HEADS: 12
52
+ MLP_RATIO: 4
53
+ QKV_BIAS: True
54
+ VIDEO_INPUT: True
55
+ TEMPORAL_RESOLUTION: 8
56
+ USE_MLP: True
57
+ DROP: 0.0
58
+ POS_DROPOUT: 0.0
59
+ DROP_PATH: 0.2
60
+ IM_PRETRAINED: True
61
+ HEAD_DROPOUT: 0.0
62
+ HEAD_ACT: tanh
63
+ PRETRAINED_WEIGHTS: vit_1k
64
+ ATTN_LAYER: divided
65
+ MODEL:
66
+ NUM_CLASSES: 174
67
+ ARCH: slow
68
+ MODEL_NAME: VisionTransformer
69
+ LOSS_FUNC: cross_entropy
70
+ TEST:
71
+ ENABLE: True
72
+ DATASET: Ssv2
73
+ BATCH_SIZE: 64
74
+ NUM_ENSEMBLE_VIEWS: 1
75
+ NUM_SPATIAL_CROPS: 3
76
+ DATA_LOADER:
77
+ NUM_WORKERS: 4
78
+ PIN_MEMORY: True
79
+ NUM_GPUS: 8
80
+ NUM_SHARDS: 4
81
+ RNG_SEED: 0
82
+ OUTPUT_DIR: .
83
+ TENSORBOARD:
84
+ ENABLE: True
data_utils/ext/synchformer/motionformer.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import einops
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+ from timm.layers import trunc_normal_
8
+ from torch import nn
9
+
10
+ from data_utils.ext.synchformer.utils import check_if_file_exists_else_download
11
+ from data_utils.ext.synchformer.video_model_builder import VisionTransformer
12
+
13
+ FILE2URL = {
14
+ # cfg
15
+ 'motionformer_224_16x4.yaml':
16
+ 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
17
+ 'joint_224_16x4.yaml':
18
+ 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
19
+ 'divided_224_16x4.yaml':
20
+ 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
21
+ # ckpt
22
+ 'ssv2_motionformer_224_16x4.pyth':
23
+ 'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
24
+ 'ssv2_joint_224_16x4.pyth':
25
+ 'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
26
+ 'ssv2_divided_224_16x4.pyth':
27
+ 'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
28
+ }
29
+
30
+
31
+ class MotionFormer(VisionTransformer):
32
+ ''' This class serves three puposes:
33
+ 1. Renames the class to MotionFormer.
34
+ 2. Downloads the cfg from the original repo and patches it if needed.
35
+ 3. Takes care of feature extraction by redefining .forward()
36
+ - if `extract_features=True` and `factorize_space_time=False`,
37
+ the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
38
+ - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
39
+ and spatial and temporal transformer encoder layers are used.
40
+ - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
41
+ the output is of shape (B, D) and spatial and temporal transformer encoder layers
42
+ are used as well as the global representation is extracted from segments (extra pos emb
43
+ is added).
44
+ '''
45
+
46
+ def __init__(
47
+ self,
48
+ extract_features: bool = False,
49
+ ckpt_path: str = None,
50
+ factorize_space_time: bool = None,
51
+ agg_space_module: str = None,
52
+ agg_time_module: str = None,
53
+ add_global_repr: bool = True,
54
+ agg_segments_module: str = None,
55
+ max_segments: int = None,
56
+ ):
57
+ self.extract_features = extract_features
58
+ self.ckpt_path = ckpt_path
59
+ self.factorize_space_time = factorize_space_time
60
+
61
+ if self.ckpt_path is not None:
62
+ check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
63
+ ckpt = torch.load(self.ckpt_path, map_location='cpu')
64
+ mformer_ckpt2cfg = {
65
+ 'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
66
+ 'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
67
+ 'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
68
+ }
69
+ # init from motionformer ckpt or from our Stage I ckpt
70
+ # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
71
+ # load the state dict differently
72
+ was_pt_on_avclip = self.ckpt_path.endswith(
73
+ '.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
74
+ if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
75
+ cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
76
+ elif was_pt_on_avclip:
77
+ # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
78
+ s1_cfg = ckpt.get('args', None) # Stage I cfg
79
+ if s1_cfg is not None:
80
+ s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
81
+ # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
82
+ if s1_vfeat_extractor_ckpt_path is not None:
83
+ cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
84
+ else:
85
+ cfg_fname = 'divided_224_16x4.yaml'
86
+ else:
87
+ cfg_fname = 'divided_224_16x4.yaml'
88
+ else:
89
+ raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
90
+ else:
91
+ was_pt_on_avclip = False
92
+ cfg_fname = 'divided_224_16x4.yaml'
93
+ # logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
94
+
95
+ if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
96
+ pos_emb_type = 'separate'
97
+ elif cfg_fname == 'joint_224_16x4.yaml':
98
+ pos_emb_type = 'joint'
99
+
100
+ self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
101
+
102
+ check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
103
+ mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
104
+ logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
105
+
106
+ # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
107
+ mformer_cfg.VIT.ATTN_DROPOUT = 0.0
108
+ mformer_cfg.VIT.POS_EMBED = pos_emb_type
109
+ mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
110
+ mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
111
+ mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
112
+
113
+ # finally init VisionTransformer with the cfg
114
+ super().__init__(mformer_cfg)
115
+
116
+ # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
117
+ if (self.ckpt_path is not None) and (not was_pt_on_avclip):
118
+ _ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
119
+ if len(_ckpt_load_status.missing_keys) > 0 or len(
120
+ _ckpt_load_status.unexpected_keys) > 0:
121
+ logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
122
+ f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
123
+ f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
124
+ else:
125
+ logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
126
+
127
+ if self.extract_features:
128
+ assert isinstance(self.norm,
129
+ nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
130
+ # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
131
+ self.pre_logits = nn.Identity()
132
+ # we don't need the classification head (saving memory)
133
+ self.head = nn.Identity()
134
+ self.head_drop = nn.Identity()
135
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
136
+ transf_enc_layer_kwargs = dict(
137
+ d_model=self.embed_dim,
138
+ nhead=self.num_heads,
139
+ activation=nn.GELU(),
140
+ batch_first=True,
141
+ dim_feedforward=self.mlp_ratio * self.embed_dim,
142
+ dropout=self.drop_rate,
143
+ layer_norm_eps=1e-6,
144
+ norm_first=True,
145
+ )
146
+ # define adapters if needed
147
+ if self.factorize_space_time:
148
+ if agg_space_module == 'TransformerEncoderLayer':
149
+ self.spatial_attn_agg = SpatialTransformerEncoderLayer(
150
+ **transf_enc_layer_kwargs)
151
+ elif agg_space_module == 'AveragePooling':
152
+ self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
153
+ then_permute_pattern='BS D t -> BS t D')
154
+ if agg_time_module == 'TransformerEncoderLayer':
155
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
156
+ elif agg_time_module == 'AveragePooling':
157
+ self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
158
+ elif 'Identity' in agg_time_module:
159
+ self.temp_attn_agg = nn.Identity()
160
+ # define a global aggregation layer (aggregarate over segments)
161
+ self.add_global_repr = add_global_repr
162
+ if add_global_repr:
163
+ if agg_segments_module == 'TransformerEncoderLayer':
164
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
165
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
166
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
167
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
168
+ add_pos_emb=True,
169
+ pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
170
+ pos_max_len=pos_max_len,
171
+ **transf_enc_layer_kwargs)
172
+ elif agg_segments_module == 'AveragePooling':
173
+ self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
174
+
175
+ if was_pt_on_avclip:
176
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
177
+ # and keep only the state_dict of the feat extractor
178
+ ckpt_weights = dict()
179
+ for k, v in ckpt['state_dict'].items():
180
+ if k.startswith(('module.v_encoder.', 'v_encoder.')):
181
+ k = k.replace('module.', '').replace('v_encoder.', '')
182
+ ckpt_weights[k] = v
183
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
184
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
185
+ logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
186
+ f'Missing keys ({len(_load_status.missing_keys)}): ' \
187
+ f'{_load_status.missing_keys}, \n' \
188
+ f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
189
+ f'{_load_status.unexpected_keys} \n' \
190
+ f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
191
+ else:
192
+ logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
193
+
194
+ # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
195
+ # but it used to calculate the number of patches, so we need to set keep it
196
+ self.patch_embed.requires_grad_(False)
197
+
198
+ def forward(self, x):
199
+ '''
200
+ x is of shape (B, S, C, T, H, W) where S is the number of segments.
201
+ '''
202
+ # Batch, Segments, Channels, T=frames, Height, Width
203
+ B, S, C, T, H, W = x.shape
204
+ # Motionformer expects a tensor of shape (1, B, C, T, H, W).
205
+ # The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
206
+ # see `video_model_builder.video_input`.
207
+ # x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
208
+
209
+ orig_shape = (B, S, C, T, H, W)
210
+ x = x.view(B * S, C, T, H, W) # flatten batch and segments
211
+ x = self.forward_segments(x, orig_shape=orig_shape)
212
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
213
+ x = x.view(B, S, *x.shape[1:])
214
+ # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
215
+
216
+ return x # x is (B, S, ...)
217
+
218
+ def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
219
+ '''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
220
+ x, x_mask = self.forward_features(x)
221
+
222
+ assert self.extract_features
223
+
224
+ # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
225
+ x = x[:,
226
+ 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
227
+ x = self.norm(x)
228
+ x = self.pre_logits(x)
229
+ if self.factorize_space_time:
230
+ x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
231
+
232
+ x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
233
+ x = self.temp_attn_agg(
234
+ x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
235
+
236
+ return x
237
+
238
+ def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
239
+ '''
240
+ feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
241
+ Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
242
+ From `self.patch_embed_3d`, it follows that we could reshape feats with:
243
+ `feats.transpose(1, 2).view(B*S, D, t, h, w)`
244
+ '''
245
+ B, S, C, T, H, W = orig_shape
246
+ D = self.embed_dim
247
+
248
+ # num patches in each dimension
249
+ t = T // self.patch_embed_3d.z_block_size
250
+ h = self.patch_embed_3d.height
251
+ w = self.patch_embed_3d.width
252
+
253
+ feats = feats.permute(0, 2, 1) # (B*S, D, T)
254
+ feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
255
+
256
+ return feats
257
+
258
+
259
+ class BaseEncoderLayer(nn.TransformerEncoderLayer):
260
+ '''
261
+ This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
262
+ to the sequence and outputs the CLS token's representation.
263
+ This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
264
+ and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
265
+ We also, optionally, add a positional embedding to the input sequence which
266
+ allows to reuse it for global aggregation (of segments) for both streams.
267
+ '''
268
+
269
+ def __init__(self,
270
+ add_pos_emb: bool = False,
271
+ pos_emb_drop: float = None,
272
+ pos_max_len: int = None,
273
+ *args_transformer_enc,
274
+ **kwargs_transformer_enc):
275
+ super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
276
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
277
+ trunc_normal_(self.cls_token, std=.02)
278
+
279
+ # add positional embedding
280
+ self.add_pos_emb = add_pos_emb
281
+ if add_pos_emb:
282
+ self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
283
+ self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
284
+ self.pos_drop = nn.Dropout(pos_emb_drop)
285
+ trunc_normal_(self.pos_emb, std=.02)
286
+
287
+ self.apply(self._init_weights)
288
+
289
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
290
+ ''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
291
+ batch_dim = x.shape[0]
292
+
293
+ # add CLS token
294
+ cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
295
+ x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
296
+ if x_mask is not None:
297
+ cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
298
+ device=x_mask.device) # 1=keep; 0=mask
299
+ x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
300
+ B, N = x_mask_w_cls.shape
301
+ # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
302
+ x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
303
+ .expand(-1, self.self_attn.num_heads, N, -1)\
304
+ .reshape(B * self.self_attn.num_heads, N, N)
305
+ assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
306
+ x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
307
+ else:
308
+ x_mask_w_cls = None
309
+
310
+ # add positional embedding
311
+ if self.add_pos_emb:
312
+ seq_len = x.shape[
313
+ 1] # (don't even think about moving it before the CLS token concatenation)
314
+ assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
315
+ x = x + self.pos_emb[:, :seq_len, :]
316
+ x = self.pos_drop(x)
317
+
318
+ # apply encoder layer (calls nn.TransformerEncoderLayer.forward);
319
+ x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
320
+
321
+ # CLS token is expected to hold spatial information for each frame
322
+ x = x[:, 0, :] # (batch_dim, D)
323
+
324
+ return x
325
+
326
+ def _init_weights(self, m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ @torch.jit.ignore
336
+ def no_weight_decay(self):
337
+ return {'cls_token', 'pos_emb'}
338
+
339
+
340
+ class SpatialTransformerEncoderLayer(BaseEncoderLayer):
341
+ ''' Aggregates spatial dimensions by applying attention individually to each frame. '''
342
+
343
+ def __init__(self, *args, **kwargs):
344
+ super().__init__(*args, **kwargs)
345
+
346
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
347
+ ''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
348
+ if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
349
+ Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
350
+ BS, D, t, h, w = x.shape
351
+
352
+ # time as a batch dimension and flatten spatial dimensions as sequence
353
+ x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
354
+ # similar to mask
355
+ if x_mask is not None:
356
+ x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
357
+
358
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
359
+ x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
360
+
361
+ # reshape back to (B*S, t, D)
362
+ x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
363
+
364
+ # (B*S, t, D)
365
+ return x
366
+
367
+
368
+ class TemporalTransformerEncoderLayer(BaseEncoderLayer):
369
+ ''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
370
+ in both streams. '''
371
+
372
+ def __init__(self, *args, **kwargs):
373
+ super().__init__(*args, **kwargs)
374
+
375
+ def forward(self, x):
376
+ ''' x is of shape (B*S, t, D) where S is the number of segments.
377
+ Returns a tensor of shape (B*S, D) pooling temporal information. '''
378
+ BS, t, D = x.shape
379
+
380
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
381
+ x = super().forward(x) # (B*S, D)
382
+
383
+ return x # (B*S, D)
384
+
385
+
386
+ class AveragePooling(nn.Module):
387
+
388
+ def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
389
+ ''' patterns are e.g. "bs t d -> bs d" '''
390
+ super().__init__()
391
+ # TODO: need to register them as buffers (but fails because these are strings)
392
+ self.reduce_fn = 'mean'
393
+ self.avg_pattern = avg_pattern
394
+ self.then_permute_pattern = then_permute_pattern
395
+
396
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
397
+ x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
398
+ if self.then_permute_pattern is not None:
399
+ x = einops.rearrange(x, self.then_permute_pattern)
400
+ return x
data_utils/ext/synchformer/synchformer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Mapping
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from data_utils.ext.synchformer.motionformer import MotionFormer
8
+
9
+
10
+ class Synchformer(nn.Module):
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ self.vfeat_extractor = MotionFormer(extract_features=True,
16
+ factorize_space_time=True,
17
+ agg_space_module='TransformerEncoderLayer',
18
+ agg_time_module='torch.nn.Identity',
19
+ add_global_repr=False)
20
+
21
+ # self.vfeat_extractor = instantiate_from_config(vfeat_extractor)
22
+ # self.afeat_extractor = instantiate_from_config(afeat_extractor)
23
+ # # bridging the s3d latent dim (1024) into what is specified in the config
24
+ # # to match e.g. the transformer dim
25
+ # self.vproj = instantiate_from_config(vproj)
26
+ # self.aproj = instantiate_from_config(aproj)
27
+ # self.transformer = instantiate_from_config(transformer)
28
+
29
+ def forward(self, vis):
30
+ B, S, Tv, C, H, W = vis.shape
31
+ vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
32
+ # feat extractors return a tuple of segment-level and global features (ignored for sync)
33
+ # (B, S, tv, D), e.g. (B, 7, 8, 768)
34
+ vis = self.vfeat_extractor(vis)
35
+ return vis
36
+
37
+ def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
38
+ # discard all entries except vfeat_extractor
39
+ sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
40
+
41
+ return super().load_state_dict(sd, strict)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ model = Synchformer().cuda().eval()
46
+ sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
47
+ model.load_state_dict(sd)
48
+
49
+ vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
50
+ features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
51
+ print(features.shape)
52
+
53
+ # extract and save the state dict only
54
+ # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
55
+ # torch.save(sd, './ext_weights/synchformer_state_dict.pth')
data_utils/ext/synchformer/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import md5
2
+ from pathlib import Path
3
+
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+ PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a'
8
+ FNAME2LINK = {
9
+ # S3: Synchability: AudioSet (run 2)
10
+ '24-01-22T20-34-52.pt':
11
+ f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt',
12
+ 'cfg-24-01-22T20-34-52.yaml':
13
+ f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml',
14
+ # S2: Synchformer: AudioSet (run 2)
15
+ '24-01-04T16-39-21.pt':
16
+ f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt',
17
+ 'cfg-24-01-04T16-39-21.yaml':
18
+ f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml',
19
+ # S2: Synchformer: AudioSet (run 1)
20
+ '23-08-28T11-23-23.pt':
21
+ f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt',
22
+ 'cfg-23-08-28T11-23-23.yaml':
23
+ f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml',
24
+ # S2: Synchformer: LRS3 (run 2)
25
+ '23-12-23T18-33-57.pt':
26
+ f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt',
27
+ 'cfg-23-12-23T18-33-57.yaml':
28
+ f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml',
29
+ # S2: Synchformer: VGS (run 2)
30
+ '24-01-02T10-00-53.pt':
31
+ f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt',
32
+ 'cfg-24-01-02T10-00-53.yaml':
33
+ f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml',
34
+ # SparseSync: ft VGGSound-Full
35
+ '22-09-21T21-00-52.pt':
36
+ f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt',
37
+ 'cfg-22-09-21T21-00-52.yaml':
38
+ f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml',
39
+ # SparseSync: ft VGGSound-Sparse
40
+ '22-07-28T15-49-45.pt':
41
+ f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt',
42
+ 'cfg-22-07-28T15-49-45.yaml':
43
+ f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml',
44
+ # SparseSync: only pt on LRS3
45
+ '22-07-13T22-25-49.pt':
46
+ f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt',
47
+ 'cfg-22-07-13T22-25-49.yaml':
48
+ f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml',
49
+ # SparseSync: feature extractors
50
+ 'ResNetAudio-22-08-04T09-51-04.pt':
51
+ f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s
52
+ 'ResNetAudio-22-08-03T23-14-49.pt':
53
+ f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s
54
+ 'ResNetAudio-22-08-03T23-14-28.pt':
55
+ f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s
56
+ 'ResNetAudio-22-06-24T08-10-33.pt':
57
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s
58
+ 'ResNetAudio-22-06-24T17-31-07.pt':
59
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s
60
+ 'ResNetAudio-22-06-24T23-57-11.pt':
61
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s
62
+ 'ResNetAudio-22-06-25T04-35-42.pt':
63
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s
64
+ }
65
+
66
+
67
+ def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
68
+ '''Checks if file exists, if not downloads it from the link to the path'''
69
+ path = Path(path)
70
+ if not path.exists():
71
+ path.parent.mkdir(exist_ok=True, parents=True)
72
+ link = fname2link.get(path.name, None)
73
+ if link is None:
74
+ raise ValueError(f'Cant find the checkpoint file: {path}.',
75
+ f'Please download it manually and ensure the path exists.')
76
+ with requests.get(fname2link[path.name], stream=True) as r:
77
+ total_size = int(r.headers.get('content-length', 0))
78
+ with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
79
+ with open(path, 'wb') as f:
80
+ for data in r.iter_content(chunk_size=chunk_size):
81
+ if data:
82
+ f.write(data)
83
+ pbar.update(chunk_size)
84
+
85
+
86
+ def get_md5sum(path):
87
+ hash_md5 = md5()
88
+ with open(path, 'rb') as f:
89
+ for chunk in iter(lambda: f.read(4096 * 8), b''):
90
+ hash_md5.update(chunk)
91
+ md5sum = hash_md5.hexdigest()
92
+ return md5sum
data_utils/ext/synchformer/video_model_builder.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+ # Copyright 2020 Ross Wightman
4
+ # Modified Model definition
5
+
6
+ from collections import OrderedDict
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.layers import trunc_normal_
12
+
13
+ from data_utils.ext.synchformer import vit_helper
14
+
15
+
16
+ class VisionTransformer(nn.Module):
17
+ """ Vision Transformer with support for patch or hybrid CNN input stage """
18
+
19
+ def __init__(self, cfg):
20
+ super().__init__()
21
+ self.img_size = cfg.DATA.TRAIN_CROP_SIZE
22
+ self.patch_size = cfg.VIT.PATCH_SIZE
23
+ self.in_chans = cfg.VIT.CHANNELS
24
+ if cfg.TRAIN.DATASET == "Epickitchens":
25
+ self.num_classes = [97, 300]
26
+ else:
27
+ self.num_classes = cfg.MODEL.NUM_CLASSES
28
+ self.embed_dim = cfg.VIT.EMBED_DIM
29
+ self.depth = cfg.VIT.DEPTH
30
+ self.num_heads = cfg.VIT.NUM_HEADS
31
+ self.mlp_ratio = cfg.VIT.MLP_RATIO
32
+ self.qkv_bias = cfg.VIT.QKV_BIAS
33
+ self.drop_rate = cfg.VIT.DROP
34
+ self.drop_path_rate = cfg.VIT.DROP_PATH
35
+ self.head_dropout = cfg.VIT.HEAD_DROPOUT
36
+ self.video_input = cfg.VIT.VIDEO_INPUT
37
+ self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
38
+ self.use_mlp = cfg.VIT.USE_MLP
39
+ self.num_features = self.embed_dim
40
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
41
+ self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
42
+ self.head_act = cfg.VIT.HEAD_ACT
43
+ self.cfg = cfg
44
+
45
+ # Patch Embedding
46
+ self.patch_embed = vit_helper.PatchEmbed(img_size=224,
47
+ patch_size=self.patch_size,
48
+ in_chans=self.in_chans,
49
+ embed_dim=self.embed_dim)
50
+
51
+ # 3D Patch Embedding
52
+ self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size,
53
+ temporal_resolution=self.temporal_resolution,
54
+ patch_size=self.patch_size,
55
+ in_chans=self.in_chans,
56
+ embed_dim=self.embed_dim,
57
+ z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP)
58
+ self.patch_embed_3d.proj.weight.data = torch.zeros_like(
59
+ self.patch_embed_3d.proj.weight.data)
60
+
61
+ # Number of patches
62
+ if self.video_input:
63
+ num_patches = self.patch_embed.num_patches * self.temporal_resolution
64
+ else:
65
+ num_patches = self.patch_embed.num_patches
66
+ self.num_patches = num_patches
67
+
68
+ # CLS token
69
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
70
+ trunc_normal_(self.cls_token, std=.02)
71
+
72
+ # Positional embedding
73
+ self.pos_embed = nn.Parameter(
74
+ torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
75
+ self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
76
+ trunc_normal_(self.pos_embed, std=.02)
77
+
78
+ if self.cfg.VIT.POS_EMBED == "joint":
79
+ self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
80
+ trunc_normal_(self.st_embed, std=.02)
81
+ elif self.cfg.VIT.POS_EMBED == "separate":
82
+ self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
83
+
84
+ # Layer Blocks
85
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
86
+ if self.cfg.VIT.ATTN_LAYER == "divided":
87
+ self.blocks = nn.ModuleList([
88
+ vit_helper.DividedSpaceTimeBlock(
89
+ attn_type=cfg.VIT.ATTN_LAYER,
90
+ dim=self.embed_dim,
91
+ num_heads=self.num_heads,
92
+ mlp_ratio=self.mlp_ratio,
93
+ qkv_bias=self.qkv_bias,
94
+ drop=self.drop_rate,
95
+ attn_drop=self.attn_drop_rate,
96
+ drop_path=dpr[i],
97
+ norm_layer=norm_layer,
98
+ ) for i in range(self.depth)
99
+ ])
100
+ else:
101
+ self.blocks = nn.ModuleList([
102
+ vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER,
103
+ dim=self.embed_dim,
104
+ num_heads=self.num_heads,
105
+ mlp_ratio=self.mlp_ratio,
106
+ qkv_bias=self.qkv_bias,
107
+ drop=self.drop_rate,
108
+ attn_drop=self.attn_drop_rate,
109
+ drop_path=dpr[i],
110
+ norm_layer=norm_layer,
111
+ use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE)
112
+ for i in range(self.depth)
113
+ ])
114
+ self.norm = norm_layer(self.embed_dim)
115
+
116
+ # MLP head
117
+ if self.use_mlp:
118
+ hidden_dim = self.embed_dim
119
+ if self.head_act == 'tanh':
120
+ # logging.info("Using TanH activation in MLP")
121
+ act = nn.Tanh()
122
+ elif self.head_act == 'gelu':
123
+ # logging.info("Using GELU activation in MLP")
124
+ act = nn.GELU()
125
+ else:
126
+ # logging.info("Using ReLU activation in MLP")
127
+ act = nn.ReLU()
128
+ self.pre_logits = nn.Sequential(
129
+ OrderedDict([
130
+ ('fc', nn.Linear(self.embed_dim, hidden_dim)),
131
+ ('act', act),
132
+ ]))
133
+ else:
134
+ self.pre_logits = nn.Identity()
135
+
136
+ # Classifier Head
137
+ self.head_drop = nn.Dropout(p=self.head_dropout)
138
+ if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
139
+ for a, i in enumerate(range(len(self.num_classes))):
140
+ setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
141
+ else:
142
+ self.head = nn.Linear(self.embed_dim,
143
+ self.num_classes) if self.num_classes > 0 else nn.Identity()
144
+
145
+ # Initialize weights
146
+ self.apply(self._init_weights)
147
+
148
+ def _init_weights(self, m):
149
+ if isinstance(m, nn.Linear):
150
+ trunc_normal_(m.weight, std=.02)
151
+ if isinstance(m, nn.Linear) and m.bias is not None:
152
+ nn.init.constant_(m.bias, 0)
153
+ elif isinstance(m, nn.LayerNorm):
154
+ nn.init.constant_(m.bias, 0)
155
+ nn.init.constant_(m.weight, 1.0)
156
+
157
+ @torch.jit.ignore
158
+ def no_weight_decay(self):
159
+ if self.cfg.VIT.POS_EMBED == "joint":
160
+ return {'pos_embed', 'cls_token', 'st_embed'}
161
+ else:
162
+ return {'pos_embed', 'cls_token', 'temp_embed'}
163
+
164
+ def get_classifier(self):
165
+ return self.head
166
+
167
+ def reset_classifier(self, num_classes, global_pool=''):
168
+ self.num_classes = num_classes
169
+ self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity())
170
+
171
+ def forward_features(self, x):
172
+ # if self.video_input:
173
+ # x = x[0]
174
+ B = x.shape[0]
175
+
176
+ # Tokenize input
177
+ # if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
178
+ # for simplicity of mapping between content dimensions (input x) and token dims (after patching)
179
+ # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
180
+
181
+ # apply patching on input
182
+ x = self.patch_embed_3d(x)
183
+ tok_mask = None
184
+
185
+ # else:
186
+ # tok_mask = None
187
+ # # 2D tokenization
188
+ # if self.video_input:
189
+ # x = x.permute(0, 2, 1, 3, 4)
190
+ # (B, T, C, H, W) = x.shape
191
+ # x = x.reshape(B * T, C, H, W)
192
+
193
+ # x = self.patch_embed(x)
194
+
195
+ # if self.video_input:
196
+ # (B2, T2, D2) = x.shape
197
+ # x = x.reshape(B, T * T2, D2)
198
+
199
+ # Append CLS token
200
+ cls_tokens = self.cls_token.expand(B, -1, -1)
201
+ x = torch.cat((cls_tokens, x), dim=1)
202
+ # if tok_mask is not None:
203
+ # # prepend 1(=keep) to the mask to account for the CLS token as well
204
+ # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
205
+
206
+ # Interpolate positinoal embeddings
207
+ # if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
208
+ # pos_embed = self.pos_embed
209
+ # N = pos_embed.shape[1] - 1
210
+ # npatch = int((x.size(1) - 1) / self.temporal_resolution)
211
+ # class_emb = pos_embed[:, 0]
212
+ # pos_embed = pos_embed[:, 1:]
213
+ # dim = x.shape[-1]
214
+ # pos_embed = torch.nn.functional.interpolate(
215
+ # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
216
+ # scale_factor=math.sqrt(npatch / N),
217
+ # mode='bicubic',
218
+ # )
219
+ # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
220
+ # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
221
+ # else:
222
+ new_pos_embed = self.pos_embed
223
+ npatch = self.patch_embed.num_patches
224
+
225
+ # Add positional embeddings to input
226
+ if self.video_input:
227
+ if self.cfg.VIT.POS_EMBED == "separate":
228
+ cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
229
+ tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
230
+ tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
231
+ total_pos_embed = tile_pos_embed + tile_temporal_embed
232
+ total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
233
+ x = x + total_pos_embed
234
+ elif self.cfg.VIT.POS_EMBED == "joint":
235
+ x = x + self.st_embed
236
+ else:
237
+ # image input
238
+ x = x + new_pos_embed
239
+
240
+ # Apply positional dropout
241
+ x = self.pos_drop(x)
242
+
243
+ # Encoding using transformer layers
244
+ for i, blk in enumerate(self.blocks):
245
+ x = blk(x,
246
+ seq_len=npatch,
247
+ num_frames=self.temporal_resolution,
248
+ approx=self.cfg.VIT.APPROX_ATTN_TYPE,
249
+ num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
250
+ tok_mask=tok_mask)
251
+
252
+ ### v-iashin: I moved it to the forward pass
253
+ # x = self.norm(x)[:, 0]
254
+ # x = self.pre_logits(x)
255
+ ###
256
+ return x, tok_mask
257
+
258
+ # def forward(self, x):
259
+ # x = self.forward_features(x)
260
+ # ### v-iashin: here. This should leave the same forward output as before
261
+ # x = self.norm(x)[:, 0]
262
+ # x = self.pre_logits(x)
263
+ # ###
264
+ # x = self.head_drop(x)
265
+ # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
266
+ # output = []
267
+ # for head in range(len(self.num_classes)):
268
+ # x_out = getattr(self, "head%d" % head)(x)
269
+ # if not self.training:
270
+ # x_out = torch.nn.functional.softmax(x_out, dim=-1)
271
+ # output.append(x_out)
272
+ # return output
273
+ # else:
274
+ # x = self.head(x)
275
+ # if not self.training:
276
+ # x = torch.nn.functional.softmax(x, dim=-1)
277
+ # return x
data_utils/ext/synchformer/vit_helper.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+ # Copyright 2020 Ross Wightman
4
+ # Modified Model definition
5
+ """Video models."""
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+ from timm.layers import to_2tuple
13
+ from torch import einsum
14
+ from torch.nn import functional as F
15
+
16
+ default_cfgs = {
17
+ 'vit_1k':
18
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
19
+ 'vit_1k_large':
20
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
21
+ }
22
+
23
+
24
+ def qkv_attn(q, k, v, tok_mask: torch.Tensor = None):
25
+ sim = einsum('b i d, b j d -> b i j', q, k)
26
+ # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N)
27
+ if tok_mask is not None:
28
+ BSH, N = tok_mask.shape
29
+ sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0,
30
+ float('-inf')) # 1 - broadcasts across N
31
+ attn = sim.softmax(dim=-1)
32
+ out = einsum('b i j, b j d -> b i d', attn, v)
33
+ return out
34
+
35
+
36
+ class DividedAttention(nn.Module):
37
+
38
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44
+ self.proj = nn.Linear(dim, dim)
45
+
46
+ # init to zeros
47
+ self.qkv.weight.data.fill_(0)
48
+ self.qkv.bias.data.fill_(0)
49
+ self.proj.weight.data.fill_(1)
50
+ self.proj.bias.data.fill_(0)
51
+
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+
55
+ def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
56
+ # num of heads variable
57
+ h = self.num_heads
58
+
59
+ # project x to q, k, v vaalues
60
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
61
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
62
+ if tok_mask is not None:
63
+ # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d
64
+ assert len(tok_mask.shape) == 2
65
+ tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1])
66
+
67
+ # Scale q
68
+ q *= self.scale
69
+
70
+ # Take out cls_q, cls_k, cls_v
71
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
72
+ # the same for masking
73
+ if tok_mask is not None:
74
+ cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:]
75
+ else:
76
+ cls_mask, mask_ = None, None
77
+
78
+ # let CLS token attend to key / values of all patches across time and space
79
+ cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask)
80
+
81
+ # rearrange across time or space
82
+ q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims),
83
+ (q_, k_, v_))
84
+
85
+ # expand CLS token keys and values across time or space and concat
86
+ r = q_.shape[0] // cls_k.shape[0]
87
+ cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
88
+
89
+ k_ = torch.cat((cls_k, k_), dim=1)
90
+ v_ = torch.cat((cls_v, v_), dim=1)
91
+
92
+ # the same for masking (if provided)
93
+ if tok_mask is not None:
94
+ # since mask does not have the latent dim (d), we need to remove it from einops dims
95
+ mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''),
96
+ **einops_dims)
97
+ cls_mask = repeat(cls_mask, 'b () -> (b r) ()',
98
+ r=r) # expand cls_mask across time or space
99
+ mask_ = torch.cat((cls_mask, mask_), dim=1)
100
+
101
+ # attention
102
+ out = qkv_attn(q_, k_, v_, tok_mask=mask_)
103
+
104
+ # merge back time or space
105
+ out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
106
+
107
+ # concat back the cls token
108
+ out = torch.cat((cls_out, out), dim=1)
109
+
110
+ # merge back the heads
111
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
112
+
113
+ ## to out
114
+ x = self.proj(out)
115
+ x = self.proj_drop(x)
116
+ return x
117
+
118
+
119
+ class DividedSpaceTimeBlock(nn.Module):
120
+
121
+ def __init__(self,
122
+ dim=768,
123
+ num_heads=12,
124
+ attn_type='divided',
125
+ mlp_ratio=4.,
126
+ qkv_bias=False,
127
+ drop=0.,
128
+ attn_drop=0.,
129
+ drop_path=0.,
130
+ act_layer=nn.GELU,
131
+ norm_layer=nn.LayerNorm):
132
+ super().__init__()
133
+
134
+ self.einops_from_space = 'b (f n) d'
135
+ self.einops_to_space = '(b f) n d'
136
+ self.einops_from_time = 'b (f n) d'
137
+ self.einops_to_time = '(b n) f d'
138
+
139
+ self.norm1 = norm_layer(dim)
140
+
141
+ self.attn = DividedAttention(dim,
142
+ num_heads=num_heads,
143
+ qkv_bias=qkv_bias,
144
+ attn_drop=attn_drop,
145
+ proj_drop=drop)
146
+
147
+ self.timeattn = DividedAttention(dim,
148
+ num_heads=num_heads,
149
+ qkv_bias=qkv_bias,
150
+ attn_drop=attn_drop,
151
+ proj_drop=drop)
152
+
153
+ # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
154
+ self.drop_path = nn.Identity()
155
+ self.norm2 = norm_layer(dim)
156
+ mlp_hidden_dim = int(dim * mlp_ratio)
157
+ self.mlp = Mlp(in_features=dim,
158
+ hidden_features=mlp_hidden_dim,
159
+ act_layer=act_layer,
160
+ drop=drop)
161
+ self.norm3 = norm_layer(dim)
162
+
163
+ def forward(self,
164
+ x,
165
+ seq_len=196,
166
+ num_frames=8,
167
+ approx='none',
168
+ num_landmarks=128,
169
+ tok_mask: torch.Tensor = None):
170
+ time_output = self.timeattn(self.norm3(x),
171
+ self.einops_from_time,
172
+ self.einops_to_time,
173
+ n=seq_len,
174
+ tok_mask=tok_mask)
175
+ time_residual = x + time_output
176
+
177
+ space_output = self.attn(self.norm1(time_residual),
178
+ self.einops_from_space,
179
+ self.einops_to_space,
180
+ f=num_frames,
181
+ tok_mask=tok_mask)
182
+ space_residual = time_residual + self.drop_path(space_output)
183
+
184
+ x = space_residual
185
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
186
+ return x
187
+
188
+
189
+ class Mlp(nn.Module):
190
+
191
+ def __init__(self,
192
+ in_features,
193
+ hidden_features=None,
194
+ out_features=None,
195
+ act_layer=nn.GELU,
196
+ drop=0.):
197
+ super().__init__()
198
+ out_features = out_features or in_features
199
+ hidden_features = hidden_features or in_features
200
+ self.fc1 = nn.Linear(in_features, hidden_features)
201
+ self.act = act_layer()
202
+ self.fc2 = nn.Linear(hidden_features, out_features)
203
+ self.drop = nn.Dropout(drop)
204
+
205
+ def forward(self, x):
206
+ x = self.fc1(x)
207
+ x = self.act(x)
208
+ x = self.drop(x)
209
+ x = self.fc2(x)
210
+ x = self.drop(x)
211
+ return x
212
+
213
+
214
+ class PatchEmbed(nn.Module):
215
+ """ Image to Patch Embedding
216
+ """
217
+
218
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
219
+ super().__init__()
220
+ img_size = img_size if type(img_size) is tuple else to_2tuple(img_size)
221
+ patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size)
222
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
223
+ self.img_size = img_size
224
+ self.patch_size = patch_size
225
+ self.num_patches = num_patches
226
+
227
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
228
+
229
+ def forward(self, x):
230
+ B, C, H, W = x.shape
231
+ x = self.proj(x).flatten(2).transpose(1, 2)
232
+ return x
233
+
234
+
235
+ class PatchEmbed3D(nn.Module):
236
+ """ Image to Patch Embedding """
237
+
238
+ def __init__(self,
239
+ img_size=224,
240
+ temporal_resolution=4,
241
+ in_chans=3,
242
+ patch_size=16,
243
+ z_block_size=2,
244
+ embed_dim=768,
245
+ flatten=True):
246
+ super().__init__()
247
+ self.height = (img_size // patch_size)
248
+ self.width = (img_size // patch_size)
249
+ ### v-iashin: these two are incorrect
250
+ # self.frames = (temporal_resolution // z_block_size)
251
+ # self.num_patches = self.height * self.width * self.frames
252
+ self.z_block_size = z_block_size
253
+ ###
254
+ self.proj = nn.Conv3d(in_chans,
255
+ embed_dim,
256
+ kernel_size=(z_block_size, patch_size, patch_size),
257
+ stride=(z_block_size, patch_size, patch_size))
258
+ self.flatten = flatten
259
+
260
+ def forward(self, x):
261
+ B, C, T, H, W = x.shape
262
+ x = self.proj(x)
263
+ if self.flatten:
264
+ x = x.flatten(2).transpose(1, 2)
265
+ return x
266
+
267
+
268
+ class HeadMLP(nn.Module):
269
+
270
+ def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
271
+ super(HeadMLP, self).__init__()
272
+ self.n_input = n_input
273
+ self.n_classes = n_classes
274
+ self.n_hidden = n_hidden
275
+ if n_hidden is None:
276
+ # use linear classifier
277
+ self.block_forward = nn.Sequential(nn.Dropout(p=p),
278
+ nn.Linear(n_input, n_classes, bias=True))
279
+ else:
280
+ # use simple MLP classifier
281
+ self.block_forward = nn.Sequential(nn.Dropout(p=p),
282
+ nn.Linear(n_input, n_hidden, bias=True),
283
+ nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True),
284
+ nn.Dropout(p=p),
285
+ nn.Linear(n_hidden, n_classes, bias=True))
286
+ print(f"Dropout-NLP: {p}")
287
+
288
+ def forward(self, x):
289
+ return self.block_forward(x)
290
+
291
+
292
+ def _conv_filter(state_dict, patch_size=16):
293
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
294
+ out_dict = {}
295
+ for k, v in state_dict.items():
296
+ if 'patch_embed.proj.weight' in k:
297
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
298
+ out_dict[k] = v
299
+ return out_dict
300
+
301
+
302
+ def adapt_input_conv(in_chans, conv_weight, agg='sum'):
303
+ conv_type = conv_weight.dtype
304
+ conv_weight = conv_weight.float()
305
+ O, I, J, K = conv_weight.shape
306
+ if in_chans == 1:
307
+ if I > 3:
308
+ assert conv_weight.shape[1] % 3 == 0
309
+ # For models with space2depth stems
310
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
311
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
312
+ else:
313
+ if agg == 'sum':
314
+ print("Summing conv1 weights")
315
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
316
+ else:
317
+ print("Averaging conv1 weights")
318
+ conv_weight = conv_weight.mean(dim=1, keepdim=True)
319
+ elif in_chans != 3:
320
+ if I != 3:
321
+ raise NotImplementedError('Weight format not supported by conversion.')
322
+ else:
323
+ if agg == 'sum':
324
+ print("Summing conv1 weights")
325
+ repeat = int(math.ceil(in_chans / 3))
326
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
327
+ conv_weight *= (3 / float(in_chans))
328
+ else:
329
+ print("Averaging conv1 weights")
330
+ conv_weight = conv_weight.mean(dim=1, keepdim=True)
331
+ conv_weight = conv_weight.repeat(1, in_chans, 1, 1)
332
+ conv_weight = conv_weight.to(conv_type)
333
+ return conv_weight
334
+
335
+
336
+ def load_pretrained(model,
337
+ cfg=None,
338
+ num_classes=1000,
339
+ in_chans=3,
340
+ filter_fn=None,
341
+ strict=True,
342
+ progress=False):
343
+ # Load state dict
344
+ assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]")
345
+ state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS])
346
+
347
+ if filter_fn is not None:
348
+ state_dict = filter_fn(state_dict)
349
+
350
+ input_convs = 'patch_embed.proj'
351
+ if input_convs is not None and in_chans != 3:
352
+ if isinstance(input_convs, str):
353
+ input_convs = (input_convs, )
354
+ for input_conv_name in input_convs:
355
+ weight_name = input_conv_name + '.weight'
356
+ try:
357
+ state_dict[weight_name] = adapt_input_conv(in_chans,
358
+ state_dict[weight_name],
359
+ agg='avg')
360
+ print(
361
+ f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)'
362
+ )
363
+ except NotImplementedError as e:
364
+ del state_dict[weight_name]
365
+ strict = False
366
+ print(
367
+ f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.'
368
+ )
369
+
370
+ classifier_name = 'head'
371
+ label_offset = cfg.get('label_offset', 0)
372
+ pretrain_classes = 1000
373
+ if num_classes != pretrain_classes:
374
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
375
+ del state_dict[classifier_name + '.weight']
376
+ del state_dict[classifier_name + '.bias']
377
+ strict = False
378
+ elif label_offset > 0:
379
+ # special case for pretrained weights with an extra background class in pretrained weights
380
+ classifier_weight = state_dict[classifier_name + '.weight']
381
+ state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
382
+ classifier_bias = state_dict[classifier_name + '.bias']
383
+ state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
384
+
385
+ loaded_state = state_dict
386
+ self_state = model.state_dict()
387
+ all_names = set(self_state.keys())
388
+ saved_names = set([])
389
+ for name, param in loaded_state.items():
390
+ param = param
391
+ if 'module.' in name:
392
+ name = name.replace('module.', '')
393
+ if name in self_state.keys() and param.shape == self_state[name].shape:
394
+ saved_names.add(name)
395
+ self_state[name].copy_(param)
396
+ else:
397
+ print(f"didnt load: {name} of shape: {param.shape}")
398
+ print("Missing Keys:")
399
+ print(all_names - saved_names)
data_utils/v2a_utils/__init__.py ADDED
File without changes
data_utils/v2a_utils/audio_text_dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from PIL import Image
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+ from transformers import AutoProcessor
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ import logging
18
+ log = logging.getLogger()
19
+
20
+ _CLIP_SIZE = 224
21
+ _CLIP_FPS = 8.0
22
+
23
+ _SYNC_SIZE = 224
24
+ _SYNC_FPS = 25.0
25
+
26
+
27
+ class Audio_Text(Dataset):
28
+
29
+ def __init__(
30
+ self,
31
+ root: Union[str, Path],
32
+ *,
33
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
34
+ sample_rate: int = 44_100,
35
+ duration_sec: float = 9.0,
36
+ audio_samples: Optional[int] = 397312,
37
+ normalize_audio: bool = False,
38
+ start_row: Optional[int] = None,
39
+ end_row: Optional[int] = None,
40
+ save_dir: str = 'data/vggsound/video_latents_text/train'
41
+ ):
42
+ self.root = Path(root)
43
+ self.normalize_audio = normalize_audio
44
+ if audio_samples is None:
45
+ self.audio_samples = int(sample_rate * duration_sec)
46
+ else:
47
+ self.audio_samples = audio_samples
48
+ effective_duration = audio_samples / sample_rate
49
+ # make sure the duration is close enough, within 15ms
50
+ assert abs(effective_duration - duration_sec) < 0.015, \
51
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
52
+
53
+ # videos = sorted(os.listdir(self.root))
54
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
55
+ videos = []
56
+ self.labels = []
57
+ self.videos = []
58
+ self.cots = []
59
+ missing_videos = []
60
+ # read the tsv for subset information
61
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
62
+
63
+ # 控制处理的行范围
64
+ if start_row is not None and end_row is not None:
65
+ df_list = df_list[start_row:end_row]
66
+ for record in df_list:
67
+ id = record['id']
68
+ if os.path.exists(f'{save_dir}/{id}.pth'): continue
69
+ label = record['caption']
70
+ # if id in videos:
71
+ self.labels.append(label)
72
+ # print(label,'debug1!!!!!!!!!')
73
+ self.cots.append(record['caption_cot'])
74
+ # self.labels[id] = label
75
+ self.videos.append(id)
76
+ # else:
77
+ # missing_videos.append(id)
78
+
79
+ log.info(f'{len(videos)} videos found in {root}')
80
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
81
+ log.info(f'{len(missing_videos)} videos missing in {root}')
82
+
83
+ self.sample_rate = sample_rate
84
+ self.duration_sec = duration_sec
85
+
86
+ self.expected_audio_length = self.audio_samples
87
+ self.resampler = {}
88
+
89
+ def sample(self, idx: int):
90
+ video_id = self.videos[idx]
91
+ label = self.labels[idx]
92
+ cot = self.cots[idx]
93
+ audio_path = os.path.join(self.root, f'{video_id}.wav')
94
+ if not os.path.exists(audio_path):
95
+ audio_path = os.path.join(self.root, f'{video_id}.flac')
96
+ if not os.path.exists(audio_path):
97
+ raise RuntimeError(f'Audio is not exist {audio_path}')
98
+ audio_chunk, sample_rate = torchaudio.load(audio_path)
99
+ if len(audio_chunk.shape) != 2:
100
+ raise RuntimeError(f'error audio shape {video_id}')
101
+
102
+ abs_max = audio_chunk[0].abs().max()
103
+
104
+ if abs_max <= 1e-6:
105
+ if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
106
+ audio_chunk = audio_chunk[1:2]
107
+ else:
108
+ raise RuntimeError(f'Audio is silent {video_id}')
109
+
110
+ # ensure the stereo audio
111
+ if audio_chunk.shape[0] < 2:
112
+ audio_chunk = audio_chunk.repeat(2, 1)
113
+ elif audio_chunk.shape[0] > 2:
114
+ audio_chunk = audio_chunk[:2]
115
+
116
+ # resample
117
+ if sample_rate == self.sample_rate:
118
+ audio_chunk = audio_chunk
119
+ else:
120
+ if sample_rate not in self.resampler:
121
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
122
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
123
+ sample_rate,
124
+ self.sample_rate,
125
+ lowpass_filter_width=64,
126
+ rolloff=0.9475937167399596,
127
+ resampling_method='sinc_interp_kaiser',
128
+ beta=14.769656459379492,
129
+ )
130
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
131
+
132
+ if audio_chunk.shape[1] < self.expected_audio_length:
133
+ # zero-padding audio
134
+ padding_length = self.expected_audio_length - audio_chunk.shape[1]
135
+ # 创建 padding 张量,大小为 [batch_size, padding_length],值为0
136
+ padding = torch.zeros(audio_chunk.shape[0], padding_length)
137
+ # 将原始音频和 padding 沿第 1 维度拼接在一起
138
+ audio_chunk = torch.cat((audio_chunk, padding), dim=1)
139
+ # raise RuntimeError(f'Audio too short {video_id}')
140
+ audio_chunk = audio_chunk[:,:self.expected_audio_length]
141
+ assert audio_chunk.shape == (2, 397312), f'error shape:{video_id},{audio_chunk.shape}'
142
+ # print(label,'debug2!!!!!!!!!')
143
+ data = {
144
+ 'id': video_id,
145
+ 'caption': label,
146
+ 'caption_cot': cot,
147
+ 'audio': audio_chunk,
148
+ }
149
+
150
+ return data
151
+
152
+ def __getitem__(self, idx: int):
153
+ try:
154
+ return self.sample(idx)
155
+ except Exception as e:
156
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
157
+ return None
158
+
159
+ def __len__(self):
160
+ return len(self.labels)
161
+
162
+
163
+ # dataset = VGGSound(
164
+ # root="data/vggsound/video/train",
165
+ # tsv_path="data/vggsound/split_txt/temp.csv",
166
+ # sample_rate=44100,
167
+ # duration_sec=9.0,
168
+ # audio_samples=397312,
169
+ # start_row=0,
170
+ # end_row=None,
171
+ # save_dir="data/vggsound/video_224_latents_text/train"
172
+ # )
173
+ # dataset[0]
data_utils/v2a_utils/audioset_224.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from PIL import Image
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+ from transformers import AutoProcessor
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ import logging
18
+ log = logging.getLogger()
19
+
20
+ _CLIP_SIZE = 224
21
+ _CLIP_FPS = 8.0
22
+
23
+ _SYNC_SIZE = 224
24
+ _SYNC_FPS = 25.0
25
+
26
+ def save_tensor_as_image(tensor, save_path):
27
+ """
28
+ 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
29
+
30
+ :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
31
+ :param save_path: 图片保存路径。
32
+ """
33
+ # # 移除批次维度,变成 (3, H, W)
34
+ # tensor = tensor.squeeze(0)
35
+
36
+ # 交换轴顺序,变为 (H, W, 3)
37
+ image_array = np.transpose(tensor, (1, 2, 0))
38
+
39
+ # 检查数组是否为合适的数据类型
40
+ if image_array.dtype != np.uint8:
41
+ # 如果不是 uint8,首先标准化,然后转换
42
+ image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
43
+ image_array = image_array.astype(np.uint8)
44
+
45
+ # 创建图像对象
46
+ image = Image.fromarray(image_array)
47
+
48
+ # 保存图片
49
+ image.save(save_path)
50
+ print(f"Image saved to {save_path}")
51
+
52
+ def pad_to_square(video_tensor):
53
+ # 验证输入的形状
54
+ if len(video_tensor.shape) != 4:
55
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
56
+
57
+ l, c, h, w = video_tensor.shape
58
+ max_side = max(h, w)
59
+
60
+ # 计算每一维度需要的填充量:(left, right, top, bottom)
61
+ pad_h = max_side - h
62
+ pad_w = max_side - w
63
+
64
+ # 创建padding tuple (left, right, top, bottom)
65
+ # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
66
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
67
+
68
+ # 使用F.pad对视频张量进行填充操作
69
+ # 填充参数为 (left, right, top, bottom)
70
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
71
+
72
+ return video_padded
73
+
74
+ class Audioset(Dataset):
75
+
76
+ def __init__(
77
+ self,
78
+ root: Union[str, Path],
79
+ *,
80
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
81
+ sample_rate: int = 44_100,
82
+ duration_sec: float = 9.0,
83
+ audio_samples: Optional[int] = 397312,
84
+ normalize_audio: bool = False,
85
+ start_row: Optional[int] = None,
86
+ end_row: Optional[int] = None,
87
+ save_dir: str = 'data/vggsound/video_latents_text/train'
88
+ ):
89
+ self.root = Path(root)
90
+ self.normalize_audio = normalize_audio
91
+ if audio_samples is None:
92
+ self.audio_samples = int(sample_rate * duration_sec)
93
+ else:
94
+ self.audio_samples = audio_samples
95
+ effective_duration = audio_samples / sample_rate
96
+ # make sure the duration is close enough, within 15ms
97
+ assert abs(effective_duration - duration_sec) < 0.015, \
98
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
99
+
100
+ # videos = sorted(os.listdir(self.root))
101
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
102
+ videos = []
103
+ self.labels = []
104
+ self.videos = []
105
+ self.caption_t5s = []
106
+ missing_videos = []
107
+ # read the tsv for subset information
108
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
109
+
110
+ # 控制处理的行范围
111
+ if start_row is not None and end_row is not None:
112
+ df_list = df_list[start_row:end_row]
113
+
114
+ for record in df_list:
115
+ id = record['id']
116
+ if os.path.exists(f'{save_dir}/{id}.pth'): continue
117
+ label = record['label']
118
+ caption_t5 = record['caption_t5']
119
+ # if id in videos:
120
+ self.labels.append(label)
121
+ # self.labels[id] = label
122
+ self.videos.append(id)
123
+ self.caption_t5s.append(caption_t5)
124
+ # else:
125
+ # missing_videos.append(id)
126
+
127
+ log.info(f'{len(videos)} videos found in {root}')
128
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
129
+ log.info(f'{len(missing_videos)} videos missing in {root}')
130
+
131
+ self.sample_rate = sample_rate
132
+ self.duration_sec = duration_sec
133
+
134
+ self.expected_audio_length = self.audio_samples
135
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
136
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
137
+
138
+ self.clip_transform = v2.Compose([
139
+ v2.Lambda(pad_to_square), # 先填充为正方形
140
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
141
+ v2.ToImage(),
142
+ v2.ToDtype(torch.float32, scale=True),
143
+ ])
144
+ self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
145
+ self.sync_transform = v2.Compose([
146
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
147
+ v2.CenterCrop(_SYNC_SIZE),
148
+ v2.ToImage(),
149
+ v2.ToDtype(torch.float32, scale=True),
150
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
151
+ ])
152
+
153
+ self.resampler = {}
154
+
155
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
156
+ video_id = self.videos[idx]
157
+ label = self.labels[idx]
158
+ caption_t5 = self.caption_t5s[idx]
159
+
160
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
161
+ reader.add_basic_video_stream(
162
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
163
+ frame_rate=_CLIP_FPS,
164
+ format='rgb24',
165
+ )
166
+ reader.add_basic_video_stream(
167
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
168
+ frame_rate=_SYNC_FPS,
169
+ format='rgb24',
170
+ )
171
+ # reader.add_basic_audio_stream(frames_per_chunk=2**30,)
172
+
173
+ reader.fill_buffer()
174
+ data_chunk = reader.pop_chunks()
175
+
176
+ clip_chunk = data_chunk[0]
177
+ sync_chunk = data_chunk[1]
178
+ audio_path = os.path.join("dataset/3_Audioset/audios/sound",video_id+'.wav')
179
+ assert os.path.exists(audio_path), f'{audio_path} not exists'
180
+ audio_chunk, sr = torchaudio.load(audio_path)
181
+ # audio_chunk = data_chunk[2]
182
+ if len(audio_chunk.shape) != 2:
183
+ raise RuntimeError(f'error audio shape {video_id}')
184
+ if clip_chunk is None:
185
+ raise RuntimeError(f'CLIP video returned None {video_id}')
186
+
187
+ if sync_chunk is None:
188
+ raise RuntimeError(f'Sync video returned None {video_id}')
189
+ sample_rate = int(sr)
190
+ # audio_chunk = audio_chunk.transpose(0, 1)
191
+ abs_max = audio_chunk[0].abs().max()
192
+ # audio_chunk = audio_chunk.mean(dim=0) # mono
193
+ # if self.normalize_audio:
194
+ # abs_max = audio_chunk.abs().max()
195
+ # audio_chunk = audio_chunk / abs_max * 0.95
196
+ if abs_max <= 1e-6:
197
+ if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
198
+ audio_chunk = audio_chunk[1:2]
199
+ else:
200
+ raise RuntimeError(f'Audio is silent {video_id}')
201
+
202
+ # ensure the stereo audio
203
+ if audio_chunk.shape[0] < 2:
204
+ audio_chunk = audio_chunk.repeat(2, 1)
205
+
206
+ # resample
207
+ if sample_rate == self.sample_rate:
208
+ audio_chunk = audio_chunk
209
+ else:
210
+ if sample_rate not in self.resampler:
211
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
212
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
213
+ sample_rate,
214
+ self.sample_rate,
215
+ lowpass_filter_width=64,
216
+ rolloff=0.9475937167399596,
217
+ resampling_method='sinc_interp_kaiser',
218
+ beta=14.769656459379492,
219
+ )
220
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
221
+
222
+ if audio_chunk.shape[1] < self.expected_audio_length:
223
+ # zero-padding audio
224
+ padding_length = self.expected_audio_length - audio_chunk.shape[1]
225
+ # 创建 padding 张量,大小为 [batch_size, padding_length],值为0
226
+ padding = torch.zeros(audio_chunk.shape[0], padding_length)
227
+ # 将原始音频和 padding 沿第 1 维度拼接在一起
228
+ audio_chunk = torch.cat((audio_chunk, padding), dim=1)
229
+ # raise RuntimeError(f'Audio too short {video_id}')
230
+ audio_chunk = audio_chunk[:,:self.expected_audio_length]
231
+ # truncate the video
232
+ clip_chunk = clip_chunk[:self.clip_expected_length]
233
+ # import ipdb
234
+ # ipdb.set_trace()
235
+ if clip_chunk.shape[0] != self.clip_expected_length:
236
+ current_length = clip_chunk.shape[0]
237
+ padding_needed = self.clip_expected_length - current_length
238
+
239
+ # Check that padding needed is no more than 2
240
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
241
+
242
+ # If assertion passes, proceed with padding
243
+ if padding_needed > 0:
244
+ last_frame = clip_chunk[-1]
245
+ log.info(last_frame.shape)
246
+ # Repeat the last frame to reach the expected length
247
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
248
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
249
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
250
+ # f'expected {self.clip_expected_length}, '
251
+ # f'got {clip_chunk.shape[0]}')
252
+
253
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
254
+ clip_chunk = pad_to_square(clip_chunk)
255
+ # save_image(clip_chunk[0] / 255.0,'square.png')
256
+ # clip_chunk = self.clip_transform(clip_chunk)
257
+ # import ipdb
258
+ # ipdb.set_trace()
259
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
260
+ # log.info(clip_chunk.shape)
261
+ # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
262
+ # log.info(clip_chunk[0])
263
+ # clip_chunk = outputs
264
+ # text_ids = outputs["input_ids"]
265
+ # temp_img = clip_chunk[0].permute(1, 2, 0) * 255
266
+ # save_image(clip_chunk[0],'scale.png')
267
+ sync_chunk = sync_chunk[:self.sync_expected_length]
268
+ if sync_chunk.shape[0] != self.sync_expected_length:
269
+ # padding using the last frame, but no more than 2
270
+ current_length = sync_chunk.shape[0]
271
+ last_frame = sync_chunk[-1]
272
+ # 重复最后一帧以进行填充
273
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
274
+ assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
275
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
276
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
277
+ # f'expected {self.sync_expected_length}, '
278
+ # f'got {sync_chunk.shape[0]}')
279
+
280
+ sync_chunk = self.sync_transform(sync_chunk)
281
+ assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
282
+ and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
283
+ data = {
284
+ 'id': video_id,
285
+ 'caption': label,
286
+ 'caption_t5': caption_t5,
287
+ 'audio': audio_chunk,
288
+ 'clip_video': clip_chunk,
289
+ 'sync_video': sync_chunk,
290
+ }
291
+
292
+ return data
293
+
294
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
295
+ try:
296
+ return self.sample(idx)
297
+ except Exception as e:
298
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
299
+ return None
300
+
301
+ def __len__(self):
302
+ return len(self.labels)
303
+
304
+
305
+ # dataset = Audioset(
306
+ # root="dataset/3_Audioset/video/sound",
307
+ # tsv_path="dataset/3_Audioset/split_txt/unbalanced_sound_filtered_aligned_novgg_noout.csv",
308
+ # sample_rate=44100,
309
+ # duration_sec=9.0,
310
+ # audio_samples=397312,
311
+ # start_row=0,
312
+ # end_row=None,
313
+ # save_dir="dataset/3_Audioset/video_text_latents/"
314
+ # )
315
+ # dataset[0]
data_utils/v2a_utils/audioset_video_224.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from PIL import Image
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+ from transformers import AutoProcessor
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ import logging
18
+ log = logging.getLogger()
19
+
20
+ _CLIP_SIZE = 224
21
+ _CLIP_FPS = 8.0
22
+
23
+ _SYNC_SIZE = 224
24
+ _SYNC_FPS = 25.0
25
+
26
+ def save_tensor_as_image(tensor, save_path):
27
+ """
28
+ 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
29
+
30
+ :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
31
+ :param save_path: 图片保存路径。
32
+ """
33
+ # # 移除批次维度,变成 (3, H, W)
34
+ # tensor = tensor.squeeze(0)
35
+
36
+ # 交换轴顺序,变为 (H, W, 3)
37
+ image_array = np.transpose(tensor, (1, 2, 0))
38
+
39
+ # 检查数组是否为合适的数据类型
40
+ if image_array.dtype != np.uint8:
41
+ # 如果不是 uint8,首先标准化,然后转换
42
+ image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
43
+ image_array = image_array.astype(np.uint8)
44
+
45
+ # 创建图像对象
46
+ image = Image.fromarray(image_array)
47
+
48
+ # 保存图片
49
+ image.save(save_path)
50
+ print(f"Image saved to {save_path}")
51
+
52
+ def pad_to_square(video_tensor):
53
+ # 验证输入的形状
54
+ if len(video_tensor.shape) != 4:
55
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
56
+
57
+ l, c, h, w = video_tensor.shape
58
+ max_side = max(h, w)
59
+
60
+ # 计算每一维度需要的填充量:(left, right, top, bottom)
61
+ pad_h = max_side - h
62
+ pad_w = max_side - w
63
+
64
+ # 创建padding tuple (left, right, top, bottom)
65
+ # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
66
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
67
+
68
+ # 使用F.pad对视频张量进行填充操作
69
+ # 填充参数为 (left, right, top, bottom)
70
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
71
+
72
+ return video_padded
73
+
74
+ class Audioset(Dataset):
75
+
76
+ def __init__(
77
+ self,
78
+ root: Union[str, Path],
79
+ *,
80
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
81
+ duration_sec: float = 10.0,
82
+ start_row: Optional[int] = None,
83
+ end_row: Optional[int] = None,
84
+ save_dir: str = 'data/vggsound/video_latents_text/train'
85
+ ):
86
+ self.root = Path(root)
87
+
88
+ # videos = sorted(os.listdir(self.root))
89
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
90
+ videos = []
91
+ self.captions = []
92
+ self.videos = []
93
+ self.caption_t5s = []
94
+ missing_videos = []
95
+ # read the tsv for subset information
96
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
97
+
98
+ # 控制处理的行范围
99
+ if start_row is not None and end_row is not None:
100
+ df_list = df_list[start_row:end_row]
101
+ with open(tsv_path.replace('.csv','.txt')) as file:
102
+ paths = file.readlines()
103
+ for record, path in zip(df_list,paths):
104
+ id = Path(record['id']).stem
105
+ # if os.path.exists(f'{save_dir}/{id}.pth'): continue
106
+ caption = record['caption']
107
+ caption_t5 = record['caption_t5']
108
+ path = path.strip()
109
+ part = Path(path).parent
110
+ video_id = Path(path).stem[1:]
111
+ video_path = os.path.join('dataset/3_Audioset/video',part,f'{video_id}.mp4')
112
+ assert os.path.exists(video_path), 'video must exist'
113
+ # if id in videos:
114
+ self.captions.append(caption)
115
+ self.caption_t5s.append(caption_t5)
116
+ # self.labels[id] = label
117
+ self.videos.append(video_path)
118
+ # else:
119
+ # missing_videos.append(id)
120
+ assert len(self.captions) == len(self.caption_t5s) and len(self.captions) == len(self.videos), 'error length'
121
+ log.info(f'{len(videos)} videos found in {root}')
122
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
123
+ log.info(f'{len(missing_videos)} videos missing in {root}')
124
+
125
+ self.duration_sec = duration_sec
126
+
127
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
128
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
129
+
130
+ self.clip_transform = v2.Compose([
131
+ v2.Lambda(pad_to_square), # 先填充为正方形
132
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
133
+ v2.ToImage(),
134
+ v2.ToDtype(torch.float32, scale=True),
135
+ ])
136
+ self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
137
+ self.sync_transform = v2.Compose([
138
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
139
+ v2.CenterCrop(_SYNC_SIZE),
140
+ v2.ToImage(),
141
+ v2.ToDtype(torch.float32, scale=True),
142
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
143
+ ])
144
+
145
+ self.resampler = {}
146
+
147
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
148
+ video_path = self.videos[idx]
149
+ video_id = 'Y'+str(Path(video_path).stem)
150
+ caption = self.captions[idx]
151
+ caption_t5 = self.caption_t5s[idx]
152
+
153
+ reader = StreamingMediaDecoder(video_path)
154
+ reader.add_basic_video_stream(
155
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
156
+ frame_rate=_CLIP_FPS,
157
+ format='rgb24',
158
+ )
159
+ reader.add_basic_video_stream(
160
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
161
+ frame_rate=_SYNC_FPS,
162
+ format='rgb24',
163
+ )
164
+
165
+ reader.fill_buffer()
166
+ data_chunk = reader.pop_chunks()
167
+
168
+ clip_chunk = data_chunk[0]
169
+ sync_chunk = data_chunk[1]
170
+
171
+ if clip_chunk is None:
172
+ raise RuntimeError(f'CLIP video returned None {video_id}')
173
+ # if clip_chunk.shape[0] < self.clip_expected_length:
174
+ # raise RuntimeError(
175
+ # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
176
+ # )
177
+
178
+ if sync_chunk is None:
179
+ raise RuntimeError(f'Sync video returned None {video_id}')
180
+ # if sync_chunk.shape[0] < self.sync_expected_length:
181
+ # raise RuntimeError(
182
+ # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
183
+ # )
184
+
185
+
186
+ # truncate the video
187
+ clip_chunk = clip_chunk[:self.clip_expected_length]
188
+ # import ipdb
189
+ # ipdb.set_trace()
190
+ if clip_chunk.shape[0] != self.clip_expected_length:
191
+ current_length = clip_chunk.shape[0]
192
+ padding_needed = self.clip_expected_length - current_length
193
+
194
+ # Check that padding needed is no more than 2
195
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
196
+
197
+ # If assertion passes, proceed with padding
198
+ if padding_needed > 0:
199
+ last_frame = clip_chunk[-1]
200
+ log.info(clip_chunk.shape)
201
+ # Repeat the last frame to reach the expected length
202
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
203
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
204
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
205
+ # f'expected {self.clip_expected_length}, '
206
+ # f'got {clip_chunk.shape[0]}')
207
+
208
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
209
+ clip_chunk = pad_to_square(clip_chunk)
210
+ # save_image(clip_chunk[0] / 255.0,'square.png')
211
+ # clip_chunk = self.clip_transform(clip_chunk)
212
+ # import ipdb
213
+ # ipdb.set_trace()
214
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
215
+ # log.info(clip_chunk.shape)
216
+ # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
217
+ # log.info(clip_chunk[0])
218
+ # clip_chunk = outputs
219
+ # text_ids = outputs["input_ids"]
220
+ # temp_img = clip_chunk[0].permute(1, 2, 0) * 255
221
+ # save_image(clip_chunk[0],'scale.png')
222
+ sync_chunk = sync_chunk[:self.sync_expected_length]
223
+ if sync_chunk.shape[0] != self.sync_expected_length:
224
+ # padding using the last frame, but no more than 2
225
+ current_length = sync_chunk.shape[0]
226
+ last_frame = sync_chunk[-1]
227
+ # 重复最后一帧以进行填充
228
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
229
+ assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
230
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
231
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
232
+ # f'expected {self.sync_expected_length}, '
233
+ # f'got {sync_chunk.shape[0]}')
234
+
235
+ sync_chunk = self.sync_transform(sync_chunk)
236
+ assert clip_chunk.shape[0] == self.clip_expected_length and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
237
+ data = {
238
+ 'id': video_id,
239
+ 'caption': caption,
240
+ 'caption_t5': caption_t5,
241
+ 'clip_video': clip_chunk,
242
+ 'sync_video': sync_chunk,
243
+ }
244
+
245
+ return data
246
+
247
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
248
+ try:
249
+ return self.sample(idx)
250
+ except Exception as e:
251
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
252
+ return None
253
+
254
+ def __len__(self):
255
+ return len(self.captions)
256
+
257
+
258
+ # dataset = VGGSound(
259
+ # root="data/vggsound/video/train",
260
+ # tsv_path="data/vggsound/split_txt/temp.csv",
261
+ # sample_rate=44100,
262
+ # duration_sec=9.0,
263
+ # audio_samples=397312,
264
+ # start_row=0,
265
+ # end_row=None,
266
+ # save_dir="data/vggsound/video_224_latents_text/train"
267
+ # )
268
+ # dataset[0]
data_utils/v2a_utils/feature_utils_224.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+ import json
3
+ import open_clip
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from open_clip import create_model_from_pretrained
9
+ from torchvision.transforms import Normalize
10
+ from ThinkSound.models.factory import create_model_from_config
11
+ from ThinkSound.models.utils import load_ckpt_state_dict
12
+ from ThinkSound.training.utils import copy_state_dict
13
+ from transformers import AutoModel
14
+ from transformers import AutoProcessor
15
+ from transformers import T5EncoderModel, AutoTokenizer
16
+ import logging
17
+ from data_utils.ext.synchformer import Synchformer
18
+
19
+ log = logging.getLogger()
20
+
21
+ def patch_clip(clip_model):
22
+ # a hack to make it output last hidden states
23
+ # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
24
+ def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
25
+ output_attentions: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ return_dict: Optional[bool] = None):
28
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
29
+ output_hidden_states = (
30
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
31
+ )
32
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
33
+
34
+ text_outputs = self.text_model(
35
+ input_ids=input_ids,
36
+ attention_mask=attention_mask,
37
+ position_ids=position_ids,
38
+ output_attentions=output_attentions,
39
+ output_hidden_states=output_hidden_states,
40
+ return_dict=return_dict,
41
+ )
42
+ last_hidden_state = text_outputs[0]
43
+ pooled_output = text_outputs[1]
44
+ text_features = self.text_projection(pooled_output)
45
+
46
+ return text_features, last_hidden_state
47
+
48
+ clip_model.get_text_features = new_get_text_features.__get__(clip_model)
49
+ return clip_model
50
+
51
+
52
+ class FeaturesUtils(nn.Module):
53
+
54
+ def __init__(
55
+ self,
56
+ *,
57
+ vae_ckpt: Optional[str] = None,
58
+ vae_config: Optional[str] = None,
59
+ synchformer_ckpt: Optional[str] = None,
60
+ enable_conditions: bool = True,
61
+ need_vae_encoder: bool = True,
62
+ ):
63
+ super().__init__()
64
+
65
+ if enable_conditions:
66
+ self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
67
+ self.clip_model = patch_clip(self.clip_model)
68
+ self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
69
+ self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl")
70
+ self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
71
+ # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
72
+ # std=[0.26862954, 0.26130258, 0.27577711])
73
+ self.synchformer = Synchformer()
74
+ self.synchformer.load_state_dict(
75
+ torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
76
+
77
+ # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
78
+ else:
79
+ self.clip_model = None
80
+ self.synchformer = None
81
+ self.tokenizer = None
82
+
83
+ if vae_ckpt is not None:
84
+ with open(vae_config) as f:
85
+ vae_config = json.load(f)
86
+ self.vae = create_model_from_config(vae_config)
87
+ print(f"Loading model checkpoint from {vae_ckpt}")
88
+ # Load checkpoint
89
+ copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
90
+ else:
91
+ self.tod = None
92
+
93
+ def compile(self):
94
+ if self.clip_model is not None:
95
+ self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
96
+ self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
97
+ if self.synchformer is not None:
98
+ self.synchformer = torch.compile(self.synchformer)
99
+
100
+
101
+ def train(self, mode: bool) -> None:
102
+ return super().train(False)
103
+
104
+ @torch.inference_mode()
105
+ def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
106
+ assert self.clip_model is not None, 'CLIP is not loaded'
107
+ # x: (B, T, C, H, W) H/W: 384
108
+ b, t, c, h, w = x.shape
109
+
110
+ assert c == 3 and h == 224 and w == 224
111
+ # x = self.clip_preprocess(x)
112
+ x = rearrange(x, 'b t c h w -> (b t) c h w')
113
+ outputs = []
114
+ if batch_size < 0:
115
+ batch_size = b * t
116
+ for i in range(0, b * t, batch_size):
117
+ outputs.append(self.clip_model.get_image_features(x[i:i + batch_size]))
118
+ x = torch.cat(outputs, dim=0)
119
+ # x = self.clip_model.encode_image(x, normalize=True)
120
+ x = rearrange(x, '(b t) d -> b t d', b=b)
121
+ return x
122
+
123
+ @torch.inference_mode()
124
+ def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
125
+ assert self.synchformer is not None, 'Synchformer is not loaded'
126
+ # x: (B, T, C, H, W) H/W: 384
127
+ b, t, c, h, w = x.shape
128
+ # import ipdb
129
+ # ipdb.set_trace()
130
+ assert c == 3 and h == 224 and w == 224
131
+
132
+ # partition the video
133
+ segment_size = 16
134
+ step_size = 8
135
+ num_segments = (t - segment_size) // step_size + 1
136
+ segments = []
137
+ for i in range(num_segments):
138
+ segments.append(x[:, i * step_size:i * step_size + segment_size])
139
+ x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
140
+
141
+ outputs = []
142
+ if batch_size < 0:
143
+ batch_size = b
144
+ x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
145
+ for i in range(0, b * num_segments, batch_size):
146
+ outputs.append(self.synchformer(x[i:i + batch_size]))
147
+ x = torch.cat(outputs, dim=0)
148
+ x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
149
+ return x
150
+
151
+ @torch.inference_mode()
152
+ def encode_text(self, text: list[str]) -> torch.Tensor:
153
+ assert self.clip_model is not None, 'CLIP is not loaded'
154
+ # assert self.tokenizer is not None, 'Tokenizer is not loaded'
155
+ # x: (B, L)
156
+ tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device)
157
+ return self.clip_model.get_text_features(**tokens)
158
+
159
+ @torch.inference_mode()
160
+ def encode_t5_text(self, text: list[str]) -> torch.Tensor:
161
+ assert self.t5_model is not None, 'T5 model is not loaded'
162
+ assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded'
163
+ # x: (B, L)
164
+ inputs = self.t5_tokenizer(text,
165
+ truncation=True,
166
+ max_length=77,
167
+ padding="max_length",
168
+ return_tensors="pt").to(self.device)
169
+ return self.t5_model(**inputs).last_hidden_state
170
+
171
+ @torch.inference_mode()
172
+ def encode_audio(self, x) -> torch.Tensor:
173
+ x = self.vae.encode(x)
174
+ return x
175
+
176
+ @property
177
+ def device(self):
178
+ return next(self.parameters()).device
179
+
180
+ @property
181
+ def dtype(self):
182
+ return next(self.parameters()).dtype
data_utils/v2a_utils/vggsound.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+
14
+ log = logging.getLogger()
15
+
16
+ _CLIP_SIZE = 384
17
+ _CLIP_FPS = 8.0
18
+
19
+ _SYNC_SIZE = 224
20
+ _SYNC_FPS = 25.0
21
+
22
+
23
+ class VGGSound(Dataset):
24
+
25
+ def __init__(
26
+ self,
27
+ root: Union[str, Path],
28
+ *,
29
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
30
+ sample_rate: int = 44_100,
31
+ duration_sec: float = 9.0,
32
+ audio_samples: Optional[int] = 397312,
33
+ normalize_audio: bool = False,
34
+ start_row: Optional[int] = None,
35
+ end_row: Optional[int] = None,
36
+ save_dir: str = 'data/vggsound/video_latents_text/train'
37
+ ):
38
+ self.root = Path(root)
39
+ self.normalize_audio = normalize_audio
40
+ if audio_samples is None:
41
+ self.audio_samples = int(sample_rate * duration_sec)
42
+ else:
43
+ self.audio_samples = audio_samples
44
+ effective_duration = audio_samples / sample_rate
45
+ # make sure the duration is close enough, within 15ms
46
+ assert abs(effective_duration - duration_sec) < 0.015, \
47
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
48
+
49
+ videos = sorted(os.listdir(self.root))
50
+ videos = set([Path(v).stem for v in videos]) # remove extensions
51
+ # videos = []
52
+ self.labels = []
53
+ self.videos = []
54
+ missing_videos = []
55
+ # read the tsv for subset information
56
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
57
+
58
+ # 控制处理的行范围
59
+ if start_row is not None and end_row is not None:
60
+ df_list = df_list[start_row:end_row]
61
+
62
+ for record in df_list:
63
+ id = record['id']
64
+ if os.path.exists(f'{save_dir}/{id}.pth'): continue
65
+ label = record['caption']
66
+ if id in videos:
67
+ # self.labels.append(label)
68
+ self.labels[id] = label
69
+ self.videos.append(id)
70
+ else:
71
+ missing_videos.append(id)
72
+
73
+ log.info(f'{len(videos)} videos found in {root}')
74
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
75
+ log.info(f'{len(missing_videos)} videos missing in {root}')
76
+
77
+ self.sample_rate = sample_rate
78
+ self.duration_sec = duration_sec
79
+
80
+ self.expected_audio_length = self.audio_samples
81
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
82
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
83
+
84
+ self.clip_transform = v2.Compose([
85
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
86
+ v2.ToImage(),
87
+ v2.ToDtype(torch.float32, scale=True),
88
+ ])
89
+
90
+ self.sync_transform = v2.Compose([
91
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
92
+ v2.CenterCrop(_SYNC_SIZE),
93
+ v2.ToImage(),
94
+ v2.ToDtype(torch.float32, scale=True),
95
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
96
+ ])
97
+
98
+ self.resampler = {}
99
+
100
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
101
+ video_id = self.videos[idx]
102
+ label = self.labels[idx]
103
+
104
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
105
+ reader.add_basic_video_stream(
106
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
107
+ frame_rate=_CLIP_FPS,
108
+ format='rgb24',
109
+ )
110
+ reader.add_basic_video_stream(
111
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
112
+ frame_rate=_SYNC_FPS,
113
+ format='rgb24',
114
+ )
115
+ reader.add_basic_audio_stream(frames_per_chunk=2**30,)
116
+
117
+ reader.fill_buffer()
118
+ data_chunk = reader.pop_chunks()
119
+
120
+ clip_chunk = data_chunk[0]
121
+ sync_chunk = data_chunk[1]
122
+ audio_chunk = data_chunk[2]
123
+ if len(audio_chunk.shape) != 2:
124
+ raise RuntimeError(f'error audio shape {video_id}')
125
+ if clip_chunk is None:
126
+ raise RuntimeError(f'CLIP video returned None {video_id}')
127
+ # if clip_chunk.shape[0] < self.clip_expected_length:
128
+ # raise RuntimeError(
129
+ # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
130
+ # )
131
+
132
+ if sync_chunk is None:
133
+ raise RuntimeError(f'Sync video returned None {video_id}')
134
+ # if sync_chunk.shape[0] < self.sync_expected_length:
135
+ # raise RuntimeError(
136
+ # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
137
+ # )
138
+ # import ipdb
139
+ # ipdb.set_trace()
140
+ # process audio
141
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
142
+ audio_chunk = audio_chunk.transpose(0, 1)
143
+ abs_max = audio_chunk[0].abs().max()
144
+ # audio_chunk = audio_chunk.mean(dim=0) # mono
145
+ # if self.normalize_audio:
146
+ # abs_max = audio_chunk.abs().max()
147
+ # audio_chunk = audio_chunk / abs_max * 0.95
148
+ if abs_max <= 1e-6:
149
+ if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
150
+ audio_chunk = audio_chunk[1:2]
151
+ else:
152
+ raise RuntimeError(f'Audio is silent {video_id}')
153
+
154
+
155
+ # if abs_max <= 1e-6:
156
+ # raise RuntimeError(f'Audio is silent {video_id}')
157
+
158
+ # ensure the stereo audio
159
+ if audio_chunk.shape[0] < 2:
160
+ audio_chunk = audio_chunk.repeat(2, 1)
161
+
162
+ # resample
163
+ if sample_rate == self.sample_rate:
164
+ audio_chunk = audio_chunk
165
+ else:
166
+ if sample_rate not in self.resampler:
167
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
168
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
169
+ sample_rate,
170
+ self.sample_rate,
171
+ lowpass_filter_width=64,
172
+ rolloff=0.9475937167399596,
173
+ resampling_method='sinc_interp_kaiser',
174
+ beta=14.769656459379492,
175
+ )
176
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
177
+
178
+ if audio_chunk.shape[1] < self.expected_audio_length:
179
+ # zero-padding audio
180
+ padding_length = self.expected_audio_length - audio_chunk.shape[1]
181
+ # 创建 padding 张量,大小为 [batch_size, padding_length],值为0
182
+ padding = torch.zeros(audio_chunk.shape[0], padding_length)
183
+ # 将原始音频和 padding 沿第 1 维度拼接在一起
184
+ audio_chunk = torch.cat((audio_chunk, padding), dim=1)
185
+ # raise RuntimeError(f'Audio too short {video_id}')
186
+ audio_chunk = audio_chunk[:,:self.expected_audio_length]
187
+ # truncate the video
188
+ clip_chunk = clip_chunk[:self.clip_expected_length]
189
+ # import ipdb
190
+ # ipdb.set_trace()
191
+ if clip_chunk.shape[0] != self.clip_expected_length:
192
+ current_length = clip_chunk.shape[0]
193
+ padding_needed = self.clip_expected_length - current_length
194
+
195
+ # Check that padding needed is no more than 2
196
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
197
+
198
+ # If assertion passes, proceed with padding
199
+ if padding_needed > 0:
200
+ last_frame = clip_chunk[-1]
201
+ log.info(last_frame.shape)
202
+ # Repeat the last frame to reach the expected length
203
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
204
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
205
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
206
+ # f'expected {self.clip_expected_length}, '
207
+ # f'got {clip_chunk.shape[0]}')
208
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
209
+ clip_chunk = self.clip_transform(clip_chunk)
210
+ # temp_img = clip_chunk[0].permute(1, 2, 0) * 255
211
+ # save_image(clip_chunk[0],'scale.png')
212
+ sync_chunk = sync_chunk[:self.sync_expected_length]
213
+ if sync_chunk.shape[0] != self.sync_expected_length:
214
+ # padding using the last frame, but no more than 2
215
+ current_length = sync_chunk.shape[0]
216
+ last_frame = sync_chunk[-1]
217
+ # 重复最后一帧以进行填充
218
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
219
+ assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
220
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
221
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
222
+ # f'expected {self.sync_expected_length}, '
223
+ # f'got {sync_chunk.shape[0]}')
224
+
225
+ sync_chunk = self.sync_transform(sync_chunk)
226
+ assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
227
+ and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
228
+ data = {
229
+ 'id': video_id,
230
+ 'caption': label,
231
+ 'audio': audio_chunk,
232
+ 'clip_video': clip_chunk,
233
+ 'sync_video': sync_chunk,
234
+ }
235
+
236
+ return data
237
+
238
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
239
+ try:
240
+ return self.sample(idx)
241
+ except Exception as e:
242
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
243
+ return None
244
+
245
+ def __len__(self):
246
+ return len(self.labels)
247
+
248
+
249
+ # dataset = VGGSound(
250
+ # root="data/vggsound/video/test",
251
+ # tsv_path="data/vggsound/split_txt/temp.csv",
252
+ # sample_rate=44100,
253
+ # duration_sec=9.0,
254
+ # audio_samples=397312,
255
+ # start_row=0,
256
+ # end_row=None,
257
+ # save_dir="data/vggsound/video_latents_text/test"
258
+ # )
259
+ # dataset[0]
data_utils/v2a_utils/vggsound_224.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from PIL import Image
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+ from transformers import AutoProcessor
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ import logging
18
+ log = logging.getLogger()
19
+
20
+ _CLIP_SIZE = 224
21
+ _CLIP_FPS = 8.0
22
+
23
+ _SYNC_SIZE = 224
24
+ _SYNC_FPS = 25.0
25
+
26
+ def save_tensor_as_image(tensor, save_path):
27
+ """
28
+ 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
29
+
30
+ :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
31
+ :param save_path: 图片保存路径。
32
+ """
33
+ # # 移除批次维度,变成 (3, H, W)
34
+ # tensor = tensor.squeeze(0)
35
+
36
+ # 交换轴顺序,变为 (H, W, 3)
37
+ image_array = np.transpose(tensor, (1, 2, 0))
38
+
39
+ # 检查数组是否为合适的数据类型
40
+ if image_array.dtype != np.uint8:
41
+ # 如果不是 uint8,首先标准化,然后转换
42
+ image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
43
+ image_array = image_array.astype(np.uint8)
44
+
45
+ # 创建图像对象
46
+ image = Image.fromarray(image_array)
47
+
48
+ # 保存图片
49
+ image.save(save_path)
50
+ print(f"Image saved to {save_path}")
51
+
52
+ def pad_to_square(video_tensor):
53
+ # 验证输入的形状
54
+ if len(video_tensor.shape) != 4:
55
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
56
+
57
+ l, c, h, w = video_tensor.shape
58
+ max_side = max(h, w)
59
+
60
+ # 计算每一维度需要的填充量:(left, right, top, bottom)
61
+ pad_h = max_side - h
62
+ pad_w = max_side - w
63
+
64
+ # 创建padding tuple (left, right, top, bottom)
65
+ # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
66
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
67
+
68
+ # 使用F.pad对视频张量进行填充操作
69
+ # 填充参数为 (left, right, top, bottom)
70
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
71
+
72
+ return video_padded
73
+
74
+ class VGGSound(Dataset):
75
+
76
+ def __init__(
77
+ self,
78
+ root: Union[str, Path],
79
+ *,
80
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
81
+ sample_rate: int = 44_100,
82
+ duration_sec: float = 9.0,
83
+ audio_samples: Optional[int] = 397312,
84
+ normalize_audio: bool = False,
85
+ start_row: Optional[int] = None,
86
+ end_row: Optional[int] = None,
87
+ save_dir: str = 'data/vggsound/video_latents_text/train'
88
+ ):
89
+ self.root = Path(root)
90
+ self.normalize_audio = normalize_audio
91
+ if audio_samples is None:
92
+ self.audio_samples = int(sample_rate * duration_sec)
93
+ else:
94
+ self.audio_samples = audio_samples
95
+ effective_duration = audio_samples / sample_rate
96
+ # make sure the duration is close enough, within 15ms
97
+ assert abs(effective_duration - duration_sec) < 0.015, \
98
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
99
+
100
+ # videos = sorted(os.listdir(self.root))
101
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
102
+ videos = []
103
+ self.labels = []
104
+ self.videos = []
105
+ missing_videos = []
106
+ # read the tsv for subset information
107
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
108
+
109
+ # 控制处理的行范围
110
+ if start_row is not None and end_row is not None:
111
+ df_list = df_list[start_row:end_row]
112
+
113
+ for record in df_list:
114
+ id = record['id']
115
+ if os.path.exists(f'{save_dir}/{id}.pth'): continue
116
+ label = record['label']
117
+ # if id in videos:
118
+ self.labels.append(label)
119
+ # self.labels[id] = label
120
+ self.videos.append(id)
121
+ # else:
122
+ # missing_videos.append(id)
123
+
124
+ log.info(f'{len(videos)} videos found in {root}')
125
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
126
+ log.info(f'{len(missing_videos)} videos missing in {root}')
127
+
128
+ self.sample_rate = sample_rate
129
+ self.duration_sec = duration_sec
130
+
131
+ self.expected_audio_length = self.audio_samples
132
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
133
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
134
+
135
+ self.clip_transform = v2.Compose([
136
+ v2.Lambda(pad_to_square), # 先填充为正方形
137
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
138
+ v2.ToImage(),
139
+ v2.ToDtype(torch.float32, scale=True),
140
+ ])
141
+ self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
142
+ self.sync_transform = v2.Compose([
143
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
144
+ v2.CenterCrop(_SYNC_SIZE),
145
+ v2.ToImage(),
146
+ v2.ToDtype(torch.float32, scale=True),
147
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
148
+ ])
149
+
150
+ self.resampler = {}
151
+
152
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
153
+ video_id = self.videos[idx]
154
+ label = self.labels[idx]
155
+
156
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
157
+ reader.add_basic_video_stream(
158
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
159
+ frame_rate=_CLIP_FPS,
160
+ format='rgb24',
161
+ )
162
+ reader.add_basic_video_stream(
163
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
164
+ frame_rate=_SYNC_FPS,
165
+ format='rgb24',
166
+ )
167
+ reader.add_basic_audio_stream(frames_per_chunk=2**30,)
168
+
169
+ reader.fill_buffer()
170
+ data_chunk = reader.pop_chunks()
171
+
172
+ clip_chunk = data_chunk[0]
173
+ sync_chunk = data_chunk[1]
174
+ audio_chunk = data_chunk[2]
175
+ if len(audio_chunk.shape) != 2:
176
+ raise RuntimeError(f'error audio shape {video_id}')
177
+ if clip_chunk is None:
178
+ raise RuntimeError(f'CLIP video returned None {video_id}')
179
+ # if clip_chunk.shape[0] < self.clip_expected_length:
180
+ # raise RuntimeError(
181
+ # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
182
+ # )
183
+
184
+ if sync_chunk is None:
185
+ raise RuntimeError(f'Sync video returned None {video_id}')
186
+ # if sync_chunk.shape[0] < self.sync_expected_length:
187
+ # raise RuntimeError(
188
+ # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
189
+ # )
190
+ # import ipdb
191
+ # ipdb.set_trace()
192
+ # process audio
193
+ # import ipdb
194
+ # ipdb.set_trace()
195
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
196
+ audio_chunk = audio_chunk.transpose(0, 1)
197
+ abs_max = audio_chunk[0].abs().max()
198
+ # audio_chunk = audio_chunk.mean(dim=0) # mono
199
+ # if self.normalize_audio:
200
+ # abs_max = audio_chunk.abs().max()
201
+ # audio_chunk = audio_chunk / abs_max * 0.95
202
+ if abs_max <= 1e-6:
203
+ if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
204
+ audio_chunk = audio_chunk[1:2]
205
+ else:
206
+ raise RuntimeError(f'Audio is silent {video_id}')
207
+
208
+ # ensure the stereo audio
209
+ if audio_chunk.shape[0] < 2:
210
+ audio_chunk = audio_chunk.repeat(2, 1)
211
+
212
+ # resample
213
+ if sample_rate == self.sample_rate:
214
+ audio_chunk = audio_chunk
215
+ else:
216
+ if sample_rate not in self.resampler:
217
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
218
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
219
+ sample_rate,
220
+ self.sample_rate,
221
+ lowpass_filter_width=64,
222
+ rolloff=0.9475937167399596,
223
+ resampling_method='sinc_interp_kaiser',
224
+ beta=14.769656459379492,
225
+ )
226
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
227
+
228
+ if audio_chunk.shape[1] < self.expected_audio_length:
229
+ # zero-padding audio
230
+ padding_length = self.expected_audio_length - audio_chunk.shape[1]
231
+ # 创建 padding 张量,大小为 [batch_size, padding_length],值为0
232
+ padding = torch.zeros(audio_chunk.shape[0], padding_length)
233
+ # 将原始音频和 padding 沿第 1 维度拼接在一起
234
+ audio_chunk = torch.cat((audio_chunk, padding), dim=1)
235
+ # raise RuntimeError(f'Audio too short {video_id}')
236
+ audio_chunk = audio_chunk[:,:self.expected_audio_length]
237
+ # truncate the video
238
+ clip_chunk = clip_chunk[:self.clip_expected_length]
239
+ # import ipdb
240
+ # ipdb.set_trace()
241
+ if clip_chunk.shape[0] != self.clip_expected_length:
242
+ current_length = clip_chunk.shape[0]
243
+ padding_needed = self.clip_expected_length - current_length
244
+
245
+ # Check that padding needed is no more than 2
246
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
247
+
248
+ # If assertion passes, proceed with padding
249
+ if padding_needed > 0:
250
+ last_frame = clip_chunk[-1]
251
+ log.info(last_frame.shape)
252
+ # Repeat the last frame to reach the expected length
253
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
254
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
255
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
256
+ # f'expected {self.clip_expected_length}, '
257
+ # f'got {clip_chunk.shape[0]}')
258
+
259
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
260
+ clip_chunk = pad_to_square(clip_chunk)
261
+ # save_image(clip_chunk[0] / 255.0,'square.png')
262
+ # clip_chunk = self.clip_transform(clip_chunk)
263
+ # import ipdb
264
+ # ipdb.set_trace()
265
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
266
+ # log.info(clip_chunk.shape)
267
+ # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
268
+ # log.info(clip_chunk[0])
269
+ # clip_chunk = outputs
270
+ # text_ids = outputs["input_ids"]
271
+ # temp_img = clip_chunk[0].permute(1, 2, 0) * 255
272
+ # save_image(clip_chunk[0],'scale.png')
273
+ sync_chunk = sync_chunk[:self.sync_expected_length]
274
+ if sync_chunk.shape[0] != self.sync_expected_length:
275
+ # padding using the last frame, but no more than 2
276
+ current_length = sync_chunk.shape[0]
277
+ last_frame = sync_chunk[-1]
278
+ # 重复最后一帧以进行填充
279
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
280
+ assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
281
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
282
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
283
+ # f'expected {self.sync_expected_length}, '
284
+ # f'got {sync_chunk.shape[0]}')
285
+
286
+ sync_chunk = self.sync_transform(sync_chunk)
287
+ assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
288
+ and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
289
+ data = {
290
+ 'id': video_id,
291
+ 'caption': label,
292
+ 'audio': audio_chunk,
293
+ 'clip_video': clip_chunk,
294
+ 'sync_video': sync_chunk,
295
+ }
296
+
297
+ return data
298
+
299
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
300
+ try:
301
+ return self.sample(idx)
302
+ except Exception as e:
303
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
304
+ return None
305
+
306
+ def __len__(self):
307
+ return len(self.labels)
308
+
309
+
310
+ # dataset = VGGSound(
311
+ # root="data/vggsound/video/train",
312
+ # tsv_path="data/vggsound/split_txt/temp.csv",
313
+ # sample_rate=44100,
314
+ # duration_sec=9.0,
315
+ # audio_samples=397312,
316
+ # start_row=0,
317
+ # end_row=None,
318
+ # save_dir="data/vggsound/video_224_latents_text/train"
319
+ # )
320
+ # dataset[0]
data_utils/v2a_utils/vggsound_224_no_audio.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from PIL import Image
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+ from transformers import AutoProcessor
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ import logging
18
+ log = logging.getLogger()
19
+
20
+ _CLIP_SIZE = 224
21
+ _CLIP_FPS = 8.0
22
+
23
+ _SYNC_SIZE = 224
24
+ _SYNC_FPS = 25.0
25
+
26
+ def save_tensor_as_image(tensor, save_path):
27
+ """
28
+ 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
29
+
30
+ :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
31
+ :param save_path: 图片保存路径。
32
+ """
33
+ # # 移除批次维度,变成 (3, H, W)
34
+ # tensor = tensor.squeeze(0)
35
+
36
+ # 交换轴顺序,变为 (H, W, 3)
37
+ image_array = np.transpose(tensor, (1, 2, 0))
38
+
39
+ # 检查数组是否为合适的数据类型
40
+ if image_array.dtype != np.uint8:
41
+ # 如果不是 uint8,首先标准化,然后转换
42
+ image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
43
+ image_array = image_array.astype(np.uint8)
44
+
45
+ # 创建图像对象
46
+ image = Image.fromarray(image_array)
47
+
48
+ # 保存图片
49
+ image.save(save_path)
50
+ print(f"Image saved to {save_path}")
51
+
52
+ def pad_to_square(video_tensor):
53
+ # 验证输入的形状
54
+ if len(video_tensor.shape) != 4:
55
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
56
+
57
+ l, c, h, w = video_tensor.shape
58
+ max_side = max(h, w)
59
+
60
+ # 计算每一维度需要的填充量:(left, right, top, bottom)
61
+ pad_h = max_side - h
62
+ pad_w = max_side - w
63
+
64
+ # 创建padding tuple (left, right, top, bottom)
65
+ # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
66
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
67
+
68
+ # 使用F.pad对视频张量进行填充操作
69
+ # 填充参数为 (left, right, top, bottom)
70
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
71
+
72
+ return video_padded
73
+
74
+ class VGGSound(Dataset):
75
+
76
+ def __init__(
77
+ self,
78
+ root: Union[str, Path],
79
+ *,
80
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
81
+ sample_rate: int = 44_100,
82
+ duration_sec: float = 9.0,
83
+ audio_samples: Optional[int] = 397312,
84
+ normalize_audio: bool = False,
85
+ start_row: Optional[int] = None,
86
+ end_row: Optional[int] = None,
87
+ save_dir: str = 'data/vggsound/video_latents_text/train'
88
+ ):
89
+ self.root = Path(root)
90
+ self.normalize_audio = normalize_audio
91
+ if audio_samples is None:
92
+ self.audio_samples = int(sample_rate * duration_sec)
93
+ else:
94
+ self.audio_samples = audio_samples
95
+ effective_duration = audio_samples / sample_rate
96
+ # make sure the duration is close enough, within 15ms
97
+ assert abs(effective_duration - duration_sec) < 0.015, \
98
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
99
+
100
+ # videos = sorted(os.listdir(self.root))
101
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
102
+ videos = []
103
+ self.labels = []
104
+ self.videos = []
105
+ self.caption_cot = []
106
+ missing_videos = []
107
+ # read the tsv for subset information
108
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
109
+
110
+ # 控制处理的行范围
111
+ if start_row is not None and end_row is not None:
112
+ df_list = df_list[start_row:end_row]
113
+
114
+ for record in df_list:
115
+ id = record['id']
116
+ if os.path.exists(f'{save_dir}/{id}.pth'): continue
117
+ label = record['caption']
118
+ caption_cot = record['caption_cot']
119
+ # if id in videos:
120
+ self.labels.append(label)
121
+ # self.labels[id] = label
122
+ self.videos.append(id)
123
+ self.caption_cot.append(caption_cot)
124
+ # else:
125
+ # missing_videos.append(id)
126
+
127
+ log.info(f'{len(videos)} videos found in {root}')
128
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
129
+ log.info(f'{len(missing_videos)} videos missing in {root}')
130
+
131
+ self.sample_rate = sample_rate
132
+ self.duration_sec = duration_sec
133
+
134
+ self.expected_audio_length = self.audio_samples
135
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
136
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
137
+
138
+ self.clip_transform = v2.Compose([
139
+ v2.Lambda(pad_to_square), # 先填充为正方形
140
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
141
+ v2.ToImage(),
142
+ v2.ToDtype(torch.float32, scale=True),
143
+ ])
144
+ self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
145
+ self.sync_transform = v2.Compose([
146
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
147
+ v2.CenterCrop(_SYNC_SIZE),
148
+ v2.ToImage(),
149
+ v2.ToDtype(torch.float32, scale=True),
150
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
151
+ ])
152
+
153
+ self.resampler = {}
154
+
155
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
156
+ video_id = self.videos[idx]
157
+ label = self.labels[idx]
158
+ caption_cot = self.caption_cot[idx]
159
+
160
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
161
+ reader.add_basic_video_stream(
162
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
163
+ frame_rate=_CLIP_FPS,
164
+ format='rgb24',
165
+ )
166
+ reader.add_basic_video_stream(
167
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
168
+ frame_rate=_SYNC_FPS,
169
+ format='rgb24',
170
+ )
171
+ # reader.add_basic_audio_stream(frames_per_chunk=2**30,)
172
+
173
+ reader.fill_buffer()
174
+ data_chunk = reader.pop_chunks()
175
+
176
+ clip_chunk = data_chunk[0]
177
+ sync_chunk = data_chunk[1]
178
+ # audio_chunk = data_chunk[2]
179
+ # if len(audio_chunk.shape) != 2:
180
+ # raise RuntimeError(f'error audio shape {video_id}')
181
+ if clip_chunk is None:
182
+ raise RuntimeError(f'CLIP video returned None {video_id}')
183
+ # if clip_chunk.shape[0] < self.clip_expected_length:
184
+ # raise RuntimeError(
185
+ # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
186
+ # )
187
+
188
+ if sync_chunk is None:
189
+ raise RuntimeError(f'Sync video returned None {video_id}')
190
+
191
+ # truncate the video
192
+ clip_chunk = clip_chunk[:self.clip_expected_length]
193
+ # import ipdb
194
+ # ipdb.set_trace()
195
+ if clip_chunk.shape[0] != self.clip_expected_length:
196
+ current_length = clip_chunk.shape[0]
197
+ padding_needed = self.clip_expected_length - current_length
198
+
199
+ # Check that padding needed is no more than 2
200
+ # assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
201
+
202
+ # If assertion passes, proceed with padding
203
+ if padding_needed > 0:
204
+ last_frame = clip_chunk[-1]
205
+ log.info(last_frame.shape)
206
+ # Repeat the last frame to reach the expected length
207
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
208
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
209
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
210
+ # f'expected {self.clip_expected_length}, '
211
+ # f'got {clip_chunk.shape[0]}')
212
+
213
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
214
+ clip_chunk = pad_to_square(clip_chunk)
215
+ # save_image(clip_chunk[0] / 255.0,'square.png')
216
+ # clip_chunk = self.clip_transform(clip_chunk)
217
+ # import ipdb
218
+ # ipdb.set_trace()
219
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
220
+ # log.info(clip_chunk.shape)
221
+ # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
222
+ # log.info(clip_chunk[0])
223
+ # clip_chunk = outputs
224
+ # text_ids = outputs["input_ids"]
225
+ # temp_img = clip_chunk[0].permute(1, 2, 0) * 255
226
+ # save_image(clip_chunk[0],'scale.png')
227
+ sync_chunk = sync_chunk[:self.sync_expected_length]
228
+ if sync_chunk.shape[0] != self.sync_expected_length:
229
+ # padding using the last frame, but no more than 2
230
+ current_length = sync_chunk.shape[0]
231
+ last_frame = sync_chunk[-1]
232
+ # 重复最后一帧以进行填充
233
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
234
+ # assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
235
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
236
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
237
+ # f'expected {self.sync_expected_length}, '
238
+ # f'got {sync_chunk.shape[0]}')
239
+
240
+ sync_chunk = self.sync_transform(sync_chunk)
241
+ # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
242
+ # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
243
+ data = {
244
+ 'id': video_id,
245
+ 'caption': label,
246
+ # 'audio': audio_chunk,
247
+ 'clip_video': clip_chunk,
248
+ 'sync_video': sync_chunk,
249
+ 'caption_cot': caption_cot,
250
+ }
251
+
252
+ return data
253
+
254
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
255
+ try:
256
+ return self.sample(idx)
257
+ except Exception as e:
258
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
259
+ return None
260
+
261
+ def __len__(self):
262
+ return len(self.labels)
263
+
264
+
265
+ # dataset = VGGSound(
266
+ # root="data/vggsound/video/train",
267
+ # tsv_path="data/vggsound/split_txt/temp.csv",
268
+ # sample_rate=44100,
269
+ # duration_sec=9.0,
270
+ # audio_samples=397312,
271
+ # start_row=0,
272
+ # end_row=None,
273
+ # save_dir="data/vggsound/video_224_latents_text/train"
274
+ # )
275
+ # dataset[0]
data_utils/v2a_utils/vggsound_224_no_sync.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+ from PIL import Image
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+ from transformers import AutoProcessor
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ import logging
18
+ log = logging.getLogger()
19
+
20
+ _CLIP_SIZE = 224
21
+ _CLIP_FPS = 8.0
22
+
23
+ _SYNC_SIZE = 224
24
+ _SYNC_FPS = 25.0
25
+
26
+ def save_tensor_as_image(tensor, save_path):
27
+ """
28
+ 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
29
+
30
+ :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
31
+ :param save_path: 图片保存路径。
32
+ """
33
+ # # 移除批次维度,变成 (3, H, W)
34
+ # tensor = tensor.squeeze(0)
35
+
36
+ # 交换轴顺序,变为 (H, W, 3)
37
+ image_array = np.transpose(tensor, (1, 2, 0))
38
+
39
+ # 检查数组是否为合适的数据类型
40
+ if image_array.dtype != np.uint8:
41
+ # 如果不是 uint8,首先标准化,然后转换
42
+ image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
43
+ image_array = image_array.astype(np.uint8)
44
+
45
+ # 创建图像对象
46
+ image = Image.fromarray(image_array)
47
+
48
+ # 保存图片
49
+ image.save(save_path)
50
+ print(f"Image saved to {save_path}")
51
+
52
+ def pad_to_square(video_tensor):
53
+ # 验证输入的形状
54
+ if len(video_tensor.shape) != 4:
55
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
56
+
57
+ l, c, h, w = video_tensor.shape
58
+ max_side = max(h, w)
59
+
60
+ # 计算每一维度需要的填充量:(left, right, top, bottom)
61
+ pad_h = max_side - h
62
+ pad_w = max_side - w
63
+
64
+ # 创建padding tuple (left, right, top, bottom)
65
+ # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
66
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
67
+
68
+ # 使用F.pad对视频张量进行填充操作
69
+ # 填充参数为 (left, right, top, bottom)
70
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
71
+
72
+ return video_padded
73
+
74
+ class VGGSound(Dataset):
75
+
76
+ def __init__(
77
+ self,
78
+ root: Union[str, Path],
79
+ *,
80
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
81
+ sample_rate: int = 44_100,
82
+ duration_sec: float = 9.0,
83
+ audio_samples: Optional[int] = 397312,
84
+ normalize_audio: bool = False,
85
+ start_row: Optional[int] = None,
86
+ end_row: Optional[int] = None,
87
+ save_dir: str = 'data/vggsound/video_latents_text/train'
88
+ ):
89
+ self.root = Path(root)
90
+ self.normalize_audio = normalize_audio
91
+ if audio_samples is None:
92
+ self.audio_samples = int(sample_rate * duration_sec)
93
+ else:
94
+ self.audio_samples = audio_samples
95
+ effective_duration = audio_samples / sample_rate
96
+ # make sure the duration is close enough, within 15ms
97
+ assert abs(effective_duration - duration_sec) < 0.015, \
98
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
99
+
100
+ # videos = sorted(os.listdir(self.root))
101
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
102
+ videos = []
103
+ self.labels = []
104
+ self.videos = []
105
+ missing_videos = []
106
+ # read the tsv for subset information
107
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
108
+
109
+ # 控制处理的行范围
110
+ if start_row is not None and end_row is not None:
111
+ df_list = df_list[start_row:end_row]
112
+
113
+ for record in df_list:
114
+ id = record['id']
115
+ if os.path.exists(f'{save_dir}/{id}.pth'): continue
116
+ label = record['label']
117
+ # if id in videos:
118
+ self.labels.append(label)
119
+ # self.labels[id] = label
120
+ self.videos.append(id)
121
+ # else:
122
+ # missing_videos.append(id)
123
+
124
+ log.info(f'{len(videos)} videos found in {root}')
125
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
126
+ log.info(f'{len(missing_videos)} videos missing in {root}')
127
+
128
+ self.sample_rate = sample_rate
129
+ self.duration_sec = duration_sec
130
+
131
+ self.expected_audio_length = self.audio_samples
132
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
133
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
134
+
135
+ self.clip_transform = v2.Compose([
136
+ v2.Lambda(pad_to_square), # 先填充为正方形
137
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
138
+ v2.ToImage(),
139
+ v2.ToDtype(torch.float32, scale=True),
140
+ ])
141
+ self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
142
+
143
+ self.resampler = {}
144
+
145
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
146
+ video_id = self.videos[idx]
147
+ label = self.labels[idx]
148
+
149
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
150
+ reader.add_basic_video_stream(
151
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
152
+ frame_rate=_CLIP_FPS,
153
+ format='rgb24',
154
+ )
155
+
156
+ reader.fill_buffer()
157
+ data_chunk = reader.pop_chunks()
158
+
159
+ clip_chunk = data_chunk[0]
160
+ if clip_chunk is None:
161
+ raise RuntimeError(f'CLIP video returned None {video_id}')
162
+
163
+
164
+ # truncate the video
165
+ clip_chunk = clip_chunk[:self.clip_expected_length]
166
+ # import ipdb
167
+ # ipdb.set_trace()
168
+ if clip_chunk.shape[0] != self.clip_expected_length:
169
+ current_length = clip_chunk.shape[0]
170
+ padding_needed = self.clip_expected_length - current_length
171
+
172
+ # Check that padding needed is no more than 2
173
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
174
+
175
+ # If assertion passes, proceed with padding
176
+ if padding_needed > 0:
177
+ last_frame = clip_chunk[-1]
178
+ log.info(last_frame.shape)
179
+ # Repeat the last frame to reach the expected length
180
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
181
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
182
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
183
+ # f'expected {self.clip_expected_length}, '
184
+ # f'got {clip_chunk.shape[0]}')
185
+
186
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
187
+ clip_chunk = pad_to_square(clip_chunk)
188
+ # save_image(clip_chunk[0] / 255.0,'square.png')
189
+ # clip_chunk = self.clip_transform(clip_chunk)
190
+ # import ipdb
191
+ # ipdb.set_trace()
192
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
193
+
194
+ data = {
195
+ 'id': video_id,
196
+ 'caption': label,
197
+ 'clip_video': clip_chunk,
198
+ }
199
+
200
+ return data
201
+
202
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
203
+ try:
204
+ return self.sample(idx)
205
+ except Exception as e:
206
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
207
+ return None
208
+
209
+ def __len__(self):
210
+ return len(self.labels)
211
+
212
+
213
+ # dataset = VGGSound(
214
+ # root="data/vggsound/video/train",
215
+ # tsv_path="data/vggsound/split_txt/temp.csv",
216
+ # sample_rate=44100,
217
+ # duration_sec=9.0,
218
+ # audio_samples=397312,
219
+ # start_row=0,
220
+ # end_row=None,
221
+ # save_dir="data/vggsound/video_224_latents_text/train"
222
+ # )
223
+ # dataset[0]
data_utils/v2a_utils/vggsound_text.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ from torchvision.utils import save_image
13
+
14
+ log = logging.getLogger()
15
+
16
+ _CLIP_SIZE = 384
17
+ _CLIP_FPS = 8.0
18
+
19
+ _SYNC_SIZE = 224
20
+ _SYNC_FPS = 25.0
21
+
22
+
23
+ class VGGSound(Dataset):
24
+
25
+ def __init__(
26
+ self,
27
+ root: Union[str, Path],
28
+ *,
29
+ tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
30
+ start_row: Optional[int] = None,
31
+ end_row: Optional[int] = None,
32
+ save_dir: str = 'data/vggsound/video_latents_text/train'
33
+ ):
34
+ self.root = Path(root)
35
+
36
+ # videos = sorted(os.listdir(self.root))
37
+ # videos = set([Path(v).stem for v in videos]) # remove extensions
38
+ videos = []
39
+ self.labels = []
40
+ self.cots = []
41
+ self.videos = []
42
+ missing_videos = []
43
+ # read the tsv for subset information
44
+ df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
45
+
46
+ # 控制处理的行范围
47
+ if start_row is not None and end_row is not None:
48
+ df_list = df_list[start_row:end_row]
49
+
50
+ for record in df_list:
51
+ id = record['id']
52
+ # if os.path.exists(f'{save_dir}/{id}.pth'):
53
+ # continue
54
+ # try:
55
+ # torch.load(f'{save_dir}/{id}.pth')
56
+ # continue
57
+ # except:
58
+ # print(f'error load file: {save_dir}/{id}.pth')
59
+ # os.system(f'rm -f {save_dir}/{id}.pth')
60
+ label = record['caption']
61
+ # if id in videos:
62
+ self.labels.append(label)
63
+ self.cots.append(record['caption_cot'])
64
+ # self.labels[id] = label
65
+ self.videos.append(id)
66
+ # else:
67
+ # missing_videos.append(id)
68
+
69
+ log.info(f'{len(videos)} videos found in {root}')
70
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
71
+ log.info(f'{len(missing_videos)} videos missing in {root}')
72
+
73
+
74
+
75
+
76
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
77
+ video_id = self.videos[idx]
78
+ label = self.labels[idx]
79
+ cot = self.cots[idx]
80
+ data = {
81
+ 'id': video_id,
82
+ 'caption': label,
83
+ 'caption_cot': cot
84
+ }
85
+
86
+ return data
87
+
88
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
89
+ try:
90
+ return self.sample(idx)
91
+ except Exception as e:
92
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
93
+ return None
94
+
95
+ def __len__(self):
96
+ return len(self.labels)
97
+
98
+
99
+ # dataset = VGGSound(
100
+ # root="data/vggsound/video/test",
101
+ # tsv_path="data/vggsound/split_txt/temp.csv",
102
+ # sample_rate=44100,
103
+ # duration_sec=9.0,
104
+ # audio_samples=397312,
105
+ # start_row=0,
106
+ # end_row=None,
107
+ # save_dir="data/vggsound/video_latents_text/test"
108
+ # )
109
+ # dataset[0]
defaults.ini ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ name = stable_audio_tools
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+ test_batch_size = 1
10
+
11
+ # predict ckpt directory
12
+ ckpt_dir = "ckpts"
13
+
14
+ # number of GPUs to use for training
15
+ num_gpus = 1
16
+
17
+ # number of nodes to use for training
18
+ num_nodes = 1
19
+
20
+ # Multi-GPU strategy for PyTorch Lightning
21
+ strategy = ""
22
+
23
+ # Precision to use for training
24
+ precision = "bf16-mixed"
25
+
26
+ # number of CPU workers for the DataLoader
27
+ num_workers = 8
28
+
29
+ # the random seed
30
+ seed = 42
31
+
32
+ # Batches for gradient accumulation
33
+ accum_batches = 1
34
+
35
+ # Number of steps between checkpoints
36
+ checkpoint_every = 2000
37
+
38
+ # trainer checkpoint file to restart training from
39
+ ckpt_path = ''
40
+
41
+ # model checkpoint file to start a new training run from
42
+ pretrained_ckpt_path = ''
43
+
44
+ # Checkpoint path for the pretransform model if needed
45
+ pretransform_ckpt_path = ''
46
+
47
+ # configuration model specifying model hyperparameters
48
+ model_config = ''
49
+
50
+ # configuration for datasets
51
+ dataset_config = ''
52
+
53
+ # directory to save the checkpoints in
54
+ save_dir = ''
55
+
56
+ # gradient_clip_val passed into PyTorch Lightning Trainer
57
+ gradient_clip_val = 0.0
58
+
59
+ # remove the weight norm from the pretransform model
60
+ remove_pretransform_weight_norm = ''
61
+
62
+ compile = False
63
+
64
+ repeat_num = 5
65
+
66
+ duration_sec = '9'
67
+
68
+ results_dir = 'results'
demo_test.csv ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,caption,caption_cot
2
+ W1nb2hIeDKc_000021,striking bowling,"Start with a background of ambient music, then add consistent sounds of bowling balls striking pins to emphasize the action. Include occasional subtle sounds of pins rattling and settling. Keep human voices or other noises minimal or absent for authenticity."
3
+ YYRdv32TJnc_000184,plastic bottle crushing,"Start with the sound of crushing plastic bottles, including crinkling and crunching. Add background noise resembling a factory environment, with machinery sounds. Incorporate subtle rustling and paper crinkling to suggest manipulation of plastic items."
4
+ Rp39_WnX5Fk_000380,"subway, metro, underground","Generate subway sounds including ambient station noise, train doors opening and closing, engine hum, wheels on tracks, and conductor announcements to produce an accurate underground train environment."
5
+ -KqXcm-I2zY_000087,playing tennis,"Generate sounds of tennis hitting a racket, the ball bouncing, and the girl’s grunts, with distant tennis court ambient noise. Avoid unrelated sounds like horses, basketballs, or indoor voices. Focus on clear tennis scene with realistic audio cues."
6
+ 0W_wPc-zV3I_000101,hedge trimmer running,"Generate the sound of a hedge trimmer running steadily, focusing on consistent motor noise and cutting sounds. Ensure minimal background noise or voices, capturing the primary sound of the trimmer in operation. Avoid including any chainsaw or unrelated sounds for accuracy."
7
+ _Betmm6FaWo_000096,writing on blackboard with chalk,"The audio should feature consistent sounds of chalk scratching the blackboard, including occasional voice instructions, encouragement, and children’s chatter, with background music playing softly or fading in/out to match the scene's atmosphere. The sounds of laughter and chatter should be lively but balanced with the primary chalk and voice sounds for clarity. Overall, the audio combines educational sounds with background activity to reflect a classroom or play environment."
8
+ xmTfE3F2huE_000854,chopping food,"Generate rhythmic chopping sounds consistent with meat or food being sliced, incorporating occasional rustling noises like a plastic bag. Avoid adding human voices or train sounds to match the correct audio descriptions, ensuring a focused, realistic kitchen chopping scene."
9
+ ZaUaqnLdg6k_000030,skateboarding,"Generate the audio featuring skateboarding sounds with wheels rolling on various surfaces, including ramps, rails, and sidewalks, capturing the sound of tricks and landings. Include subtle ambient background noise to suggest an outdoor setting, avoiding any human voices or singing. Focus on realistic skateboarding sounds, emphasizing wheel contact, impacts, and movement."
10
+ _ZC6yk5iE1I_000026,playing trumpet,"Generate a continuous trumpet sound with melodic variations, mimicking the sound of a person playing the trumpet idealy in a musical setting, ensuring clarity and realistic tone. Avoid extraneous noise or background sounds to reflect the focus on trumpet playing. The audio should resemble a skilled player producing expressive, melodious trumpet notes."
11
+ 55L7peYRB_Q_000120,using sewing machines,"Generate ambient sewing room sounds with consistent sewing machine hum, minimal background noise, and no human voices, focusing on characteristic machine noise to match the correct descriptions."
12
+ 4p8n4Zf-WMM_000190,lighting firecrackers,"Generate the sound of firecrackers lighting and exploding repeatedly, mixed with distant background sounds of crickets chirping. Incorporate occasional subtle echoes to mimic outdoor night ambiance, with no human voices present. End with a series of sharp cracker bursts to create a lively, festive atmosphere."
13
+ yLazKv68TeA_000078,people eating crisps,"Create audio with consistent crisp sounds of people eating chips, including crinkling paper and breathing. Include subtle chewing noises to match the activity. Avoid background music or voices for clarity."
14
+ _XyxrZDZ36E_000034,hammering nails,"Generate audio with consistent hammering sounds, featuring a rhythmic pattern of nails being driven into a surface, with occasional ambient background sounds like birds chirping and distant traffic. Avoid human voices, focusing on realistic hammer strikes and natural outdoor environment sounds. Ensure the hammering tone is steady and clear, matching the description of continuous nail hammering."
15
+ 1u1orBeV4xI_000428,ripping paper,"Start with a subtle tearing sound of paper being ripped, emphasizing a continuous, consistent noise. Ensure the sound has slight variations to mimic real tearing. No background or additional noises are needed, focusing solely on the tearing action."
16
+ JFG4YvcJ3bo_000228,playing bongo,"Generate a lively percussion track featuring rhythmic djembe beats, with a melodic guitar strumming softly in the background to enhance the musical atmosphere. Ensure no human voice is included, focusing on the percussive and guitar sounds. Maintain a natural, well-balanced stereo mix to highlight the instruments' interplay."
17
+ 1pViEqMXJH0_000030,printer printing,"Generate a continuous printer printing sound with periodic beeps, resembling typical printer noise, including paper movement and occasional beeps for realism. Add subtle ambient background noise, like faint room sounds, to enhance authenticity. Ensure the primary focus remains on the printing and beeping sounds, consistent with the correct audio descriptions."
examples/1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8884c466292b46510c298a9ee88d8a584c86cb750afb558108c0850413e21e51
3
+ size 634576
examples/1_mute.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0ca4223b15583d8099d023bac3e86725bfd5cfbbad771ef67d31c1ad953bdc3
3
+ size 482981
examples/2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8c7a3dd144c91d690b07892b30544c24000008873116451291f59553f3908a4
3
+ size 368050
examples/2_mute.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e09de17f3d3f631a5a7cd5dfb2b5b32a78fbcd3c1b90673dfa36ce798647f1e8
3
+ size 216098
examples/3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dccddd67a954b12d34d481107c499460c69bebd92913b8f092724fcdf1c5baf
3
+ size 1716778
examples/3_mute.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:103ed8d5e4fbe8d6954dde463c2a997acd0e21c13895404706dbbaab39e2b086
3
+ size 1564981
examples/4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a399f94d5748b77497e28e1fca4e70c3900e1583cab4ecaefb242507b9fe1b
3
+ size 3642290
examples/4_mute.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9aa6d9ef7523cec4f5e0087d9f6c6b86efceb694f0433da5d07bdf57eea1247
3
+ size 3490447
examples/5.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:189a6cfcac18470a3f18285013f63f46c7b5996f31e0d3ecf617d9f7f91fdfeb
3
+ size 738718
examples/5_mute.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:103c2b1517d9cbecd81b394c93ebe36f166e95f635ceee280c28073233f08173
3
+ size 586982
extract_latents.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ import logging
8
+ from data_utils.v2a_utils.vggsound_224_no_audio import VGGSound
9
+ from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
10
+ import torchaudio
11
+ from einops import rearrange
12
+ from torch.utils.data.dataloader import default_collate
13
+ import numpy as np
14
+ from huggingface_hub import hf_hub_download
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+
17
+ def setup(rank, world_size):
18
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
19
+ torch.cuda.set_device(rank)
20
+
21
+ def cleanup():
22
+ dist.destroy_process_group()
23
+
24
+ def error_avoidance_collate(batch):
25
+ batch = list(filter(lambda x: x is not None, batch))
26
+ return default_collate(batch)
27
+
28
+ def main(args):
29
+
30
+ print(f"Using root: {args.root}, tsv_path: {args.tsv_path}, save_dir: {args.save_dir}")
31
+ dataset = VGGSound(
32
+ root=args.root,
33
+ tsv_path=args.tsv_path,
34
+ sample_rate=args.sample_rate,
35
+ duration_sec=args.duration_sec,
36
+ audio_samples=args.audio_samples,
37
+ start_row=args.start_row,
38
+ end_row=args.end_row,
39
+ save_dir=args.save_dir
40
+ )
41
+ save_dir = args.save_dir
42
+ os.makedirs(save_dir, exist_ok=True)
43
+
44
+ dataloader = DataLoader(dataset, batch_size=2, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate)
45
+
46
+ print(f"Dataset length: {len(dataset)}")
47
+ feature_extractor = FeaturesUtils(
48
+ vae_ckpt=None,
49
+ vae_config=args.vae_config,
50
+ enable_conditions=True,
51
+ synchformer_ckpt=args.synchformer_ckpt
52
+ ).eval().cuda()
53
+
54
+ feature_extractor = feature_extractor
55
+
56
+ for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
57
+ ids = data['id']
58
+ with torch.no_grad():
59
+ # audio = data['audio'].cuda(rank, non_blocking=True)
60
+ output = {
61
+ 'caption': str(data['caption']),
62
+ 'caption_cot': str(data['caption_cot'])
63
+ }
64
+ print(output)
65
+
66
+ # latent = feature_extractor.module.encode_audio(audio)
67
+ # output['latent'] = latent.detach().cpu()
68
+
69
+ clip_video = data['clip_video'].cuda()
70
+ clip_features = feature_extractor.encode_video_with_clip(clip_video)
71
+ output['metaclip_features'] = clip_features.detach().cpu()
72
+
73
+ sync_video = data['sync_video'].cuda()
74
+ sync_features = feature_extractor.encode_video_with_sync(sync_video)
75
+ output['sync_features'] = sync_features.detach().cpu()
76
+
77
+ caption = data['caption']
78
+ metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(caption)
79
+ output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
80
+ output['metaclip_text_features'] = metaclip_text_features.detach().cpu()
81
+
82
+ caption_cot = data['caption_cot']
83
+ t5_features = feature_extractor.encode_t5_text(caption_cot)
84
+ output['t5_features'] = t5_features.detach().cpu()
85
+
86
+ for j in range(len(ids)):
87
+ sample_output = {
88
+ 'id': ids[j],
89
+ 'caption': output['caption'][j],
90
+ 'caption_cot': output['caption_cot'][j],
91
+ # 'latent': output['latent'][j],
92
+ 'metaclip_features': output['metaclip_features'][j],
93
+ 'sync_features': output['sync_features'][j],
94
+ 'metaclip_global_text_features': output['metaclip_global_text_features'][j],
95
+ 'metaclip_text_features': output['metaclip_text_features'][j],
96
+ 't5_features': output['t5_features'][j],
97
+ }
98
+ # torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')
99
+ np.savez(f'{save_dir}/demo.npz', **sample_output)
100
+
101
+ ## test the sync between videos and audios
102
+ # torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
103
+ # recon_audio = feature_extractor.decode_audio(latent)
104
+ # recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
105
+ # id = data['id']
106
+ # torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
107
+ # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
108
+ # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')
109
+
110
+
111
+ if __name__ == '__main__':
112
+ parser = argparse.ArgumentParser(description='Extract Video Training Latents')
113
+ parser.add_argument('--root', type=str, default='videos', help='Root directory of the video dataset')
114
+ parser.add_argument('--tsv_path', type=str, default='cot_coarse/cot.csv', help='Path to the TSV file')
115
+ parser.add_argument('--save-dir', type=str, default='results', help='Save Directory')
116
+ parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
117
+ parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
118
+ parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
119
+ parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
120
+ parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
121
+ parser.add_argument('--start-row', type=int, default=0, help='start row')
122
+ parser.add_argument('--end-row', type=int, default=None, help='end row')
123
+
124
+ args = parser.parse_args()
125
+ args.audio_samples = int(args.sample_rate * args.duration_sec)
126
+
127
+ main(args=args)
128
+
predict.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prefigure.prefigure import get_all_args, push_wandb_config
2
+ import json
3
+ import os
4
+ import re
5
+ import torch
6
+ import torchaudio
7
+ # import pytorch_lightning as pl
8
+ import lightning as L
9
+ from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
10
+ from lightning.pytorch.callbacks import Callback
11
+ from lightning.pytorch.tuner import Tuner
12
+ from lightning.pytorch import seed_everything
13
+ import random
14
+ from datetime import datetime
15
+
16
+ from ThinkSound.data.datamodule import DataModule
17
+ from ThinkSound.models import create_model_from_config
18
+ from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
19
+ from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
20
+ from ThinkSound.training.utils import copy_state_dict
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ class ExceptionCallback(Callback):
24
+ def on_exception(self, trainer, module, err):
25
+ print(f'{type(err).__name__}: {err}')
26
+
27
+ class ModelConfigEmbedderCallback(Callback):
28
+ def __init__(self, model_config):
29
+ self.model_config = model_config
30
+
31
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
32
+ checkpoint["model_config"] = self.model_config
33
+
34
+ class CustomWriter(BasePredictionWriter):
35
+
36
+ def __init__(self, output_dir, write_interval='batch', batch_size=32):
37
+ super().__init__(write_interval)
38
+ self.output_dir = output_dir
39
+ self.batch_size = batch_size
40
+
41
+ def write_on_batch_end(self, trainer, pl_module, predictions, batch_indices, batch, batch_idx, dataloader_idx):
42
+
43
+ audios = predictions
44
+ ids = [item['id'] for item in batch[1]]
45
+ current_date = datetime.now()
46
+
47
+ formatted_date = current_date.strftime('%m%d')
48
+ os.makedirs(os.path.join(self.output_dir, f'{formatted_date}_batch_size{self.batch_size}'),exist_ok=True)
49
+ for audio, id in zip(audios, ids):
50
+ save_path = os.path.join(self.output_dir, f'{formatted_date}_batch_size{self.batch_size}', f'{id}.wav')
51
+ torchaudio.save(save_path, audio, 44100)
52
+
53
+ def main():
54
+
55
+ args = get_all_args()
56
+
57
+
58
+ # args.pretransform_ckpt_path = hf_hub_download(
59
+ # repo_id="liuhuadai/ThinkSound",
60
+ # filename="vae.ckpt"
61
+ # )
62
+
63
+ args.pretransform_ckpt_path = "./ckpts/vae.ckpt"
64
+
65
+
66
+ seed = 10086
67
+
68
+ # Set a different seed for each process if using SLURM
69
+ if os.environ.get("SLURM_PROCID") is not None:
70
+ seed += int(os.environ.get("SLURM_PROCID"))
71
+
72
+ # random.seed(seed)
73
+ # torch.manual_seed(seed)
74
+ seed_everything(seed, workers=True)
75
+
76
+ #Get JSON config from args.model_config
77
+ with open(args.model_config) as f:
78
+ model_config = json.load(f)
79
+
80
+ with open(args.dataset_config) as f:
81
+ dataset_config = json.load(f)
82
+
83
+ for td in dataset_config["test_datasets"]:
84
+ td["path"] = args.results_dir
85
+
86
+ # train_dl = create_dataloader_from_config(
87
+ # dataset_config,
88
+ # batch_size=args.batch_size,
89
+ # num_workers=args.num_workers,
90
+ # sample_rate=model_config["sample_rate"],
91
+ # sample_size=model_config["sample_size"],
92
+ # audio_channels=model_config.get("audio_channels", 2),
93
+ # )
94
+
95
+
96
+ duration=(float)(args.duration_sec)
97
+
98
+ dm = DataModule(
99
+ dataset_config,
100
+ batch_size=args.batch_size,
101
+ test_batch_size=args.test_batch_size,
102
+ num_workers=args.num_workers,
103
+ sample_rate=model_config["sample_rate"],
104
+ sample_size=(float)(args.duration_sec) * model_config["sample_rate"],
105
+ audio_channels=model_config.get("audio_channels", 2),
106
+ latent_length=round(44100/64/32*duration),
107
+ )
108
+
109
+ model_config["sample_size"] = duration * model_config["sample_rate"]
110
+ model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24*int(duration)
111
+ model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8*int(duration)
112
+ model_config["model"]["diffusion"]["config"]["latent_seq_len"] = round(44100/64/32*duration)
113
+
114
+ model = create_model_from_config(model_config)
115
+
116
+ ## speed by torch.compile
117
+ if args.compile:
118
+ model = torch.compile(model)
119
+
120
+ if args.pretrained_ckpt_path:
121
+ copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
122
+
123
+ if args.remove_pretransform_weight_norm == "pre_load":
124
+ remove_weight_norm_from_model(model.pretransform)
125
+ # import ipdb
126
+ # ipdb.set_trace()
127
+ if args.pretransform_ckpt_path:
128
+ load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.')
129
+ # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
130
+ model.pretransform.load_state_dict(load_vae_state)
131
+
132
+ # Remove weight_norm from the pretransform if specified
133
+ if args.remove_pretransform_weight_norm == "post_load":
134
+ remove_weight_norm_from_model(model.pretransform)
135
+
136
+ training_wrapper = create_training_wrapper_from_config(model_config, model)
137
+
138
+ # wandb_logger = L.pytorch.loggers.WandbLogger(project=args.name)
139
+ # wandb_logger.watch(training_wrapper)
140
+
141
+ exc_callback = ExceptionCallback()
142
+
143
+ # if args.save_dir and isinstance(wandb_logger.experiment.id, str):
144
+ # checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints")
145
+ # else:
146
+ # checkpoint_dir = None
147
+
148
+ # ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='val_loss', mode='min', save_top_k=10)
149
+ save_model_config_callback = ModelConfigEmbedderCallback(model_config)
150
+ audio_dir = args.results_dir
151
+ pred_writer = CustomWriter(output_dir=audio_dir, write_interval="batch", batch_size=args.test_batch_size)
152
+ timer = Timer(duration="00:15:00:00")
153
+ demo_callback = create_demo_callback_from_config(model_config, demo_dl=dm)
154
+
155
+ #Combine args and config dicts
156
+ args_dict = vars(args)
157
+ args_dict.update({"model_config": model_config})
158
+ args_dict.update({"dataset_config": dataset_config})
159
+ # push_wandb_config(wandb_logger, args_dict)
160
+
161
+ #Set multi-GPU strategy if specified
162
+ if args.strategy:
163
+ if args.strategy == "deepspeed":
164
+ from pytorch_lightning.strategies import DeepSpeedStrategy
165
+ strategy = DeepSpeedStrategy(stage=2,
166
+ contiguous_gradients=True,
167
+ overlap_comm=True,
168
+ reduce_scatter=True,
169
+ reduce_bucket_size=5e8,
170
+ allgather_bucket_size=5e8,
171
+ load_full_weights=True
172
+ )
173
+ else:
174
+ strategy = args.strategy
175
+ else:
176
+ strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto"
177
+
178
+ trainer = L.Trainer(
179
+ devices=args.num_gpus,
180
+ accelerator="gpu",
181
+ num_nodes = args.num_nodes,
182
+ strategy=strategy,
183
+ precision=args.precision,
184
+ accumulate_grad_batches=args.accum_batches,
185
+ callbacks=[demo_callback, exc_callback, save_model_config_callback, timer, pred_writer],
186
+ log_every_n_steps=1,
187
+ max_epochs=1000,
188
+ default_root_dir=args.save_dir,
189
+ gradient_clip_val=args.gradient_clip_val,
190
+ reload_dataloaders_every_n_epochs = 0,
191
+ check_val_every_n_epoch=2,
192
+ )
193
+
194
+
195
+
196
+ # ckpt_path = hf_hub_download(
197
+ # repo_id="liuhuadai/ThinkSound",
198
+ # filename="thinksound.ckpt"
199
+ # )
200
+ ckpt_path = 'ckpts/thinksound.ckpt'
201
+
202
+
203
+
204
+ current_date = datetime.now()
205
+ formatted_date = current_date.strftime('%m%d')
206
+
207
+ audio_dir = f'{formatted_date}_step68k_batch_size'+str(args.test_batch_size)
208
+ metrics_path = os.path.join(args.ckpt_dir, 'audios',audio_dir,'cache',"output_metrics.json")
209
+ # if os.path.exists(metrics_path): continue
210
+
211
+ trainer.predict(training_wrapper, dm, return_predictions=False,ckpt_path=ckpt_path)
212
+
213
+ if __name__ == '__main__':
214
+ main()
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ modelscope
2
+ absl-py==2.2.2
3
+ accelerate==1.6.0
4
+ aeiou==0.0.20
5
+ aiobotocore==2.22.0
6
+ aiofiles==23.2.1
7
+ aiohappyeyeballs==2.6.1
8
+ aiohttp==3.11.18
9
+ aioitertools==0.12.0
10
+ aiosignal==1.3.2
11
+ alias-free-torch==0.0.6
12
+ annotated-types==0.7.0
13
+ antlr4-python3-runtime==4.9.3
14
+ anyio==4.9.0
15
+ appdirs==1.4.4
16
+ argbind==0.3.9
17
+ asttokens==3.0.0
18
+ async-timeout==5.0.1
19
+ attrs==25.3.0
20
+ audiobox_aesthetics==0.0.2
21
+ audioread==3.0.1
22
+ auraloss==0.4.0
23
+ av==14.4.0
24
+ bleach==6.2.0
25
+ bokeh==3.7.3
26
+ botocore==1.37.3
27
+ braceexpand==0.1.7
28
+ Brotli==1.1.0
29
+ certifi==2025.4.26
30
+ cffi==1.17.1
31
+ charset-normalizer==3.4.1
32
+ clean-fid==0.1.35
33
+ click==8.1.8
34
+ clip-anytorch==2.6.0
35
+ cloudpickle==3.1.1
36
+ colorcet==3.1.0
37
+ colorlog==6.9.0
38
+ configparser==7.2.0
39
+ contourpy==1.3.2
40
+ cycler==0.12.1
41
+ Cython==3.1.1
42
+ dctorch==0.1.2
43
+ decorator==4.4.2
44
+ decord==0.6.0
45
+ descript-audio-codec==1.0.0
46
+ docker-pycreds==0.4.0
47
+ docstring_parser==0.16
48
+ einops==0.7.0
49
+ einops-exts==0.0.4
50
+ ema-pytorch==0.2.3
51
+ encodec==0.1.1
52
+ exceptiongroup==1.2.2
53
+ executing==2.2.0
54
+ fastapi==0.115.12
55
+ fastcore==1.8.2
56
+ ffmpeg==1.4
57
+ ffmpy==0.5.0
58
+ filelock==3.18.0
59
+ fire==0.7.0
60
+ flatten-dict==0.4.2
61
+ fonttools==4.58.0
62
+ frozenlist==1.6.0
63
+ fsspec==2025.5.0
64
+ ftfy==6.3.1
65
+ future==1.0.0
66
+ fvcore==0.1.5.post20221221
67
+ gin-config==0.5.0
68
+ gitdb==4.0.12
69
+ GitPython==3.1.44
70
+ gradio==3.50.0
71
+ gradio_client==0.6.1
72
+ groovy==0.1.2
73
+ grpcio==1.71.0
74
+ h11==0.16.0
75
+ hf_xet
76
+ h5py==3.13.0
77
+ hjson==3.1.0
78
+ holoviews==1.20.2
79
+ httpcore==1.0.9
80
+ httpx==0.28.1
81
+ huggingface-hub==0.30.2
82
+ hydra-colorlog==1.2.0
83
+ hydra-core==1.3.2
84
+ idna==3.10
85
+ imageio==2.37.0
86
+ imageio-ffmpeg==0.4.9
87
+ importlib-resources==5.12.0
88
+ importlib_metadata==8.7.0
89
+ iopath==0.1.10
90
+ ipython==8.36.0
91
+ jedi==0.19.2
92
+ Jinja2==3.1.0
93
+ jmespath==1.0.1
94
+ joblib==1.5.0
95
+ jsonmerge==1.9.2
96
+ jsonschema==4.23.0
97
+ jsonschema-specifications==2025.4.1
98
+ julius==0.2.7
99
+ k-diffusion==0.1.1
100
+ kiwisolver==1.4.8
101
+ kornia==0.8.1
102
+ kornia_rs==0.1.9
103
+ laion-clap==1.1.4
104
+ latex2mathml==3.77.0
105
+ lazy_loader==0.4
106
+ librosa==0.9.2
107
+ lightning==2.5.1.post0
108
+ lightning-utilities==0.14.3
109
+ linkify-it-py==2.0.3
110
+ llvmlite==0.43.0
111
+ local-attention==1.8.6
112
+ Markdown==3.8
113
+ markdown-it-py==3.0.0
114
+ markdown2==2.5.3
115
+ MarkupSafe==2.1.5
116
+ matplotlib==3.10.3
117
+ matplotlib-inline==0.1.7
118
+ mdit-py-plugins==0.4.2
119
+ mdurl==0.1.2
120
+ moviepy==1.0.3
121
+ mpmath==1.3.0
122
+ multidict==6.4.4
123
+ multiprocessing-logging==0.2.4
124
+ mutagen==1.47.0
125
+ narwhals==1.40.0
126
+ networkx==3.4.2
127
+ ninja==1.11.1.3
128
+ nitrous_ema==0.0.1
129
+ numba==0.60.0
130
+ numpy==1.23.5
131
+ omegaconf==2.3.0
132
+ open_clip_torch==2.32.0
133
+ openai==1.33.0
134
+ opencv-python==4.11.0.86
135
+ orjson==3.10.18
136
+ pafy==0.5.3.1
137
+ pandas==2.0.2
138
+ panel==1.7.0
139
+ param==2.2.0
140
+ parameterized==0.9.0
141
+ parso==0.8.4
142
+ pathtools==0.1.2
143
+ pedalboard==0.7.4
144
+ pexpect==4.9.0
145
+ pillow
146
+ platformdirs==4.3.8
147
+ plotly==6.1.1
148
+ pooch==1.8.2
149
+ prefigure==0.0.9
150
+ proglog==0.1.10
151
+ progressbar==2.5
152
+ prompt_toolkit==3.0.51
153
+ propcache==0.3.1
154
+ protobuf==3.19.6
155
+ psutil==7.0.0
156
+ ptyprocess==0.7.0
157
+ pure_eval==0.2.3
158
+ py-cpuinfo==9.0.0
159
+ pycparser==2.22
160
+ pydantic==2.11.5
161
+ pydantic_core==2.33.2
162
+ pydub==0.25.1
163
+ Pygments==2.19.1
164
+ pyloudnorm==0.1.1
165
+ pynndescent==0.5.13
166
+ pynvml==12.0.0
167
+ pyparsing==3.2.3
168
+ pystoi==0.4.1
169
+ pysubs2==1.8.0
170
+ python-dateutil==2.9.0.post0
171
+ python-dotenv==1.0.1
172
+ python-multipart==0.0.20
173
+ pytorch-lightning==2.5.1.post0
174
+ pytorchvideo==0.1.5
175
+ pytz==2025.2
176
+ pyviz_comms==3.0.4
177
+ PyWavelets==1.4.1
178
+ PyYAML==6.0.2
179
+ randomname==0.2.1
180
+ referencing==0.36.2
181
+ regex==2024.11.6
182
+ requests==2.32.3
183
+ resampy==0.4.3
184
+ rich==14.0.0
185
+ rpds-py==0.25.1
186
+ ruff==0.11.11
187
+ s3fs==2025.5.0
188
+ safehttpx==0.1.6
189
+ safetensors==0.5.3
190
+ scenedetect==0.6.3
191
+ scikit-image==0.24.0
192
+ scikit-learn==1.6.1
193
+ scipy==1.15.3
194
+ semantic-version==2.10.0
195
+ sentencepiece==0.1.99
196
+ sentry-sdk==2.29.1
197
+ setproctitle==1.3.6
198
+ shellingham==1.5.4
199
+ shortuuid==1.0.13
200
+ six==1.17.0
201
+ smmap==5.0.2
202
+ sniffio==1.3.1
203
+ SoundFile==0.10.2
204
+ sox==1.3.0
205
+ stack-data==0.6.3
206
+ starlette==0.46.2
207
+ submitit==1.5.2
208
+ svgwrite==1.4.3
209
+ sympy==1.13.1
210
+ tabulate==0.9.0
211
+ tensorboard-data-server==0.7.2
212
+ termcolor==3.1.0
213
+ threadpoolctl==3.6.0
214
+ tifffile==2025.5.10
215
+ timm==1.0.15
216
+ tokenizers==0.19
217
+ tomlkit==0.13.2
218
+ torch==2.4.0
219
+ torch-stoi==0.2.3
220
+ torchaudio==2.4.0
221
+ torchdiffeq==0.2.5
222
+ torchlibrosa==0.1.0
223
+ torchmetrics==0.11.4
224
+ torchsde==0.2.6
225
+ torchvision==0.19.0
226
+ tornado==6.5.1
227
+ git+https://github.com/patrick-kidger/torchcubicspline.git
228
+ tqdm==4.67.1
229
+ traitlets==5.14.3
230
+ trampoline==0.1.2
231
+ transformers==4.43
232
+ triton==3.0.0
233
+ typer==0.15.4
234
+ typing-inspection==0.4.1
235
+ typing_extensions==4.12.2
236
+ tzdata==2025.2
237
+ uc-micro-py==1.0.3
238
+ umap-learn==0.5.7
239
+ urllib3==2.4.0
240
+ uvicorn==0.34.2
241
+ v-diffusion-pytorch==0.0.2
242
+ vector-quantize-pytorch==1.9.14
243
+ wcwidth==0.2.13
244
+ webdataset==0.2.48
245
+ webencodings==0.5.1
246
+ Werkzeug==3.1.3
247
+ wget==3.2
248
+ wrapt==1.17.2
249
+ x-transformers==1.26.6
250
+ xyzservices==2025.4.0
251
+ yacs==0.1.8
252
+ yarl==1.20.0
253
+ zipp==3.21.0
254
+ altair==5.5.0
scripts/demo.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Check number of arguments
4
+ if [ "$#" -ne 3 ]; then
5
+ echo "Usage: $0 <video_path> <title> <description>"
6
+ exit 1
7
+ fi
8
+
9
+ VIDEO_PATH="$1"
10
+ TITLE="$2"
11
+ DESCRIPTION="$3"
12
+
13
+ # Generate unique ID
14
+ UNIQUE_ID=$(uuidgen | cut -c 1-8)
15
+
16
+ # Create necessary directories
17
+ mkdir -p videos cot_coarse results
18
+
19
+ # Get video filename and extension
20
+ VIDEO_FILE=$(basename "$VIDEO_PATH")
21
+ VIDEO_EXT="${VIDEO_FILE##*.}"
22
+ VIDEO_ID="${VIDEO_FILE%.*}"
23
+ TEMP_VIDEO_PATH="videos/${VIDEO_ID}_${UNIQUE_ID}.mp4"
24
+
25
+ # Convert video to MP4 format if needed
26
+ if [ "${VIDEO_EXT,,}" != "mp4" ]; then
27
+ echo "⏳ Converting video to MP4 format..."
28
+ ffmpeg -y -i "$VIDEO_PATH" -c:v libx264 -preset fast -c:a aac -strict experimental "$TEMP_VIDEO_PATH" >/dev/null 2>&1
29
+ if [ $? -ne 0 ]; then
30
+ echo "❌ Video conversion failed"
31
+ exit 2
32
+ fi
33
+ else
34
+ cp "$VIDEO_PATH" "$TEMP_VIDEO_PATH"
35
+ fi
36
+
37
+ # Calculate video duration
38
+ DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$TEMP_VIDEO_PATH")
39
+ DURATION_SEC=${DURATION%.*}
40
+ echo "Duration is: $DURATION_SEC"
41
+
42
+ # Create cot.csv file
43
+ CAPTION_COT=$(echo "$DESCRIPTION" | tr '"' "'")
44
+ CSV_PATH="cot_coarse/cot.csv"
45
+ echo "id,caption,caption_cot" > "$CSV_PATH"
46
+ echo "${VIDEO_ID}_${UNIQUE_ID},$TITLE,\"$CAPTION_COT\"" >> "$CSV_PATH"
47
+
48
+ # Run feature extraction
49
+ echo "⏳ Extracting features..."
50
+ python extract_latents.py --duration_sec "$DURATION_SEC" 2>&1
51
+ if [ $? -ne 0 ]; then
52
+ echo "❌ Feature extraction failed"
53
+ rm -f "$TEMP_VIDEO_PATH"
54
+ exit 3
55
+ fi
56
+
57
+ # Run inference
58
+ echo "⏳ Running model inference..."
59
+ bash scripts/infer.sh --duration-sec "$DURATION_SEC" 2>&1
60
+ if [ $? -ne 0 ]; then
61
+ echo "❌ Inference failed"
62
+ rm -f "$TEMP_VIDEO_PATH"
63
+ exit 4
64
+ fi
65
+
66
+ # Get generated audio file
67
+ CURRENT_DATE=$(date +"%m%d")
68
+ AUDIO_PATH="results/${CURRENT_DATE}_batch_size1/demo.wav"
69
+
70
+ # Check if audio file exists
71
+ if [ ! -f "$AUDIO_PATH" ]; then
72
+ echo "❌ Generated audio file not found"
73
+ rm -f "$TEMP_VIDEO_PATH"
74
+ exit 5
75
+ fi
76
+
77
+ # Clean up temporary video file
78
+ rm -f "$TEMP_VIDEO_PATH"
79
+
80
+
81
+ echo "✅ Audio generated successfully!"
82
+ echo "Audio file path: $AUDIO_PATH"
scripts/infer.sh ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 变量定义
4
+ ckpt_dir="ckpts/thinksound.ckpt"
5
+ test_batch_size=1
6
+ dataset_config="ThinkSound/configs/multimodal_dataset_demo.json"
7
+ model_config="ThinkSound/configs/model_configs/thinksound.json"
8
+ pretransform_ckpt_path="ckpts/vae.ckpt"
9
+ # 默认值
10
+ debug_mode="true"
11
+ node_rank=0
12
+
13
+ result_path="results"
14
+
15
+ while [[ $# -gt 0 ]]; do
16
+ case "$1" in
17
+ --duration-sec)
18
+ if [[ -n "$2" && "$2" != --* ]]; then
19
+ duration_sec="$2"
20
+ shift 2
21
+ else
22
+ echo "❌ Argument --duration-sec requires a value"
23
+ exit 1
24
+ fi
25
+ ;;
26
+ --result-path)
27
+ if [[ -n "$2" && "$2" != --* ]]; then
28
+ result_path="$2"
29
+ shift 2
30
+ else
31
+ echo "❌ Argument --result-path requires a path"
32
+ exit 1
33
+ fi
34
+ ;;
35
+ *)
36
+ echo "❌ Unknown argument: $1"
37
+ exit 1
38
+ ;;
39
+ esac
40
+ done
41
+
42
+ export NODE_RANK=$node_rank
43
+ export RANK=$node_rank
44
+
45
+ num_gpus=1
46
+ num_nodes=1
47
+
48
+ export WORLD_SIZE=$((num_gpus * num_nodes))
49
+ # 打印配置信息
50
+ echo "Training Configuration:"
51
+ echo "Checkpoint Directory: $ckpt_dir"
52
+ echo "Dataset Config: $dataset_config"
53
+ echo "Model Config: $model_config"
54
+ echo "Pretransform Checkpoint Path: $pretransform_ckpt_path"
55
+ echo "Num GPUs: $num_gpus"
56
+ echo "Num Nodes: $num_nodes"
57
+ echo "Test Batch Size: $test_batch_size"
58
+ echo "Num Workers: 20"
59
+ echo "Node Rank: $node_rank"
60
+ echo "WORLD SIZE: $WORLD_SIZE"
61
+
62
+
63
+ python predict.py \
64
+ --dataset-config "$dataset_config" \
65
+ --model-config "$model_config" \
66
+ --ckpt-dir "$ckpt_dir" \
67
+ --pretransform-ckpt-path "$pretransform_ckpt_path" \
68
+ --checkpoint-every 2000 \
69
+ --num-gpus "$num_gpus" \
70
+ --num-nodes "$num_nodes" \
71
+ --batch-size 1 \
72
+ --test-batch-size $test_batch_size \
73
+ --num-workers 32 \
74
+ --duration-sec $duration_sec \
75
+ --results-dir $result_path \
76
+
setup.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='thinksound',
5
+ version='0.0.16',
6
+ url='https://github.com/liuhuadai/thinksound.git',
7
+ author='liuhuadai',
8
+ description='a unified Any2Audio generation framework guided by Chain-of-Thought (CoT) reasoning',
9
+ packages=find_packages(),
10
+ install_requires=[
11
+ 'aeiou==0.0.20',
12
+ 'alias-free-torch==0.0.6',
13
+ 'auraloss==0.4.0',
14
+ 'descript-audio-codec==1.0.0',
15
+ 'einops==0.7.0',
16
+ 'einops-exts==0.0.4',
17
+ 'ema-pytorch==0.2.3',
18
+ 'encodec==0.1.1',
19
+ # 'gradio>=3.42.0',
20
+ 'huggingface_hub',
21
+ 'importlib-resources==5.12.0',
22
+ 'k-diffusion==0.1.1',
23
+ 'laion-clap==1.1.4',
24
+ 'local-attention==1.8.6',
25
+ 'pandas==2.0.2',
26
+ 'pedalboard==0.7.4',
27
+ 'prefigure==0.0.9',
28
+ 'pytorch_lightning==2.1.0',
29
+ 'PyWavelets==1.4.1',
30
+ 'safetensors',
31
+ 'sentencepiece==0.1.99',
32
+ 's3fs',
33
+ 'torch>=2.0.1',
34
+ 'torchaudio>=2.0.2',
35
+ 'torchmetrics==0.11.4',
36
+ 'tqdm',
37
+ 'transformers',
38
+ 'v-diffusion-pytorch==0.0.2',
39
+ 'vector-quantize-pytorch==1.9.14',
40
+ 'wandb==0.15.4',
41
+ 'webdataset==0.2.48',
42
+ 'x-transformers<1.27.0'
43
+ ],
44
+ )