Spaces:
Running
on
Zero
Running
on
Zero
Implement core functionality for ThinkSound audio generation app, including video processing, audio synthesis, and Gradio interface setup. Update README with new title and emoji.
Browse files- README.md +5 -7
- ThinkSound +1 -0
- app.py +340 -4
- cot_vgg_demo_caption.txt +1 -0
- data_utils/ext/synchformer/LICENSE +21 -0
- data_utils/ext/synchformer/__init__.py +1 -0
- data_utils/ext/synchformer/divided_224_16x4.yaml +84 -0
- data_utils/ext/synchformer/motionformer.py +400 -0
- data_utils/ext/synchformer/synchformer.py +55 -0
- data_utils/ext/synchformer/utils.py +92 -0
- data_utils/ext/synchformer/video_model_builder.py +277 -0
- data_utils/ext/synchformer/vit_helper.py +399 -0
- data_utils/v2a_utils/__init__.py +0 -0
- data_utils/v2a_utils/audio_text_dataset.py +173 -0
- data_utils/v2a_utils/audioset_224.py +315 -0
- data_utils/v2a_utils/audioset_video_224.py +268 -0
- data_utils/v2a_utils/feature_utils_224.py +182 -0
- data_utils/v2a_utils/vggsound.py +259 -0
- data_utils/v2a_utils/vggsound_224.py +320 -0
- data_utils/v2a_utils/vggsound_224_no_audio.py +275 -0
- data_utils/v2a_utils/vggsound_224_no_sync.py +223 -0
- data_utils/v2a_utils/vggsound_text.py +109 -0
- defaults.ini +68 -0
- demo_test.csv +17 -0
- examples/1.mp4 +3 -0
- examples/1_mute.mp4 +3 -0
- examples/2.mp4 +3 -0
- examples/2_mute.mp4 +3 -0
- examples/3.mp4 +3 -0
- examples/3_mute.mp4 +3 -0
- examples/4.mp4 +3 -0
- examples/4_mute.mp4 +3 -0
- examples/5.mp4 +3 -0
- examples/5_mute.mp4 +3 -0
- extract_latents.py +128 -0
- predict.py +214 -0
- pyproject.toml +3 -0
- requirements.txt +254 -0
- scripts/demo.sh +82 -0
- scripts/infer.sh +76 -0
- setup.py +44 -0
README.md
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
---
|
2 |
-
title: ThinkSound
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: ThinkSound
|
3 |
+
emoji: 🔊
|
4 |
+
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.35.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
---
|
|
|
|
ThinkSound
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 9eeb443af2ab75afc046e27494c522dfd87fa2c1
|
app.py
CHANGED
@@ -1,7 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
|
|
1 |
+
from prefigure.prefigure import get_all_args, push_wandb_config
|
2 |
+
import spaces
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp"
|
6 |
+
import re
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
# import pytorch_lightning as pl
|
10 |
+
import lightning as L
|
11 |
+
from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
|
12 |
+
from lightning.pytorch.callbacks import Callback
|
13 |
+
from lightning.pytorch.tuner import Tuner
|
14 |
+
from lightning.pytorch import seed_everything
|
15 |
+
import random
|
16 |
+
from datetime import datetime
|
17 |
+
from ThinkSound.data.datamodule import DataModule
|
18 |
+
from ThinkSound.models import create_model_from_config
|
19 |
+
from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
|
20 |
+
from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
|
21 |
+
from ThinkSound.training.utils import copy_state_dict
|
22 |
+
from ThinkSound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
|
23 |
+
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from typing import Optional, Union
|
26 |
+
from torchvision.transforms import v2
|
27 |
+
from torio.io import StreamingMediaDecoder
|
28 |
+
from torchvision.utils import save_image
|
29 |
+
from transformers import AutoProcessor
|
30 |
+
import torch.nn.functional as F
|
31 |
import gradio as gr
|
32 |
+
import tempfile
|
33 |
+
import subprocess
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
from moviepy.editor import VideoFileClip
|
36 |
+
# os.system("conda install -c conda-forge 'ffmpeg<7'")
|
37 |
+
|
38 |
+
_CLIP_SIZE = 224
|
39 |
+
_CLIP_FPS = 8.0
|
40 |
+
|
41 |
+
_SYNC_SIZE = 224
|
42 |
+
_SYNC_FPS = 25.0
|
43 |
+
|
44 |
+
def pad_to_square(video_tensor):
|
45 |
+
if len(video_tensor.shape) != 4:
|
46 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
47 |
+
|
48 |
+
l, c, h, w = video_tensor.shape
|
49 |
+
max_side = max(h, w)
|
50 |
+
|
51 |
+
pad_h = max_side - h
|
52 |
+
pad_w = max_side - w
|
53 |
+
|
54 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
55 |
+
|
56 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
57 |
+
|
58 |
+
return video_padded
|
59 |
+
|
60 |
+
|
61 |
+
class VGGSound(Dataset):
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
sample_rate: int = 44_100,
|
66 |
+
duration_sec: float = 9.0,
|
67 |
+
audio_samples: int = None,
|
68 |
+
normalize_audio: bool = False,
|
69 |
+
):
|
70 |
+
if audio_samples is None:
|
71 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
72 |
+
else:
|
73 |
+
self.audio_samples = audio_samples
|
74 |
+
effective_duration = audio_samples / sample_rate
|
75 |
+
# make sure the duration is close enough, within 15ms
|
76 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
77 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
78 |
+
|
79 |
+
self.sample_rate = sample_rate
|
80 |
+
self.duration_sec = duration_sec
|
81 |
+
|
82 |
+
self.expected_audio_length = self.audio_samples
|
83 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
84 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
85 |
+
|
86 |
+
self.clip_transform = v2.Compose([
|
87 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
88 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
89 |
+
v2.ToImage(),
|
90 |
+
v2.ToDtype(torch.float32, scale=True),
|
91 |
+
])
|
92 |
+
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
93 |
+
self.sync_transform = v2.Compose([
|
94 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
95 |
+
v2.CenterCrop(_SYNC_SIZE),
|
96 |
+
v2.ToImage(),
|
97 |
+
v2.ToDtype(torch.float32, scale=True),
|
98 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
99 |
+
])
|
100 |
+
|
101 |
+
self.resampler = {}
|
102 |
+
|
103 |
+
def sample(self, video_path,label,cot):
|
104 |
+
video_id = video_path
|
105 |
+
|
106 |
+
reader = StreamingMediaDecoder(video_path)
|
107 |
+
reader.add_basic_video_stream(
|
108 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
109 |
+
frame_rate=_CLIP_FPS,
|
110 |
+
format='rgb24',
|
111 |
+
)
|
112 |
+
reader.add_basic_video_stream(
|
113 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
114 |
+
frame_rate=_SYNC_FPS,
|
115 |
+
format='rgb24',
|
116 |
+
)
|
117 |
+
|
118 |
+
reader.fill_buffer()
|
119 |
+
data_chunk = reader.pop_chunks()
|
120 |
+
|
121 |
+
clip_chunk = data_chunk[0]
|
122 |
+
sync_chunk = data_chunk[1]
|
123 |
+
|
124 |
+
if sync_chunk is None:
|
125 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
126 |
+
|
127 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
128 |
+
# import ipdb
|
129 |
+
# ipdb.set_trace()
|
130 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
131 |
+
current_length = clip_chunk.shape[0]
|
132 |
+
padding_needed = self.clip_expected_length - current_length
|
133 |
+
|
134 |
+
# Check that padding needed is no more than 2
|
135 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
136 |
+
|
137 |
+
# If assertion passes, proceed with padding
|
138 |
+
if padding_needed > 0:
|
139 |
+
last_frame = clip_chunk[-1]
|
140 |
+
log.info(last_frame.shape)
|
141 |
+
# Repeat the last frame to reach the expected length
|
142 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
143 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
144 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
145 |
+
# f'expected {self.clip_expected_length}, '
|
146 |
+
# f'got {clip_chunk.shape[0]}')
|
147 |
+
|
148 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
149 |
+
clip_chunk = pad_to_square(clip_chunk)
|
150 |
+
|
151 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
152 |
+
|
153 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
154 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
155 |
+
# padding using the last frame, but no more than 2
|
156 |
+
current_length = sync_chunk.shape[0]
|
157 |
+
last_frame = sync_chunk[-1]
|
158 |
+
|
159 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
160 |
+
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
161 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
162 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
163 |
+
# f'expected {self.sync_expected_length}, '
|
164 |
+
# f'got {sync_chunk.shape[0]}')
|
165 |
+
|
166 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
167 |
+
# assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
|
168 |
+
# and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
169 |
+
data = {
|
170 |
+
'id': video_id,
|
171 |
+
'caption': label,
|
172 |
+
'caption_cot': cot,
|
173 |
+
# 'audio': audio_chunk,
|
174 |
+
'clip_video': clip_chunk,
|
175 |
+
'sync_video': sync_chunk,
|
176 |
+
}
|
177 |
+
|
178 |
+
return data
|
179 |
+
|
180 |
+
# 检查设备
|
181 |
+
if torch.cuda.is_available():
|
182 |
+
device = 'cuda'
|
183 |
+
extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
|
184 |
+
else:
|
185 |
+
device = 'cpu'
|
186 |
+
extra_device = 'cpu'
|
187 |
+
|
188 |
+
print(f"load in device {device}")
|
189 |
+
|
190 |
+
vae_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="vae.ckpt",repo_type="model")
|
191 |
+
synchformer_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
|
192 |
+
|
193 |
+
feature_extractor = FeaturesUtils(
|
194 |
+
vae_ckpt=None,
|
195 |
+
vae_config='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json',
|
196 |
+
enable_conditions=True,
|
197 |
+
synchformer_ckpt=synchformer_ckpt
|
198 |
+
).eval().to(extra_device)
|
199 |
+
|
200 |
+
args = get_all_args()
|
201 |
+
|
202 |
+
seed = 10086
|
203 |
+
|
204 |
+
seed_everything(seed, workers=True)
|
205 |
+
|
206 |
+
|
207 |
+
#Get JSON config from args.model_config
|
208 |
+
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
209 |
+
model_config = json.load(f)
|
210 |
+
|
211 |
+
diffusion_model = create_model_from_config(model_config)
|
212 |
+
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
|
213 |
+
diffusion_model.load_state_dict(torch.load(ckpt_path))
|
214 |
+
diffusion_model.to(device)
|
215 |
+
|
216 |
+
## speed by torch.compile
|
217 |
+
if args.compile:
|
218 |
+
diffusion_model = torch.compile(diffusion_model)
|
219 |
+
|
220 |
+
|
221 |
+
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
222 |
+
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
223 |
+
diffusion_model.pretransform.load_state_dict(load_vae_state)
|
224 |
+
|
225 |
+
def get_video_duration(video_path):
|
226 |
+
video = VideoFileClip(video_path)
|
227 |
+
return video.duration
|
228 |
+
|
229 |
+
@spaces.GPU(duration=60)
|
230 |
+
@torch.inference_mode()
|
231 |
+
@torch.no_grad()
|
232 |
+
def synthesize_video_with_audio(video_file, caption, cot):
|
233 |
+
yield "⏳ Extracting Features…", None
|
234 |
+
video_path = video_file
|
235 |
+
if caption is None:
|
236 |
+
caption = ''
|
237 |
+
if cot is None:
|
238 |
+
cot = caption
|
239 |
+
timer = Timer(duration="00:15:00:00")
|
240 |
+
#get video duration
|
241 |
+
duration_sec = get_video_duration(video_path)
|
242 |
+
print(duration_sec)
|
243 |
+
preprocesser = VGGSound(duration_sec=duration_sec)
|
244 |
+
data = preprocesser.sample(video_path, caption, cot)
|
245 |
+
|
246 |
+
|
247 |
+
preprocessed_data = {}
|
248 |
+
metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption'])
|
249 |
+
preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
|
250 |
+
preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
|
251 |
+
|
252 |
+
t5_features = feature_extractor.encode_t5_text(data['caption_cot'])
|
253 |
+
preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
|
254 |
+
|
255 |
+
clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
|
256 |
+
preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0)
|
257 |
+
|
258 |
+
sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device))
|
259 |
+
preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0)
|
260 |
+
preprocessed_data['video_exist'] = torch.tensor(True)
|
261 |
+
print("clip_shape", preprocessed_data['metaclip_features'].shape)
|
262 |
+
print("sync_shape", preprocessed_data['sync_features'].shape)
|
263 |
+
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
264 |
+
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
265 |
+
latent_seq_len = (int)(194/9*duration_sec)
|
266 |
+
diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
|
267 |
+
|
268 |
+
metadata = [preprocessed_data]
|
269 |
+
|
270 |
+
batch_size = 1
|
271 |
+
length = latent_seq_len
|
272 |
+
with torch.amp.autocast(device):
|
273 |
+
conditioning = diffusion_model.conditioner(metadata, device)
|
274 |
+
|
275 |
+
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
276 |
+
conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
|
277 |
+
conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
|
278 |
+
|
279 |
+
yield "⏳ Inferring…", None
|
280 |
+
|
281 |
+
cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
|
282 |
+
noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
|
283 |
+
with torch.amp.autocast(device):
|
284 |
+
if diffusion_model.diffusion_objective == "v":
|
285 |
+
fakes = sample(diffusion_model.model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
286 |
+
elif diffusion_model.diffusion_objective == "rectified_flow":
|
287 |
+
import time
|
288 |
+
start_time = time.time()
|
289 |
+
fakes = sample_discrete_euler(diffusion_model.model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
290 |
+
end_time = time.time()
|
291 |
+
execution_time = end_time - start_time
|
292 |
+
print(f"execution_time: {execution_time:.2f} 秒")
|
293 |
+
|
294 |
+
if diffusion_model.pretransform is not None:
|
295 |
+
fakes = diffusion_model.pretransform.decode(fakes)
|
296 |
+
|
297 |
+
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
298 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
299 |
+
torchaudio.save(tmp_audio.name, audios[0], 44100)
|
300 |
+
audio_path = tmp_audio.name
|
301 |
+
|
302 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
|
303 |
+
output_video_path = tmp_video.name
|
304 |
+
|
305 |
+
cmd = [
|
306 |
+
'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
|
307 |
+
'-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
|
308 |
+
'-shortest', output_video_path
|
309 |
+
]
|
310 |
+
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
311 |
+
|
312 |
+
# return output_video_path
|
313 |
+
yield "✅ Generation completed!", output_video_path
|
314 |
+
|
315 |
+
demo = gr.Interface(
|
316 |
+
fn=synthesize_video_with_audio,
|
317 |
+
inputs=[
|
318 |
+
gr.Video(label="Upload Video"),
|
319 |
+
gr.Textbox(label="Caption (optional)", placeholder="can be empty",),
|
320 |
+
gr.Textbox(label="CoT Description (optional)", lines=6, placeholder="can be empty",),
|
321 |
+
],
|
322 |
+
outputs=[
|
323 |
+
gr.Text(label="Status"),
|
324 |
+
gr.Video(label="Result"),
|
325 |
+
],
|
326 |
+
title="ThinkSound Demo",
|
327 |
+
description="Upload a video, caption, or CoT to generate audio. For an enhanced experience, we automatically merge the generated audio with your original silent video. (Note: Flexible audio generation lengths are supported.:)",
|
328 |
+
examples=[
|
329 |
+
["examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "Begin by creating a soft, steady background of light pacifier suckling. Add subtle, breathy rhythms to mimic a newborn's gentle mouth movements. Keep the sound smooth, natural, and soothing."],
|
330 |
+
["examples/2_mute.mp4", "Printer Printing", "Generate a continuous printer printing sound with periodic beeps and paper movement, plus a cat pawing at the machine. Add subtle ambient room noise for authenticity, keeping the focus on printing, beeps, and the cat's interaction."],
|
331 |
+
["examples/5_mute.mp4", "Lighting Firecrackers", "Generate the sound of firecrackers lighting and exploding repeatedly on the ground, followed by fireworks bursting in the sky. Incorporate occasional subtle echoes to mimic an outdoor night ambiance, with no human voices present."],
|
332 |
+
["examples/4_mute.mp4", "Plastic Debris Handling", "Begin with the sound of hands scooping up loose plastic debris, followed by the subtle cascading noise as the pieces fall and scatter back down. Include soft crinkling and rustling to emphasize the texture of the plastic. Add ambient factory background noise with distant machinery to create an industrial atmosphere."]
|
333 |
+
],
|
334 |
+
cache_examples=True
|
335 |
+
)
|
336 |
+
|
337 |
+
if __name__ == "__main__":
|
338 |
+
demo.queue().launch(share=True)
|
339 |
+
|
340 |
+
demo.launch(share=True)
|
341 |
+
|
342 |
|
|
|
|
|
343 |
|
|
|
|
cot_vgg_demo_caption.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
demo.npz
|
data_utils/ext/synchformer/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Vladimir Iashin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
data_utils/ext/synchformer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from data_utils.ext.synchformer.synchformer import Synchformer
|
data_utils/ext/synchformer/divided_224_16x4.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: Ssv2
|
4 |
+
BATCH_SIZE: 32
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
CHECKPOINT_EPOCH_RESET: True
|
9 |
+
CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
|
10 |
+
DATA:
|
11 |
+
NUM_FRAMES: 16
|
12 |
+
SAMPLING_RATE: 4
|
13 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
14 |
+
TRAIN_CROP_SIZE: 224
|
15 |
+
TEST_CROP_SIZE: 224
|
16 |
+
INPUT_CHANNEL_NUM: [3]
|
17 |
+
MEAN: [0.5, 0.5, 0.5]
|
18 |
+
STD: [0.5, 0.5, 0.5]
|
19 |
+
PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
|
20 |
+
PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
|
21 |
+
INV_UNIFORM_SAMPLE: True
|
22 |
+
RANDOM_FLIP: False
|
23 |
+
REVERSE_INPUT_CHANNEL: True
|
24 |
+
USE_RAND_AUGMENT: True
|
25 |
+
RE_PROB: 0.0
|
26 |
+
USE_REPEATED_AUG: False
|
27 |
+
USE_RANDOM_RESIZE_CROPS: False
|
28 |
+
COLORJITTER: False
|
29 |
+
GRAYSCALE: False
|
30 |
+
GAUSSIAN: False
|
31 |
+
SOLVER:
|
32 |
+
BASE_LR: 1e-4
|
33 |
+
LR_POLICY: steps_with_relative_lrs
|
34 |
+
LRS: [1, 0.1, 0.01]
|
35 |
+
STEPS: [0, 20, 30]
|
36 |
+
MAX_EPOCH: 35
|
37 |
+
MOMENTUM: 0.9
|
38 |
+
WEIGHT_DECAY: 5e-2
|
39 |
+
WARMUP_EPOCHS: 0.0
|
40 |
+
OPTIMIZING_METHOD: adamw
|
41 |
+
USE_MIXED_PRECISION: True
|
42 |
+
SMOOTHING: 0.2
|
43 |
+
SLOWFAST:
|
44 |
+
ALPHA: 8
|
45 |
+
VIT:
|
46 |
+
PATCH_SIZE: 16
|
47 |
+
PATCH_SIZE_TEMP: 2
|
48 |
+
CHANNELS: 3
|
49 |
+
EMBED_DIM: 768
|
50 |
+
DEPTH: 12
|
51 |
+
NUM_HEADS: 12
|
52 |
+
MLP_RATIO: 4
|
53 |
+
QKV_BIAS: True
|
54 |
+
VIDEO_INPUT: True
|
55 |
+
TEMPORAL_RESOLUTION: 8
|
56 |
+
USE_MLP: True
|
57 |
+
DROP: 0.0
|
58 |
+
POS_DROPOUT: 0.0
|
59 |
+
DROP_PATH: 0.2
|
60 |
+
IM_PRETRAINED: True
|
61 |
+
HEAD_DROPOUT: 0.0
|
62 |
+
HEAD_ACT: tanh
|
63 |
+
PRETRAINED_WEIGHTS: vit_1k
|
64 |
+
ATTN_LAYER: divided
|
65 |
+
MODEL:
|
66 |
+
NUM_CLASSES: 174
|
67 |
+
ARCH: slow
|
68 |
+
MODEL_NAME: VisionTransformer
|
69 |
+
LOSS_FUNC: cross_entropy
|
70 |
+
TEST:
|
71 |
+
ENABLE: True
|
72 |
+
DATASET: Ssv2
|
73 |
+
BATCH_SIZE: 64
|
74 |
+
NUM_ENSEMBLE_VIEWS: 1
|
75 |
+
NUM_SPATIAL_CROPS: 3
|
76 |
+
DATA_LOADER:
|
77 |
+
NUM_WORKERS: 4
|
78 |
+
PIN_MEMORY: True
|
79 |
+
NUM_GPUS: 8
|
80 |
+
NUM_SHARDS: 4
|
81 |
+
RNG_SEED: 0
|
82 |
+
OUTPUT_DIR: .
|
83 |
+
TENSORBOARD:
|
84 |
+
ENABLE: True
|
data_utils/ext/synchformer/motionformer.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import einops
|
5 |
+
import torch
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from timm.layers import trunc_normal_
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from data_utils.ext.synchformer.utils import check_if_file_exists_else_download
|
11 |
+
from data_utils.ext.synchformer.video_model_builder import VisionTransformer
|
12 |
+
|
13 |
+
FILE2URL = {
|
14 |
+
# cfg
|
15 |
+
'motionformer_224_16x4.yaml':
|
16 |
+
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
|
17 |
+
'joint_224_16x4.yaml':
|
18 |
+
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
|
19 |
+
'divided_224_16x4.yaml':
|
20 |
+
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
|
21 |
+
# ckpt
|
22 |
+
'ssv2_motionformer_224_16x4.pyth':
|
23 |
+
'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
|
24 |
+
'ssv2_joint_224_16x4.pyth':
|
25 |
+
'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
|
26 |
+
'ssv2_divided_224_16x4.pyth':
|
27 |
+
'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class MotionFormer(VisionTransformer):
|
32 |
+
''' This class serves three puposes:
|
33 |
+
1. Renames the class to MotionFormer.
|
34 |
+
2. Downloads the cfg from the original repo and patches it if needed.
|
35 |
+
3. Takes care of feature extraction by redefining .forward()
|
36 |
+
- if `extract_features=True` and `factorize_space_time=False`,
|
37 |
+
the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
38 |
+
- if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
|
39 |
+
and spatial and temporal transformer encoder layers are used.
|
40 |
+
- if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
|
41 |
+
the output is of shape (B, D) and spatial and temporal transformer encoder layers
|
42 |
+
are used as well as the global representation is extracted from segments (extra pos emb
|
43 |
+
is added).
|
44 |
+
'''
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
extract_features: bool = False,
|
49 |
+
ckpt_path: str = None,
|
50 |
+
factorize_space_time: bool = None,
|
51 |
+
agg_space_module: str = None,
|
52 |
+
agg_time_module: str = None,
|
53 |
+
add_global_repr: bool = True,
|
54 |
+
agg_segments_module: str = None,
|
55 |
+
max_segments: int = None,
|
56 |
+
):
|
57 |
+
self.extract_features = extract_features
|
58 |
+
self.ckpt_path = ckpt_path
|
59 |
+
self.factorize_space_time = factorize_space_time
|
60 |
+
|
61 |
+
if self.ckpt_path is not None:
|
62 |
+
check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
|
63 |
+
ckpt = torch.load(self.ckpt_path, map_location='cpu')
|
64 |
+
mformer_ckpt2cfg = {
|
65 |
+
'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
|
66 |
+
'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
|
67 |
+
'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
|
68 |
+
}
|
69 |
+
# init from motionformer ckpt or from our Stage I ckpt
|
70 |
+
# depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
|
71 |
+
# load the state dict differently
|
72 |
+
was_pt_on_avclip = self.ckpt_path.endswith(
|
73 |
+
'.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
|
74 |
+
if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
|
75 |
+
cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
|
76 |
+
elif was_pt_on_avclip:
|
77 |
+
# TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
|
78 |
+
s1_cfg = ckpt.get('args', None) # Stage I cfg
|
79 |
+
if s1_cfg is not None:
|
80 |
+
s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
|
81 |
+
# if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
|
82 |
+
if s1_vfeat_extractor_ckpt_path is not None:
|
83 |
+
cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
|
84 |
+
else:
|
85 |
+
cfg_fname = 'divided_224_16x4.yaml'
|
86 |
+
else:
|
87 |
+
cfg_fname = 'divided_224_16x4.yaml'
|
88 |
+
else:
|
89 |
+
raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
|
90 |
+
else:
|
91 |
+
was_pt_on_avclip = False
|
92 |
+
cfg_fname = 'divided_224_16x4.yaml'
|
93 |
+
# logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
|
94 |
+
|
95 |
+
if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
|
96 |
+
pos_emb_type = 'separate'
|
97 |
+
elif cfg_fname == 'joint_224_16x4.yaml':
|
98 |
+
pos_emb_type = 'joint'
|
99 |
+
|
100 |
+
self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
|
101 |
+
|
102 |
+
check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
|
103 |
+
mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
|
104 |
+
logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
|
105 |
+
|
106 |
+
# patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
|
107 |
+
mformer_cfg.VIT.ATTN_DROPOUT = 0.0
|
108 |
+
mformer_cfg.VIT.POS_EMBED = pos_emb_type
|
109 |
+
mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
|
110 |
+
mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
|
111 |
+
mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
|
112 |
+
|
113 |
+
# finally init VisionTransformer with the cfg
|
114 |
+
super().__init__(mformer_cfg)
|
115 |
+
|
116 |
+
# load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
|
117 |
+
if (self.ckpt_path is not None) and (not was_pt_on_avclip):
|
118 |
+
_ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
|
119 |
+
if len(_ckpt_load_status.missing_keys) > 0 or len(
|
120 |
+
_ckpt_load_status.unexpected_keys) > 0:
|
121 |
+
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
|
122 |
+
f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
|
123 |
+
f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
|
124 |
+
else:
|
125 |
+
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
126 |
+
|
127 |
+
if self.extract_features:
|
128 |
+
assert isinstance(self.norm,
|
129 |
+
nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
|
130 |
+
# pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
|
131 |
+
self.pre_logits = nn.Identity()
|
132 |
+
# we don't need the classification head (saving memory)
|
133 |
+
self.head = nn.Identity()
|
134 |
+
self.head_drop = nn.Identity()
|
135 |
+
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
136 |
+
transf_enc_layer_kwargs = dict(
|
137 |
+
d_model=self.embed_dim,
|
138 |
+
nhead=self.num_heads,
|
139 |
+
activation=nn.GELU(),
|
140 |
+
batch_first=True,
|
141 |
+
dim_feedforward=self.mlp_ratio * self.embed_dim,
|
142 |
+
dropout=self.drop_rate,
|
143 |
+
layer_norm_eps=1e-6,
|
144 |
+
norm_first=True,
|
145 |
+
)
|
146 |
+
# define adapters if needed
|
147 |
+
if self.factorize_space_time:
|
148 |
+
if agg_space_module == 'TransformerEncoderLayer':
|
149 |
+
self.spatial_attn_agg = SpatialTransformerEncoderLayer(
|
150 |
+
**transf_enc_layer_kwargs)
|
151 |
+
elif agg_space_module == 'AveragePooling':
|
152 |
+
self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
|
153 |
+
then_permute_pattern='BS D t -> BS t D')
|
154 |
+
if agg_time_module == 'TransformerEncoderLayer':
|
155 |
+
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
156 |
+
elif agg_time_module == 'AveragePooling':
|
157 |
+
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
158 |
+
elif 'Identity' in agg_time_module:
|
159 |
+
self.temp_attn_agg = nn.Identity()
|
160 |
+
# define a global aggregation layer (aggregarate over segments)
|
161 |
+
self.add_global_repr = add_global_repr
|
162 |
+
if add_global_repr:
|
163 |
+
if agg_segments_module == 'TransformerEncoderLayer':
|
164 |
+
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
165 |
+
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
166 |
+
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
167 |
+
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
168 |
+
add_pos_emb=True,
|
169 |
+
pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
|
170 |
+
pos_max_len=pos_max_len,
|
171 |
+
**transf_enc_layer_kwargs)
|
172 |
+
elif agg_segments_module == 'AveragePooling':
|
173 |
+
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
174 |
+
|
175 |
+
if was_pt_on_avclip:
|
176 |
+
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
177 |
+
# and keep only the state_dict of the feat extractor
|
178 |
+
ckpt_weights = dict()
|
179 |
+
for k, v in ckpt['state_dict'].items():
|
180 |
+
if k.startswith(('module.v_encoder.', 'v_encoder.')):
|
181 |
+
k = k.replace('module.', '').replace('v_encoder.', '')
|
182 |
+
ckpt_weights[k] = v
|
183 |
+
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
184 |
+
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
185 |
+
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
|
186 |
+
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
187 |
+
f'{_load_status.missing_keys}, \n' \
|
188 |
+
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
189 |
+
f'{_load_status.unexpected_keys} \n' \
|
190 |
+
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
|
191 |
+
else:
|
192 |
+
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
193 |
+
|
194 |
+
# patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
|
195 |
+
# but it used to calculate the number of patches, so we need to set keep it
|
196 |
+
self.patch_embed.requires_grad_(False)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
'''
|
200 |
+
x is of shape (B, S, C, T, H, W) where S is the number of segments.
|
201 |
+
'''
|
202 |
+
# Batch, Segments, Channels, T=frames, Height, Width
|
203 |
+
B, S, C, T, H, W = x.shape
|
204 |
+
# Motionformer expects a tensor of shape (1, B, C, T, H, W).
|
205 |
+
# The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
|
206 |
+
# see `video_model_builder.video_input`.
|
207 |
+
# x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
|
208 |
+
|
209 |
+
orig_shape = (B, S, C, T, H, W)
|
210 |
+
x = x.view(B * S, C, T, H, W) # flatten batch and segments
|
211 |
+
x = self.forward_segments(x, orig_shape=orig_shape)
|
212 |
+
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
213 |
+
x = x.view(B, S, *x.shape[1:])
|
214 |
+
# x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
|
215 |
+
|
216 |
+
return x # x is (B, S, ...)
|
217 |
+
|
218 |
+
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
|
219 |
+
'''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
|
220 |
+
x, x_mask = self.forward_features(x)
|
221 |
+
|
222 |
+
assert self.extract_features
|
223 |
+
|
224 |
+
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
225 |
+
x = x[:,
|
226 |
+
1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
|
227 |
+
x = self.norm(x)
|
228 |
+
x = self.pre_logits(x)
|
229 |
+
if self.factorize_space_time:
|
230 |
+
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
|
231 |
+
|
232 |
+
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
|
233 |
+
x = self.temp_attn_agg(
|
234 |
+
x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
|
239 |
+
'''
|
240 |
+
feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
241 |
+
Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
|
242 |
+
From `self.patch_embed_3d`, it follows that we could reshape feats with:
|
243 |
+
`feats.transpose(1, 2).view(B*S, D, t, h, w)`
|
244 |
+
'''
|
245 |
+
B, S, C, T, H, W = orig_shape
|
246 |
+
D = self.embed_dim
|
247 |
+
|
248 |
+
# num patches in each dimension
|
249 |
+
t = T // self.patch_embed_3d.z_block_size
|
250 |
+
h = self.patch_embed_3d.height
|
251 |
+
w = self.patch_embed_3d.width
|
252 |
+
|
253 |
+
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
254 |
+
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
|
255 |
+
|
256 |
+
return feats
|
257 |
+
|
258 |
+
|
259 |
+
class BaseEncoderLayer(nn.TransformerEncoderLayer):
|
260 |
+
'''
|
261 |
+
This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
|
262 |
+
to the sequence and outputs the CLS token's representation.
|
263 |
+
This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
|
264 |
+
and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
|
265 |
+
We also, optionally, add a positional embedding to the input sequence which
|
266 |
+
allows to reuse it for global aggregation (of segments) for both streams.
|
267 |
+
'''
|
268 |
+
|
269 |
+
def __init__(self,
|
270 |
+
add_pos_emb: bool = False,
|
271 |
+
pos_emb_drop: float = None,
|
272 |
+
pos_max_len: int = None,
|
273 |
+
*args_transformer_enc,
|
274 |
+
**kwargs_transformer_enc):
|
275 |
+
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
|
276 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
|
277 |
+
trunc_normal_(self.cls_token, std=.02)
|
278 |
+
|
279 |
+
# add positional embedding
|
280 |
+
self.add_pos_emb = add_pos_emb
|
281 |
+
if add_pos_emb:
|
282 |
+
self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
|
283 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
|
284 |
+
self.pos_drop = nn.Dropout(pos_emb_drop)
|
285 |
+
trunc_normal_(self.pos_emb, std=.02)
|
286 |
+
|
287 |
+
self.apply(self._init_weights)
|
288 |
+
|
289 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
|
290 |
+
''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
|
291 |
+
batch_dim = x.shape[0]
|
292 |
+
|
293 |
+
# add CLS token
|
294 |
+
cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
|
295 |
+
x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
|
296 |
+
if x_mask is not None:
|
297 |
+
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
|
298 |
+
device=x_mask.device) # 1=keep; 0=mask
|
299 |
+
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
|
300 |
+
B, N = x_mask_w_cls.shape
|
301 |
+
# torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
|
302 |
+
x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
|
303 |
+
.expand(-1, self.self_attn.num_heads, N, -1)\
|
304 |
+
.reshape(B * self.self_attn.num_heads, N, N)
|
305 |
+
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
|
306 |
+
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
|
307 |
+
else:
|
308 |
+
x_mask_w_cls = None
|
309 |
+
|
310 |
+
# add positional embedding
|
311 |
+
if self.add_pos_emb:
|
312 |
+
seq_len = x.shape[
|
313 |
+
1] # (don't even think about moving it before the CLS token concatenation)
|
314 |
+
assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
|
315 |
+
x = x + self.pos_emb[:, :seq_len, :]
|
316 |
+
x = self.pos_drop(x)
|
317 |
+
|
318 |
+
# apply encoder layer (calls nn.TransformerEncoderLayer.forward);
|
319 |
+
x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
|
320 |
+
|
321 |
+
# CLS token is expected to hold spatial information for each frame
|
322 |
+
x = x[:, 0, :] # (batch_dim, D)
|
323 |
+
|
324 |
+
return x
|
325 |
+
|
326 |
+
def _init_weights(self, m):
|
327 |
+
if isinstance(m, nn.Linear):
|
328 |
+
trunc_normal_(m.weight, std=.02)
|
329 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
330 |
+
nn.init.constant_(m.bias, 0)
|
331 |
+
elif isinstance(m, nn.LayerNorm):
|
332 |
+
nn.init.constant_(m.bias, 0)
|
333 |
+
nn.init.constant_(m.weight, 1.0)
|
334 |
+
|
335 |
+
@torch.jit.ignore
|
336 |
+
def no_weight_decay(self):
|
337 |
+
return {'cls_token', 'pos_emb'}
|
338 |
+
|
339 |
+
|
340 |
+
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
|
341 |
+
''' Aggregates spatial dimensions by applying attention individually to each frame. '''
|
342 |
+
|
343 |
+
def __init__(self, *args, **kwargs):
|
344 |
+
super().__init__(*args, **kwargs)
|
345 |
+
|
346 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
347 |
+
''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
|
348 |
+
if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
|
349 |
+
Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
|
350 |
+
BS, D, t, h, w = x.shape
|
351 |
+
|
352 |
+
# time as a batch dimension and flatten spatial dimensions as sequence
|
353 |
+
x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
|
354 |
+
# similar to mask
|
355 |
+
if x_mask is not None:
|
356 |
+
x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
|
357 |
+
|
358 |
+
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
359 |
+
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
|
360 |
+
|
361 |
+
# reshape back to (B*S, t, D)
|
362 |
+
x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
|
363 |
+
|
364 |
+
# (B*S, t, D)
|
365 |
+
return x
|
366 |
+
|
367 |
+
|
368 |
+
class TemporalTransformerEncoderLayer(BaseEncoderLayer):
|
369 |
+
''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
|
370 |
+
in both streams. '''
|
371 |
+
|
372 |
+
def __init__(self, *args, **kwargs):
|
373 |
+
super().__init__(*args, **kwargs)
|
374 |
+
|
375 |
+
def forward(self, x):
|
376 |
+
''' x is of shape (B*S, t, D) where S is the number of segments.
|
377 |
+
Returns a tensor of shape (B*S, D) pooling temporal information. '''
|
378 |
+
BS, t, D = x.shape
|
379 |
+
|
380 |
+
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
381 |
+
x = super().forward(x) # (B*S, D)
|
382 |
+
|
383 |
+
return x # (B*S, D)
|
384 |
+
|
385 |
+
|
386 |
+
class AveragePooling(nn.Module):
|
387 |
+
|
388 |
+
def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
|
389 |
+
''' patterns are e.g. "bs t d -> bs d" '''
|
390 |
+
super().__init__()
|
391 |
+
# TODO: need to register them as buffers (but fails because these are strings)
|
392 |
+
self.reduce_fn = 'mean'
|
393 |
+
self.avg_pattern = avg_pattern
|
394 |
+
self.then_permute_pattern = then_permute_pattern
|
395 |
+
|
396 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
397 |
+
x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
|
398 |
+
if self.then_permute_pattern is not None:
|
399 |
+
x = einops.rearrange(x, self.then_permute_pattern)
|
400 |
+
return x
|
data_utils/ext/synchformer/synchformer.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Mapping
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from data_utils.ext.synchformer.motionformer import MotionFormer
|
8 |
+
|
9 |
+
|
10 |
+
class Synchformer(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.vfeat_extractor = MotionFormer(extract_features=True,
|
16 |
+
factorize_space_time=True,
|
17 |
+
agg_space_module='TransformerEncoderLayer',
|
18 |
+
agg_time_module='torch.nn.Identity',
|
19 |
+
add_global_repr=False)
|
20 |
+
|
21 |
+
# self.vfeat_extractor = instantiate_from_config(vfeat_extractor)
|
22 |
+
# self.afeat_extractor = instantiate_from_config(afeat_extractor)
|
23 |
+
# # bridging the s3d latent dim (1024) into what is specified in the config
|
24 |
+
# # to match e.g. the transformer dim
|
25 |
+
# self.vproj = instantiate_from_config(vproj)
|
26 |
+
# self.aproj = instantiate_from_config(aproj)
|
27 |
+
# self.transformer = instantiate_from_config(transformer)
|
28 |
+
|
29 |
+
def forward(self, vis):
|
30 |
+
B, S, Tv, C, H, W = vis.shape
|
31 |
+
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
32 |
+
# feat extractors return a tuple of segment-level and global features (ignored for sync)
|
33 |
+
# (B, S, tv, D), e.g. (B, 7, 8, 768)
|
34 |
+
vis = self.vfeat_extractor(vis)
|
35 |
+
return vis
|
36 |
+
|
37 |
+
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
|
38 |
+
# discard all entries except vfeat_extractor
|
39 |
+
sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
|
40 |
+
|
41 |
+
return super().load_state_dict(sd, strict)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
model = Synchformer().cuda().eval()
|
46 |
+
sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
|
47 |
+
model.load_state_dict(sd)
|
48 |
+
|
49 |
+
vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
|
50 |
+
features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
|
51 |
+
print(features.shape)
|
52 |
+
|
53 |
+
# extract and save the state dict only
|
54 |
+
# sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
|
55 |
+
# torch.save(sd, './ext_weights/synchformer_state_dict.pth')
|
data_utils/ext/synchformer/utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from hashlib import md5
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a'
|
8 |
+
FNAME2LINK = {
|
9 |
+
# S3: Synchability: AudioSet (run 2)
|
10 |
+
'24-01-22T20-34-52.pt':
|
11 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt',
|
12 |
+
'cfg-24-01-22T20-34-52.yaml':
|
13 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml',
|
14 |
+
# S2: Synchformer: AudioSet (run 2)
|
15 |
+
'24-01-04T16-39-21.pt':
|
16 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt',
|
17 |
+
'cfg-24-01-04T16-39-21.yaml':
|
18 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml',
|
19 |
+
# S2: Synchformer: AudioSet (run 1)
|
20 |
+
'23-08-28T11-23-23.pt':
|
21 |
+
f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt',
|
22 |
+
'cfg-23-08-28T11-23-23.yaml':
|
23 |
+
f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml',
|
24 |
+
# S2: Synchformer: LRS3 (run 2)
|
25 |
+
'23-12-23T18-33-57.pt':
|
26 |
+
f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt',
|
27 |
+
'cfg-23-12-23T18-33-57.yaml':
|
28 |
+
f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml',
|
29 |
+
# S2: Synchformer: VGS (run 2)
|
30 |
+
'24-01-02T10-00-53.pt':
|
31 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt',
|
32 |
+
'cfg-24-01-02T10-00-53.yaml':
|
33 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml',
|
34 |
+
# SparseSync: ft VGGSound-Full
|
35 |
+
'22-09-21T21-00-52.pt':
|
36 |
+
f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt',
|
37 |
+
'cfg-22-09-21T21-00-52.yaml':
|
38 |
+
f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml',
|
39 |
+
# SparseSync: ft VGGSound-Sparse
|
40 |
+
'22-07-28T15-49-45.pt':
|
41 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt',
|
42 |
+
'cfg-22-07-28T15-49-45.yaml':
|
43 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml',
|
44 |
+
# SparseSync: only pt on LRS3
|
45 |
+
'22-07-13T22-25-49.pt':
|
46 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt',
|
47 |
+
'cfg-22-07-13T22-25-49.yaml':
|
48 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml',
|
49 |
+
# SparseSync: feature extractors
|
50 |
+
'ResNetAudio-22-08-04T09-51-04.pt':
|
51 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s
|
52 |
+
'ResNetAudio-22-08-03T23-14-49.pt':
|
53 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s
|
54 |
+
'ResNetAudio-22-08-03T23-14-28.pt':
|
55 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s
|
56 |
+
'ResNetAudio-22-06-24T08-10-33.pt':
|
57 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s
|
58 |
+
'ResNetAudio-22-06-24T17-31-07.pt':
|
59 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s
|
60 |
+
'ResNetAudio-22-06-24T23-57-11.pt':
|
61 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s
|
62 |
+
'ResNetAudio-22-06-25T04-35-42.pt':
|
63 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
|
68 |
+
'''Checks if file exists, if not downloads it from the link to the path'''
|
69 |
+
path = Path(path)
|
70 |
+
if not path.exists():
|
71 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
72 |
+
link = fname2link.get(path.name, None)
|
73 |
+
if link is None:
|
74 |
+
raise ValueError(f'Cant find the checkpoint file: {path}.',
|
75 |
+
f'Please download it manually and ensure the path exists.')
|
76 |
+
with requests.get(fname2link[path.name], stream=True) as r:
|
77 |
+
total_size = int(r.headers.get('content-length', 0))
|
78 |
+
with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
|
79 |
+
with open(path, 'wb') as f:
|
80 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
81 |
+
if data:
|
82 |
+
f.write(data)
|
83 |
+
pbar.update(chunk_size)
|
84 |
+
|
85 |
+
|
86 |
+
def get_md5sum(path):
|
87 |
+
hash_md5 = md5()
|
88 |
+
with open(path, 'rb') as f:
|
89 |
+
for chunk in iter(lambda: f.read(4096 * 8), b''):
|
90 |
+
hash_md5.update(chunk)
|
91 |
+
md5sum = hash_md5.hexdigest()
|
92 |
+
return md5sum
|
data_utils/ext/synchformer/video_model_builder.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
# Copyright 2020 Ross Wightman
|
4 |
+
# Modified Model definition
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.layers import trunc_normal_
|
12 |
+
|
13 |
+
from data_utils.ext.synchformer import vit_helper
|
14 |
+
|
15 |
+
|
16 |
+
class VisionTransformer(nn.Module):
|
17 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage """
|
18 |
+
|
19 |
+
def __init__(self, cfg):
|
20 |
+
super().__init__()
|
21 |
+
self.img_size = cfg.DATA.TRAIN_CROP_SIZE
|
22 |
+
self.patch_size = cfg.VIT.PATCH_SIZE
|
23 |
+
self.in_chans = cfg.VIT.CHANNELS
|
24 |
+
if cfg.TRAIN.DATASET == "Epickitchens":
|
25 |
+
self.num_classes = [97, 300]
|
26 |
+
else:
|
27 |
+
self.num_classes = cfg.MODEL.NUM_CLASSES
|
28 |
+
self.embed_dim = cfg.VIT.EMBED_DIM
|
29 |
+
self.depth = cfg.VIT.DEPTH
|
30 |
+
self.num_heads = cfg.VIT.NUM_HEADS
|
31 |
+
self.mlp_ratio = cfg.VIT.MLP_RATIO
|
32 |
+
self.qkv_bias = cfg.VIT.QKV_BIAS
|
33 |
+
self.drop_rate = cfg.VIT.DROP
|
34 |
+
self.drop_path_rate = cfg.VIT.DROP_PATH
|
35 |
+
self.head_dropout = cfg.VIT.HEAD_DROPOUT
|
36 |
+
self.video_input = cfg.VIT.VIDEO_INPUT
|
37 |
+
self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
|
38 |
+
self.use_mlp = cfg.VIT.USE_MLP
|
39 |
+
self.num_features = self.embed_dim
|
40 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
41 |
+
self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
|
42 |
+
self.head_act = cfg.VIT.HEAD_ACT
|
43 |
+
self.cfg = cfg
|
44 |
+
|
45 |
+
# Patch Embedding
|
46 |
+
self.patch_embed = vit_helper.PatchEmbed(img_size=224,
|
47 |
+
patch_size=self.patch_size,
|
48 |
+
in_chans=self.in_chans,
|
49 |
+
embed_dim=self.embed_dim)
|
50 |
+
|
51 |
+
# 3D Patch Embedding
|
52 |
+
self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size,
|
53 |
+
temporal_resolution=self.temporal_resolution,
|
54 |
+
patch_size=self.patch_size,
|
55 |
+
in_chans=self.in_chans,
|
56 |
+
embed_dim=self.embed_dim,
|
57 |
+
z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP)
|
58 |
+
self.patch_embed_3d.proj.weight.data = torch.zeros_like(
|
59 |
+
self.patch_embed_3d.proj.weight.data)
|
60 |
+
|
61 |
+
# Number of patches
|
62 |
+
if self.video_input:
|
63 |
+
num_patches = self.patch_embed.num_patches * self.temporal_resolution
|
64 |
+
else:
|
65 |
+
num_patches = self.patch_embed.num_patches
|
66 |
+
self.num_patches = num_patches
|
67 |
+
|
68 |
+
# CLS token
|
69 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
70 |
+
trunc_normal_(self.cls_token, std=.02)
|
71 |
+
|
72 |
+
# Positional embedding
|
73 |
+
self.pos_embed = nn.Parameter(
|
74 |
+
torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
|
75 |
+
self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
|
76 |
+
trunc_normal_(self.pos_embed, std=.02)
|
77 |
+
|
78 |
+
if self.cfg.VIT.POS_EMBED == "joint":
|
79 |
+
self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
80 |
+
trunc_normal_(self.st_embed, std=.02)
|
81 |
+
elif self.cfg.VIT.POS_EMBED == "separate":
|
82 |
+
self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
|
83 |
+
|
84 |
+
# Layer Blocks
|
85 |
+
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
|
86 |
+
if self.cfg.VIT.ATTN_LAYER == "divided":
|
87 |
+
self.blocks = nn.ModuleList([
|
88 |
+
vit_helper.DividedSpaceTimeBlock(
|
89 |
+
attn_type=cfg.VIT.ATTN_LAYER,
|
90 |
+
dim=self.embed_dim,
|
91 |
+
num_heads=self.num_heads,
|
92 |
+
mlp_ratio=self.mlp_ratio,
|
93 |
+
qkv_bias=self.qkv_bias,
|
94 |
+
drop=self.drop_rate,
|
95 |
+
attn_drop=self.attn_drop_rate,
|
96 |
+
drop_path=dpr[i],
|
97 |
+
norm_layer=norm_layer,
|
98 |
+
) for i in range(self.depth)
|
99 |
+
])
|
100 |
+
else:
|
101 |
+
self.blocks = nn.ModuleList([
|
102 |
+
vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER,
|
103 |
+
dim=self.embed_dim,
|
104 |
+
num_heads=self.num_heads,
|
105 |
+
mlp_ratio=self.mlp_ratio,
|
106 |
+
qkv_bias=self.qkv_bias,
|
107 |
+
drop=self.drop_rate,
|
108 |
+
attn_drop=self.attn_drop_rate,
|
109 |
+
drop_path=dpr[i],
|
110 |
+
norm_layer=norm_layer,
|
111 |
+
use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE)
|
112 |
+
for i in range(self.depth)
|
113 |
+
])
|
114 |
+
self.norm = norm_layer(self.embed_dim)
|
115 |
+
|
116 |
+
# MLP head
|
117 |
+
if self.use_mlp:
|
118 |
+
hidden_dim = self.embed_dim
|
119 |
+
if self.head_act == 'tanh':
|
120 |
+
# logging.info("Using TanH activation in MLP")
|
121 |
+
act = nn.Tanh()
|
122 |
+
elif self.head_act == 'gelu':
|
123 |
+
# logging.info("Using GELU activation in MLP")
|
124 |
+
act = nn.GELU()
|
125 |
+
else:
|
126 |
+
# logging.info("Using ReLU activation in MLP")
|
127 |
+
act = nn.ReLU()
|
128 |
+
self.pre_logits = nn.Sequential(
|
129 |
+
OrderedDict([
|
130 |
+
('fc', nn.Linear(self.embed_dim, hidden_dim)),
|
131 |
+
('act', act),
|
132 |
+
]))
|
133 |
+
else:
|
134 |
+
self.pre_logits = nn.Identity()
|
135 |
+
|
136 |
+
# Classifier Head
|
137 |
+
self.head_drop = nn.Dropout(p=self.head_dropout)
|
138 |
+
if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
139 |
+
for a, i in enumerate(range(len(self.num_classes))):
|
140 |
+
setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
|
141 |
+
else:
|
142 |
+
self.head = nn.Linear(self.embed_dim,
|
143 |
+
self.num_classes) if self.num_classes > 0 else nn.Identity()
|
144 |
+
|
145 |
+
# Initialize weights
|
146 |
+
self.apply(self._init_weights)
|
147 |
+
|
148 |
+
def _init_weights(self, m):
|
149 |
+
if isinstance(m, nn.Linear):
|
150 |
+
trunc_normal_(m.weight, std=.02)
|
151 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
152 |
+
nn.init.constant_(m.bias, 0)
|
153 |
+
elif isinstance(m, nn.LayerNorm):
|
154 |
+
nn.init.constant_(m.bias, 0)
|
155 |
+
nn.init.constant_(m.weight, 1.0)
|
156 |
+
|
157 |
+
@torch.jit.ignore
|
158 |
+
def no_weight_decay(self):
|
159 |
+
if self.cfg.VIT.POS_EMBED == "joint":
|
160 |
+
return {'pos_embed', 'cls_token', 'st_embed'}
|
161 |
+
else:
|
162 |
+
return {'pos_embed', 'cls_token', 'temp_embed'}
|
163 |
+
|
164 |
+
def get_classifier(self):
|
165 |
+
return self.head
|
166 |
+
|
167 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
168 |
+
self.num_classes = num_classes
|
169 |
+
self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity())
|
170 |
+
|
171 |
+
def forward_features(self, x):
|
172 |
+
# if self.video_input:
|
173 |
+
# x = x[0]
|
174 |
+
B = x.shape[0]
|
175 |
+
|
176 |
+
# Tokenize input
|
177 |
+
# if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
|
178 |
+
# for simplicity of mapping between content dimensions (input x) and token dims (after patching)
|
179 |
+
# we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
|
180 |
+
|
181 |
+
# apply patching on input
|
182 |
+
x = self.patch_embed_3d(x)
|
183 |
+
tok_mask = None
|
184 |
+
|
185 |
+
# else:
|
186 |
+
# tok_mask = None
|
187 |
+
# # 2D tokenization
|
188 |
+
# if self.video_input:
|
189 |
+
# x = x.permute(0, 2, 1, 3, 4)
|
190 |
+
# (B, T, C, H, W) = x.shape
|
191 |
+
# x = x.reshape(B * T, C, H, W)
|
192 |
+
|
193 |
+
# x = self.patch_embed(x)
|
194 |
+
|
195 |
+
# if self.video_input:
|
196 |
+
# (B2, T2, D2) = x.shape
|
197 |
+
# x = x.reshape(B, T * T2, D2)
|
198 |
+
|
199 |
+
# Append CLS token
|
200 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
201 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
202 |
+
# if tok_mask is not None:
|
203 |
+
# # prepend 1(=keep) to the mask to account for the CLS token as well
|
204 |
+
# tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
|
205 |
+
|
206 |
+
# Interpolate positinoal embeddings
|
207 |
+
# if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
|
208 |
+
# pos_embed = self.pos_embed
|
209 |
+
# N = pos_embed.shape[1] - 1
|
210 |
+
# npatch = int((x.size(1) - 1) / self.temporal_resolution)
|
211 |
+
# class_emb = pos_embed[:, 0]
|
212 |
+
# pos_embed = pos_embed[:, 1:]
|
213 |
+
# dim = x.shape[-1]
|
214 |
+
# pos_embed = torch.nn.functional.interpolate(
|
215 |
+
# pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
216 |
+
# scale_factor=math.sqrt(npatch / N),
|
217 |
+
# mode='bicubic',
|
218 |
+
# )
|
219 |
+
# pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
220 |
+
# new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
221 |
+
# else:
|
222 |
+
new_pos_embed = self.pos_embed
|
223 |
+
npatch = self.patch_embed.num_patches
|
224 |
+
|
225 |
+
# Add positional embeddings to input
|
226 |
+
if self.video_input:
|
227 |
+
if self.cfg.VIT.POS_EMBED == "separate":
|
228 |
+
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
229 |
+
tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
|
230 |
+
tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
|
231 |
+
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
232 |
+
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
|
233 |
+
x = x + total_pos_embed
|
234 |
+
elif self.cfg.VIT.POS_EMBED == "joint":
|
235 |
+
x = x + self.st_embed
|
236 |
+
else:
|
237 |
+
# image input
|
238 |
+
x = x + new_pos_embed
|
239 |
+
|
240 |
+
# Apply positional dropout
|
241 |
+
x = self.pos_drop(x)
|
242 |
+
|
243 |
+
# Encoding using transformer layers
|
244 |
+
for i, blk in enumerate(self.blocks):
|
245 |
+
x = blk(x,
|
246 |
+
seq_len=npatch,
|
247 |
+
num_frames=self.temporal_resolution,
|
248 |
+
approx=self.cfg.VIT.APPROX_ATTN_TYPE,
|
249 |
+
num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
|
250 |
+
tok_mask=tok_mask)
|
251 |
+
|
252 |
+
### v-iashin: I moved it to the forward pass
|
253 |
+
# x = self.norm(x)[:, 0]
|
254 |
+
# x = self.pre_logits(x)
|
255 |
+
###
|
256 |
+
return x, tok_mask
|
257 |
+
|
258 |
+
# def forward(self, x):
|
259 |
+
# x = self.forward_features(x)
|
260 |
+
# ### v-iashin: here. This should leave the same forward output as before
|
261 |
+
# x = self.norm(x)[:, 0]
|
262 |
+
# x = self.pre_logits(x)
|
263 |
+
# ###
|
264 |
+
# x = self.head_drop(x)
|
265 |
+
# if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
266 |
+
# output = []
|
267 |
+
# for head in range(len(self.num_classes)):
|
268 |
+
# x_out = getattr(self, "head%d" % head)(x)
|
269 |
+
# if not self.training:
|
270 |
+
# x_out = torch.nn.functional.softmax(x_out, dim=-1)
|
271 |
+
# output.append(x_out)
|
272 |
+
# return output
|
273 |
+
# else:
|
274 |
+
# x = self.head(x)
|
275 |
+
# if not self.training:
|
276 |
+
# x = torch.nn.functional.softmax(x, dim=-1)
|
277 |
+
# return x
|
data_utils/ext/synchformer/vit_helper.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
# Copyright 2020 Ross Wightman
|
4 |
+
# Modified Model definition
|
5 |
+
"""Video models."""
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from timm.layers import to_2tuple
|
13 |
+
from torch import einsum
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
default_cfgs = {
|
17 |
+
'vit_1k':
|
18 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
19 |
+
'vit_1k_large':
|
20 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def qkv_attn(q, k, v, tok_mask: torch.Tensor = None):
|
25 |
+
sim = einsum('b i d, b j d -> b i j', q, k)
|
26 |
+
# apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N)
|
27 |
+
if tok_mask is not None:
|
28 |
+
BSH, N = tok_mask.shape
|
29 |
+
sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0,
|
30 |
+
float('-inf')) # 1 - broadcasts across N
|
31 |
+
attn = sim.softmax(dim=-1)
|
32 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class DividedAttention(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
44 |
+
self.proj = nn.Linear(dim, dim)
|
45 |
+
|
46 |
+
# init to zeros
|
47 |
+
self.qkv.weight.data.fill_(0)
|
48 |
+
self.qkv.bias.data.fill_(0)
|
49 |
+
self.proj.weight.data.fill_(1)
|
50 |
+
self.proj.bias.data.fill_(0)
|
51 |
+
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
54 |
+
|
55 |
+
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
|
56 |
+
# num of heads variable
|
57 |
+
h = self.num_heads
|
58 |
+
|
59 |
+
# project x to q, k, v vaalues
|
60 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
61 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
62 |
+
if tok_mask is not None:
|
63 |
+
# replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d
|
64 |
+
assert len(tok_mask.shape) == 2
|
65 |
+
tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1])
|
66 |
+
|
67 |
+
# Scale q
|
68 |
+
q *= self.scale
|
69 |
+
|
70 |
+
# Take out cls_q, cls_k, cls_v
|
71 |
+
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
|
72 |
+
# the same for masking
|
73 |
+
if tok_mask is not None:
|
74 |
+
cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:]
|
75 |
+
else:
|
76 |
+
cls_mask, mask_ = None, None
|
77 |
+
|
78 |
+
# let CLS token attend to key / values of all patches across time and space
|
79 |
+
cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask)
|
80 |
+
|
81 |
+
# rearrange across time or space
|
82 |
+
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims),
|
83 |
+
(q_, k_, v_))
|
84 |
+
|
85 |
+
# expand CLS token keys and values across time or space and concat
|
86 |
+
r = q_.shape[0] // cls_k.shape[0]
|
87 |
+
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
|
88 |
+
|
89 |
+
k_ = torch.cat((cls_k, k_), dim=1)
|
90 |
+
v_ = torch.cat((cls_v, v_), dim=1)
|
91 |
+
|
92 |
+
# the same for masking (if provided)
|
93 |
+
if tok_mask is not None:
|
94 |
+
# since mask does not have the latent dim (d), we need to remove it from einops dims
|
95 |
+
mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''),
|
96 |
+
**einops_dims)
|
97 |
+
cls_mask = repeat(cls_mask, 'b () -> (b r) ()',
|
98 |
+
r=r) # expand cls_mask across time or space
|
99 |
+
mask_ = torch.cat((cls_mask, mask_), dim=1)
|
100 |
+
|
101 |
+
# attention
|
102 |
+
out = qkv_attn(q_, k_, v_, tok_mask=mask_)
|
103 |
+
|
104 |
+
# merge back time or space
|
105 |
+
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
|
106 |
+
|
107 |
+
# concat back the cls token
|
108 |
+
out = torch.cat((cls_out, out), dim=1)
|
109 |
+
|
110 |
+
# merge back the heads
|
111 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
112 |
+
|
113 |
+
## to out
|
114 |
+
x = self.proj(out)
|
115 |
+
x = self.proj_drop(x)
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class DividedSpaceTimeBlock(nn.Module):
|
120 |
+
|
121 |
+
def __init__(self,
|
122 |
+
dim=768,
|
123 |
+
num_heads=12,
|
124 |
+
attn_type='divided',
|
125 |
+
mlp_ratio=4.,
|
126 |
+
qkv_bias=False,
|
127 |
+
drop=0.,
|
128 |
+
attn_drop=0.,
|
129 |
+
drop_path=0.,
|
130 |
+
act_layer=nn.GELU,
|
131 |
+
norm_layer=nn.LayerNorm):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.einops_from_space = 'b (f n) d'
|
135 |
+
self.einops_to_space = '(b f) n d'
|
136 |
+
self.einops_from_time = 'b (f n) d'
|
137 |
+
self.einops_to_time = '(b n) f d'
|
138 |
+
|
139 |
+
self.norm1 = norm_layer(dim)
|
140 |
+
|
141 |
+
self.attn = DividedAttention(dim,
|
142 |
+
num_heads=num_heads,
|
143 |
+
qkv_bias=qkv_bias,
|
144 |
+
attn_drop=attn_drop,
|
145 |
+
proj_drop=drop)
|
146 |
+
|
147 |
+
self.timeattn = DividedAttention(dim,
|
148 |
+
num_heads=num_heads,
|
149 |
+
qkv_bias=qkv_bias,
|
150 |
+
attn_drop=attn_drop,
|
151 |
+
proj_drop=drop)
|
152 |
+
|
153 |
+
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
154 |
+
self.drop_path = nn.Identity()
|
155 |
+
self.norm2 = norm_layer(dim)
|
156 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
157 |
+
self.mlp = Mlp(in_features=dim,
|
158 |
+
hidden_features=mlp_hidden_dim,
|
159 |
+
act_layer=act_layer,
|
160 |
+
drop=drop)
|
161 |
+
self.norm3 = norm_layer(dim)
|
162 |
+
|
163 |
+
def forward(self,
|
164 |
+
x,
|
165 |
+
seq_len=196,
|
166 |
+
num_frames=8,
|
167 |
+
approx='none',
|
168 |
+
num_landmarks=128,
|
169 |
+
tok_mask: torch.Tensor = None):
|
170 |
+
time_output = self.timeattn(self.norm3(x),
|
171 |
+
self.einops_from_time,
|
172 |
+
self.einops_to_time,
|
173 |
+
n=seq_len,
|
174 |
+
tok_mask=tok_mask)
|
175 |
+
time_residual = x + time_output
|
176 |
+
|
177 |
+
space_output = self.attn(self.norm1(time_residual),
|
178 |
+
self.einops_from_space,
|
179 |
+
self.einops_to_space,
|
180 |
+
f=num_frames,
|
181 |
+
tok_mask=tok_mask)
|
182 |
+
space_residual = time_residual + self.drop_path(space_output)
|
183 |
+
|
184 |
+
x = space_residual
|
185 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
class Mlp(nn.Module):
|
190 |
+
|
191 |
+
def __init__(self,
|
192 |
+
in_features,
|
193 |
+
hidden_features=None,
|
194 |
+
out_features=None,
|
195 |
+
act_layer=nn.GELU,
|
196 |
+
drop=0.):
|
197 |
+
super().__init__()
|
198 |
+
out_features = out_features or in_features
|
199 |
+
hidden_features = hidden_features or in_features
|
200 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
201 |
+
self.act = act_layer()
|
202 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
203 |
+
self.drop = nn.Dropout(drop)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
x = self.fc1(x)
|
207 |
+
x = self.act(x)
|
208 |
+
x = self.drop(x)
|
209 |
+
x = self.fc2(x)
|
210 |
+
x = self.drop(x)
|
211 |
+
return x
|
212 |
+
|
213 |
+
|
214 |
+
class PatchEmbed(nn.Module):
|
215 |
+
""" Image to Patch Embedding
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
219 |
+
super().__init__()
|
220 |
+
img_size = img_size if type(img_size) is tuple else to_2tuple(img_size)
|
221 |
+
patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size)
|
222 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
223 |
+
self.img_size = img_size
|
224 |
+
self.patch_size = patch_size
|
225 |
+
self.num_patches = num_patches
|
226 |
+
|
227 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
B, C, H, W = x.shape
|
231 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
class PatchEmbed3D(nn.Module):
|
236 |
+
""" Image to Patch Embedding """
|
237 |
+
|
238 |
+
def __init__(self,
|
239 |
+
img_size=224,
|
240 |
+
temporal_resolution=4,
|
241 |
+
in_chans=3,
|
242 |
+
patch_size=16,
|
243 |
+
z_block_size=2,
|
244 |
+
embed_dim=768,
|
245 |
+
flatten=True):
|
246 |
+
super().__init__()
|
247 |
+
self.height = (img_size // patch_size)
|
248 |
+
self.width = (img_size // patch_size)
|
249 |
+
### v-iashin: these two are incorrect
|
250 |
+
# self.frames = (temporal_resolution // z_block_size)
|
251 |
+
# self.num_patches = self.height * self.width * self.frames
|
252 |
+
self.z_block_size = z_block_size
|
253 |
+
###
|
254 |
+
self.proj = nn.Conv3d(in_chans,
|
255 |
+
embed_dim,
|
256 |
+
kernel_size=(z_block_size, patch_size, patch_size),
|
257 |
+
stride=(z_block_size, patch_size, patch_size))
|
258 |
+
self.flatten = flatten
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
B, C, T, H, W = x.shape
|
262 |
+
x = self.proj(x)
|
263 |
+
if self.flatten:
|
264 |
+
x = x.flatten(2).transpose(1, 2)
|
265 |
+
return x
|
266 |
+
|
267 |
+
|
268 |
+
class HeadMLP(nn.Module):
|
269 |
+
|
270 |
+
def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
|
271 |
+
super(HeadMLP, self).__init__()
|
272 |
+
self.n_input = n_input
|
273 |
+
self.n_classes = n_classes
|
274 |
+
self.n_hidden = n_hidden
|
275 |
+
if n_hidden is None:
|
276 |
+
# use linear classifier
|
277 |
+
self.block_forward = nn.Sequential(nn.Dropout(p=p),
|
278 |
+
nn.Linear(n_input, n_classes, bias=True))
|
279 |
+
else:
|
280 |
+
# use simple MLP classifier
|
281 |
+
self.block_forward = nn.Sequential(nn.Dropout(p=p),
|
282 |
+
nn.Linear(n_input, n_hidden, bias=True),
|
283 |
+
nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True),
|
284 |
+
nn.Dropout(p=p),
|
285 |
+
nn.Linear(n_hidden, n_classes, bias=True))
|
286 |
+
print(f"Dropout-NLP: {p}")
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
return self.block_forward(x)
|
290 |
+
|
291 |
+
|
292 |
+
def _conv_filter(state_dict, patch_size=16):
|
293 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
294 |
+
out_dict = {}
|
295 |
+
for k, v in state_dict.items():
|
296 |
+
if 'patch_embed.proj.weight' in k:
|
297 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
298 |
+
out_dict[k] = v
|
299 |
+
return out_dict
|
300 |
+
|
301 |
+
|
302 |
+
def adapt_input_conv(in_chans, conv_weight, agg='sum'):
|
303 |
+
conv_type = conv_weight.dtype
|
304 |
+
conv_weight = conv_weight.float()
|
305 |
+
O, I, J, K = conv_weight.shape
|
306 |
+
if in_chans == 1:
|
307 |
+
if I > 3:
|
308 |
+
assert conv_weight.shape[1] % 3 == 0
|
309 |
+
# For models with space2depth stems
|
310 |
+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
311 |
+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
312 |
+
else:
|
313 |
+
if agg == 'sum':
|
314 |
+
print("Summing conv1 weights")
|
315 |
+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
316 |
+
else:
|
317 |
+
print("Averaging conv1 weights")
|
318 |
+
conv_weight = conv_weight.mean(dim=1, keepdim=True)
|
319 |
+
elif in_chans != 3:
|
320 |
+
if I != 3:
|
321 |
+
raise NotImplementedError('Weight format not supported by conversion.')
|
322 |
+
else:
|
323 |
+
if agg == 'sum':
|
324 |
+
print("Summing conv1 weights")
|
325 |
+
repeat = int(math.ceil(in_chans / 3))
|
326 |
+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
327 |
+
conv_weight *= (3 / float(in_chans))
|
328 |
+
else:
|
329 |
+
print("Averaging conv1 weights")
|
330 |
+
conv_weight = conv_weight.mean(dim=1, keepdim=True)
|
331 |
+
conv_weight = conv_weight.repeat(1, in_chans, 1, 1)
|
332 |
+
conv_weight = conv_weight.to(conv_type)
|
333 |
+
return conv_weight
|
334 |
+
|
335 |
+
|
336 |
+
def load_pretrained(model,
|
337 |
+
cfg=None,
|
338 |
+
num_classes=1000,
|
339 |
+
in_chans=3,
|
340 |
+
filter_fn=None,
|
341 |
+
strict=True,
|
342 |
+
progress=False):
|
343 |
+
# Load state dict
|
344 |
+
assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]")
|
345 |
+
state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS])
|
346 |
+
|
347 |
+
if filter_fn is not None:
|
348 |
+
state_dict = filter_fn(state_dict)
|
349 |
+
|
350 |
+
input_convs = 'patch_embed.proj'
|
351 |
+
if input_convs is not None and in_chans != 3:
|
352 |
+
if isinstance(input_convs, str):
|
353 |
+
input_convs = (input_convs, )
|
354 |
+
for input_conv_name in input_convs:
|
355 |
+
weight_name = input_conv_name + '.weight'
|
356 |
+
try:
|
357 |
+
state_dict[weight_name] = adapt_input_conv(in_chans,
|
358 |
+
state_dict[weight_name],
|
359 |
+
agg='avg')
|
360 |
+
print(
|
361 |
+
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)'
|
362 |
+
)
|
363 |
+
except NotImplementedError as e:
|
364 |
+
del state_dict[weight_name]
|
365 |
+
strict = False
|
366 |
+
print(
|
367 |
+
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.'
|
368 |
+
)
|
369 |
+
|
370 |
+
classifier_name = 'head'
|
371 |
+
label_offset = cfg.get('label_offset', 0)
|
372 |
+
pretrain_classes = 1000
|
373 |
+
if num_classes != pretrain_classes:
|
374 |
+
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
375 |
+
del state_dict[classifier_name + '.weight']
|
376 |
+
del state_dict[classifier_name + '.bias']
|
377 |
+
strict = False
|
378 |
+
elif label_offset > 0:
|
379 |
+
# special case for pretrained weights with an extra background class in pretrained weights
|
380 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
381 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
382 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
383 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
384 |
+
|
385 |
+
loaded_state = state_dict
|
386 |
+
self_state = model.state_dict()
|
387 |
+
all_names = set(self_state.keys())
|
388 |
+
saved_names = set([])
|
389 |
+
for name, param in loaded_state.items():
|
390 |
+
param = param
|
391 |
+
if 'module.' in name:
|
392 |
+
name = name.replace('module.', '')
|
393 |
+
if name in self_state.keys() and param.shape == self_state[name].shape:
|
394 |
+
saved_names.add(name)
|
395 |
+
self_state[name].copy_(param)
|
396 |
+
else:
|
397 |
+
print(f"didnt load: {name} of shape: {param.shape}")
|
398 |
+
print("Missing Keys:")
|
399 |
+
print(all_names - saved_names)
|
data_utils/v2a_utils/__init__.py
ADDED
File without changes
|
data_utils/v2a_utils/audio_text_dataset.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import logging
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
_CLIP_SIZE = 224
|
21 |
+
_CLIP_FPS = 8.0
|
22 |
+
|
23 |
+
_SYNC_SIZE = 224
|
24 |
+
_SYNC_FPS = 25.0
|
25 |
+
|
26 |
+
|
27 |
+
class Audio_Text(Dataset):
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
root: Union[str, Path],
|
32 |
+
*,
|
33 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
34 |
+
sample_rate: int = 44_100,
|
35 |
+
duration_sec: float = 9.0,
|
36 |
+
audio_samples: Optional[int] = 397312,
|
37 |
+
normalize_audio: bool = False,
|
38 |
+
start_row: Optional[int] = None,
|
39 |
+
end_row: Optional[int] = None,
|
40 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
41 |
+
):
|
42 |
+
self.root = Path(root)
|
43 |
+
self.normalize_audio = normalize_audio
|
44 |
+
if audio_samples is None:
|
45 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
46 |
+
else:
|
47 |
+
self.audio_samples = audio_samples
|
48 |
+
effective_duration = audio_samples / sample_rate
|
49 |
+
# make sure the duration is close enough, within 15ms
|
50 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
51 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
52 |
+
|
53 |
+
# videos = sorted(os.listdir(self.root))
|
54 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
55 |
+
videos = []
|
56 |
+
self.labels = []
|
57 |
+
self.videos = []
|
58 |
+
self.cots = []
|
59 |
+
missing_videos = []
|
60 |
+
# read the tsv for subset information
|
61 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
62 |
+
|
63 |
+
# 控制处理的行范围
|
64 |
+
if start_row is not None and end_row is not None:
|
65 |
+
df_list = df_list[start_row:end_row]
|
66 |
+
for record in df_list:
|
67 |
+
id = record['id']
|
68 |
+
if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
69 |
+
label = record['caption']
|
70 |
+
# if id in videos:
|
71 |
+
self.labels.append(label)
|
72 |
+
# print(label,'debug1!!!!!!!!!')
|
73 |
+
self.cots.append(record['caption_cot'])
|
74 |
+
# self.labels[id] = label
|
75 |
+
self.videos.append(id)
|
76 |
+
# else:
|
77 |
+
# missing_videos.append(id)
|
78 |
+
|
79 |
+
log.info(f'{len(videos)} videos found in {root}')
|
80 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
81 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
82 |
+
|
83 |
+
self.sample_rate = sample_rate
|
84 |
+
self.duration_sec = duration_sec
|
85 |
+
|
86 |
+
self.expected_audio_length = self.audio_samples
|
87 |
+
self.resampler = {}
|
88 |
+
|
89 |
+
def sample(self, idx: int):
|
90 |
+
video_id = self.videos[idx]
|
91 |
+
label = self.labels[idx]
|
92 |
+
cot = self.cots[idx]
|
93 |
+
audio_path = os.path.join(self.root, f'{video_id}.wav')
|
94 |
+
if not os.path.exists(audio_path):
|
95 |
+
audio_path = os.path.join(self.root, f'{video_id}.flac')
|
96 |
+
if not os.path.exists(audio_path):
|
97 |
+
raise RuntimeError(f'Audio is not exist {audio_path}')
|
98 |
+
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
99 |
+
if len(audio_chunk.shape) != 2:
|
100 |
+
raise RuntimeError(f'error audio shape {video_id}')
|
101 |
+
|
102 |
+
abs_max = audio_chunk[0].abs().max()
|
103 |
+
|
104 |
+
if abs_max <= 1e-6:
|
105 |
+
if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
|
106 |
+
audio_chunk = audio_chunk[1:2]
|
107 |
+
else:
|
108 |
+
raise RuntimeError(f'Audio is silent {video_id}')
|
109 |
+
|
110 |
+
# ensure the stereo audio
|
111 |
+
if audio_chunk.shape[0] < 2:
|
112 |
+
audio_chunk = audio_chunk.repeat(2, 1)
|
113 |
+
elif audio_chunk.shape[0] > 2:
|
114 |
+
audio_chunk = audio_chunk[:2]
|
115 |
+
|
116 |
+
# resample
|
117 |
+
if sample_rate == self.sample_rate:
|
118 |
+
audio_chunk = audio_chunk
|
119 |
+
else:
|
120 |
+
if sample_rate not in self.resampler:
|
121 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
122 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
123 |
+
sample_rate,
|
124 |
+
self.sample_rate,
|
125 |
+
lowpass_filter_width=64,
|
126 |
+
rolloff=0.9475937167399596,
|
127 |
+
resampling_method='sinc_interp_kaiser',
|
128 |
+
beta=14.769656459379492,
|
129 |
+
)
|
130 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
131 |
+
|
132 |
+
if audio_chunk.shape[1] < self.expected_audio_length:
|
133 |
+
# zero-padding audio
|
134 |
+
padding_length = self.expected_audio_length - audio_chunk.shape[1]
|
135 |
+
# 创建 padding 张量,大小为 [batch_size, padding_length],值为0
|
136 |
+
padding = torch.zeros(audio_chunk.shape[0], padding_length)
|
137 |
+
# 将原始音频和 padding 沿第 1 维度拼接在一起
|
138 |
+
audio_chunk = torch.cat((audio_chunk, padding), dim=1)
|
139 |
+
# raise RuntimeError(f'Audio too short {video_id}')
|
140 |
+
audio_chunk = audio_chunk[:,:self.expected_audio_length]
|
141 |
+
assert audio_chunk.shape == (2, 397312), f'error shape:{video_id},{audio_chunk.shape}'
|
142 |
+
# print(label,'debug2!!!!!!!!!')
|
143 |
+
data = {
|
144 |
+
'id': video_id,
|
145 |
+
'caption': label,
|
146 |
+
'caption_cot': cot,
|
147 |
+
'audio': audio_chunk,
|
148 |
+
}
|
149 |
+
|
150 |
+
return data
|
151 |
+
|
152 |
+
def __getitem__(self, idx: int):
|
153 |
+
try:
|
154 |
+
return self.sample(idx)
|
155 |
+
except Exception as e:
|
156 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
157 |
+
return None
|
158 |
+
|
159 |
+
def __len__(self):
|
160 |
+
return len(self.labels)
|
161 |
+
|
162 |
+
|
163 |
+
# dataset = VGGSound(
|
164 |
+
# root="data/vggsound/video/train",
|
165 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
166 |
+
# sample_rate=44100,
|
167 |
+
# duration_sec=9.0,
|
168 |
+
# audio_samples=397312,
|
169 |
+
# start_row=0,
|
170 |
+
# end_row=None,
|
171 |
+
# save_dir="data/vggsound/video_224_latents_text/train"
|
172 |
+
# )
|
173 |
+
# dataset[0]
|
data_utils/v2a_utils/audioset_224.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import logging
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
_CLIP_SIZE = 224
|
21 |
+
_CLIP_FPS = 8.0
|
22 |
+
|
23 |
+
_SYNC_SIZE = 224
|
24 |
+
_SYNC_FPS = 25.0
|
25 |
+
|
26 |
+
def save_tensor_as_image(tensor, save_path):
|
27 |
+
"""
|
28 |
+
将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
|
29 |
+
|
30 |
+
:param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
|
31 |
+
:param save_path: 图片保存路径。
|
32 |
+
"""
|
33 |
+
# # 移除批次维度,变成 (3, H, W)
|
34 |
+
# tensor = tensor.squeeze(0)
|
35 |
+
|
36 |
+
# 交换轴顺序,变为 (H, W, 3)
|
37 |
+
image_array = np.transpose(tensor, (1, 2, 0))
|
38 |
+
|
39 |
+
# 检查数组是否为合适的数据类型
|
40 |
+
if image_array.dtype != np.uint8:
|
41 |
+
# 如果不是 uint8,首先标准化,然后转换
|
42 |
+
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
|
43 |
+
image_array = image_array.astype(np.uint8)
|
44 |
+
|
45 |
+
# 创建图像对象
|
46 |
+
image = Image.fromarray(image_array)
|
47 |
+
|
48 |
+
# 保存图片
|
49 |
+
image.save(save_path)
|
50 |
+
print(f"Image saved to {save_path}")
|
51 |
+
|
52 |
+
def pad_to_square(video_tensor):
|
53 |
+
# 验证输入的形状
|
54 |
+
if len(video_tensor.shape) != 4:
|
55 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
56 |
+
|
57 |
+
l, c, h, w = video_tensor.shape
|
58 |
+
max_side = max(h, w)
|
59 |
+
|
60 |
+
# 计算每一维度需要的填充量:(left, right, top, bottom)
|
61 |
+
pad_h = max_side - h
|
62 |
+
pad_w = max_side - w
|
63 |
+
|
64 |
+
# 创建padding tuple (left, right, top, bottom)
|
65 |
+
# 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
|
66 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
67 |
+
|
68 |
+
# 使用F.pad对视频张量进行填充操作
|
69 |
+
# 填充参数为 (left, right, top, bottom)
|
70 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
71 |
+
|
72 |
+
return video_padded
|
73 |
+
|
74 |
+
class Audioset(Dataset):
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
root: Union[str, Path],
|
79 |
+
*,
|
80 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
81 |
+
sample_rate: int = 44_100,
|
82 |
+
duration_sec: float = 9.0,
|
83 |
+
audio_samples: Optional[int] = 397312,
|
84 |
+
normalize_audio: bool = False,
|
85 |
+
start_row: Optional[int] = None,
|
86 |
+
end_row: Optional[int] = None,
|
87 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
88 |
+
):
|
89 |
+
self.root = Path(root)
|
90 |
+
self.normalize_audio = normalize_audio
|
91 |
+
if audio_samples is None:
|
92 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
93 |
+
else:
|
94 |
+
self.audio_samples = audio_samples
|
95 |
+
effective_duration = audio_samples / sample_rate
|
96 |
+
# make sure the duration is close enough, within 15ms
|
97 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
98 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
99 |
+
|
100 |
+
# videos = sorted(os.listdir(self.root))
|
101 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
102 |
+
videos = []
|
103 |
+
self.labels = []
|
104 |
+
self.videos = []
|
105 |
+
self.caption_t5s = []
|
106 |
+
missing_videos = []
|
107 |
+
# read the tsv for subset information
|
108 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
109 |
+
|
110 |
+
# 控制处理的行范围
|
111 |
+
if start_row is not None and end_row is not None:
|
112 |
+
df_list = df_list[start_row:end_row]
|
113 |
+
|
114 |
+
for record in df_list:
|
115 |
+
id = record['id']
|
116 |
+
if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
117 |
+
label = record['label']
|
118 |
+
caption_t5 = record['caption_t5']
|
119 |
+
# if id in videos:
|
120 |
+
self.labels.append(label)
|
121 |
+
# self.labels[id] = label
|
122 |
+
self.videos.append(id)
|
123 |
+
self.caption_t5s.append(caption_t5)
|
124 |
+
# else:
|
125 |
+
# missing_videos.append(id)
|
126 |
+
|
127 |
+
log.info(f'{len(videos)} videos found in {root}')
|
128 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
129 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
130 |
+
|
131 |
+
self.sample_rate = sample_rate
|
132 |
+
self.duration_sec = duration_sec
|
133 |
+
|
134 |
+
self.expected_audio_length = self.audio_samples
|
135 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
136 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
137 |
+
|
138 |
+
self.clip_transform = v2.Compose([
|
139 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
140 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
141 |
+
v2.ToImage(),
|
142 |
+
v2.ToDtype(torch.float32, scale=True),
|
143 |
+
])
|
144 |
+
self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
|
145 |
+
self.sync_transform = v2.Compose([
|
146 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
147 |
+
v2.CenterCrop(_SYNC_SIZE),
|
148 |
+
v2.ToImage(),
|
149 |
+
v2.ToDtype(torch.float32, scale=True),
|
150 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
151 |
+
])
|
152 |
+
|
153 |
+
self.resampler = {}
|
154 |
+
|
155 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
156 |
+
video_id = self.videos[idx]
|
157 |
+
label = self.labels[idx]
|
158 |
+
caption_t5 = self.caption_t5s[idx]
|
159 |
+
|
160 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
161 |
+
reader.add_basic_video_stream(
|
162 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
163 |
+
frame_rate=_CLIP_FPS,
|
164 |
+
format='rgb24',
|
165 |
+
)
|
166 |
+
reader.add_basic_video_stream(
|
167 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
168 |
+
frame_rate=_SYNC_FPS,
|
169 |
+
format='rgb24',
|
170 |
+
)
|
171 |
+
# reader.add_basic_audio_stream(frames_per_chunk=2**30,)
|
172 |
+
|
173 |
+
reader.fill_buffer()
|
174 |
+
data_chunk = reader.pop_chunks()
|
175 |
+
|
176 |
+
clip_chunk = data_chunk[0]
|
177 |
+
sync_chunk = data_chunk[1]
|
178 |
+
audio_path = os.path.join("dataset/3_Audioset/audios/sound",video_id+'.wav')
|
179 |
+
assert os.path.exists(audio_path), f'{audio_path} not exists'
|
180 |
+
audio_chunk, sr = torchaudio.load(audio_path)
|
181 |
+
# audio_chunk = data_chunk[2]
|
182 |
+
if len(audio_chunk.shape) != 2:
|
183 |
+
raise RuntimeError(f'error audio shape {video_id}')
|
184 |
+
if clip_chunk is None:
|
185 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
186 |
+
|
187 |
+
if sync_chunk is None:
|
188 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
189 |
+
sample_rate = int(sr)
|
190 |
+
# audio_chunk = audio_chunk.transpose(0, 1)
|
191 |
+
abs_max = audio_chunk[0].abs().max()
|
192 |
+
# audio_chunk = audio_chunk.mean(dim=0) # mono
|
193 |
+
# if self.normalize_audio:
|
194 |
+
# abs_max = audio_chunk.abs().max()
|
195 |
+
# audio_chunk = audio_chunk / abs_max * 0.95
|
196 |
+
if abs_max <= 1e-6:
|
197 |
+
if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
|
198 |
+
audio_chunk = audio_chunk[1:2]
|
199 |
+
else:
|
200 |
+
raise RuntimeError(f'Audio is silent {video_id}')
|
201 |
+
|
202 |
+
# ensure the stereo audio
|
203 |
+
if audio_chunk.shape[0] < 2:
|
204 |
+
audio_chunk = audio_chunk.repeat(2, 1)
|
205 |
+
|
206 |
+
# resample
|
207 |
+
if sample_rate == self.sample_rate:
|
208 |
+
audio_chunk = audio_chunk
|
209 |
+
else:
|
210 |
+
if sample_rate not in self.resampler:
|
211 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
212 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
213 |
+
sample_rate,
|
214 |
+
self.sample_rate,
|
215 |
+
lowpass_filter_width=64,
|
216 |
+
rolloff=0.9475937167399596,
|
217 |
+
resampling_method='sinc_interp_kaiser',
|
218 |
+
beta=14.769656459379492,
|
219 |
+
)
|
220 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
221 |
+
|
222 |
+
if audio_chunk.shape[1] < self.expected_audio_length:
|
223 |
+
# zero-padding audio
|
224 |
+
padding_length = self.expected_audio_length - audio_chunk.shape[1]
|
225 |
+
# 创建 padding 张量,大小为 [batch_size, padding_length],值为0
|
226 |
+
padding = torch.zeros(audio_chunk.shape[0], padding_length)
|
227 |
+
# 将原始音频和 padding 沿第 1 维度拼接在一起
|
228 |
+
audio_chunk = torch.cat((audio_chunk, padding), dim=1)
|
229 |
+
# raise RuntimeError(f'Audio too short {video_id}')
|
230 |
+
audio_chunk = audio_chunk[:,:self.expected_audio_length]
|
231 |
+
# truncate the video
|
232 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
233 |
+
# import ipdb
|
234 |
+
# ipdb.set_trace()
|
235 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
236 |
+
current_length = clip_chunk.shape[0]
|
237 |
+
padding_needed = self.clip_expected_length - current_length
|
238 |
+
|
239 |
+
# Check that padding needed is no more than 2
|
240 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
241 |
+
|
242 |
+
# If assertion passes, proceed with padding
|
243 |
+
if padding_needed > 0:
|
244 |
+
last_frame = clip_chunk[-1]
|
245 |
+
log.info(last_frame.shape)
|
246 |
+
# Repeat the last frame to reach the expected length
|
247 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
248 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
249 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
250 |
+
# f'expected {self.clip_expected_length}, '
|
251 |
+
# f'got {clip_chunk.shape[0]}')
|
252 |
+
|
253 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
254 |
+
clip_chunk = pad_to_square(clip_chunk)
|
255 |
+
# save_image(clip_chunk[0] / 255.0,'square.png')
|
256 |
+
# clip_chunk = self.clip_transform(clip_chunk)
|
257 |
+
# import ipdb
|
258 |
+
# ipdb.set_trace()
|
259 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
260 |
+
# log.info(clip_chunk.shape)
|
261 |
+
# save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
|
262 |
+
# log.info(clip_chunk[0])
|
263 |
+
# clip_chunk = outputs
|
264 |
+
# text_ids = outputs["input_ids"]
|
265 |
+
# temp_img = clip_chunk[0].permute(1, 2, 0) * 255
|
266 |
+
# save_image(clip_chunk[0],'scale.png')
|
267 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
268 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
269 |
+
# padding using the last frame, but no more than 2
|
270 |
+
current_length = sync_chunk.shape[0]
|
271 |
+
last_frame = sync_chunk[-1]
|
272 |
+
# 重复最后一帧以进行填充
|
273 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
274 |
+
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
275 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
276 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
277 |
+
# f'expected {self.sync_expected_length}, '
|
278 |
+
# f'got {sync_chunk.shape[0]}')
|
279 |
+
|
280 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
281 |
+
assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
|
282 |
+
and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
283 |
+
data = {
|
284 |
+
'id': video_id,
|
285 |
+
'caption': label,
|
286 |
+
'caption_t5': caption_t5,
|
287 |
+
'audio': audio_chunk,
|
288 |
+
'clip_video': clip_chunk,
|
289 |
+
'sync_video': sync_chunk,
|
290 |
+
}
|
291 |
+
|
292 |
+
return data
|
293 |
+
|
294 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
295 |
+
try:
|
296 |
+
return self.sample(idx)
|
297 |
+
except Exception as e:
|
298 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
299 |
+
return None
|
300 |
+
|
301 |
+
def __len__(self):
|
302 |
+
return len(self.labels)
|
303 |
+
|
304 |
+
|
305 |
+
# dataset = Audioset(
|
306 |
+
# root="dataset/3_Audioset/video/sound",
|
307 |
+
# tsv_path="dataset/3_Audioset/split_txt/unbalanced_sound_filtered_aligned_novgg_noout.csv",
|
308 |
+
# sample_rate=44100,
|
309 |
+
# duration_sec=9.0,
|
310 |
+
# audio_samples=397312,
|
311 |
+
# start_row=0,
|
312 |
+
# end_row=None,
|
313 |
+
# save_dir="dataset/3_Audioset/video_text_latents/"
|
314 |
+
# )
|
315 |
+
# dataset[0]
|
data_utils/v2a_utils/audioset_video_224.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import logging
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
_CLIP_SIZE = 224
|
21 |
+
_CLIP_FPS = 8.0
|
22 |
+
|
23 |
+
_SYNC_SIZE = 224
|
24 |
+
_SYNC_FPS = 25.0
|
25 |
+
|
26 |
+
def save_tensor_as_image(tensor, save_path):
|
27 |
+
"""
|
28 |
+
将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
|
29 |
+
|
30 |
+
:param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
|
31 |
+
:param save_path: 图片保存路径。
|
32 |
+
"""
|
33 |
+
# # 移除批次维度,变成 (3, H, W)
|
34 |
+
# tensor = tensor.squeeze(0)
|
35 |
+
|
36 |
+
# 交换轴顺序,变为 (H, W, 3)
|
37 |
+
image_array = np.transpose(tensor, (1, 2, 0))
|
38 |
+
|
39 |
+
# 检查数组是否为合适的数据类型
|
40 |
+
if image_array.dtype != np.uint8:
|
41 |
+
# 如果不是 uint8,首先标准化,然后转换
|
42 |
+
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
|
43 |
+
image_array = image_array.astype(np.uint8)
|
44 |
+
|
45 |
+
# 创建图像对象
|
46 |
+
image = Image.fromarray(image_array)
|
47 |
+
|
48 |
+
# 保存图片
|
49 |
+
image.save(save_path)
|
50 |
+
print(f"Image saved to {save_path}")
|
51 |
+
|
52 |
+
def pad_to_square(video_tensor):
|
53 |
+
# 验证输入的形状
|
54 |
+
if len(video_tensor.shape) != 4:
|
55 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
56 |
+
|
57 |
+
l, c, h, w = video_tensor.shape
|
58 |
+
max_side = max(h, w)
|
59 |
+
|
60 |
+
# 计算每一维度需要的填充量:(left, right, top, bottom)
|
61 |
+
pad_h = max_side - h
|
62 |
+
pad_w = max_side - w
|
63 |
+
|
64 |
+
# 创建padding tuple (left, right, top, bottom)
|
65 |
+
# 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
|
66 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
67 |
+
|
68 |
+
# 使用F.pad对视频张量进行填充操作
|
69 |
+
# 填充参数为 (left, right, top, bottom)
|
70 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
71 |
+
|
72 |
+
return video_padded
|
73 |
+
|
74 |
+
class Audioset(Dataset):
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
root: Union[str, Path],
|
79 |
+
*,
|
80 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
81 |
+
duration_sec: float = 10.0,
|
82 |
+
start_row: Optional[int] = None,
|
83 |
+
end_row: Optional[int] = None,
|
84 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
85 |
+
):
|
86 |
+
self.root = Path(root)
|
87 |
+
|
88 |
+
# videos = sorted(os.listdir(self.root))
|
89 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
90 |
+
videos = []
|
91 |
+
self.captions = []
|
92 |
+
self.videos = []
|
93 |
+
self.caption_t5s = []
|
94 |
+
missing_videos = []
|
95 |
+
# read the tsv for subset information
|
96 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
97 |
+
|
98 |
+
# 控制处理的行范围
|
99 |
+
if start_row is not None and end_row is not None:
|
100 |
+
df_list = df_list[start_row:end_row]
|
101 |
+
with open(tsv_path.replace('.csv','.txt')) as file:
|
102 |
+
paths = file.readlines()
|
103 |
+
for record, path in zip(df_list,paths):
|
104 |
+
id = Path(record['id']).stem
|
105 |
+
# if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
106 |
+
caption = record['caption']
|
107 |
+
caption_t5 = record['caption_t5']
|
108 |
+
path = path.strip()
|
109 |
+
part = Path(path).parent
|
110 |
+
video_id = Path(path).stem[1:]
|
111 |
+
video_path = os.path.join('dataset/3_Audioset/video',part,f'{video_id}.mp4')
|
112 |
+
assert os.path.exists(video_path), 'video must exist'
|
113 |
+
# if id in videos:
|
114 |
+
self.captions.append(caption)
|
115 |
+
self.caption_t5s.append(caption_t5)
|
116 |
+
# self.labels[id] = label
|
117 |
+
self.videos.append(video_path)
|
118 |
+
# else:
|
119 |
+
# missing_videos.append(id)
|
120 |
+
assert len(self.captions) == len(self.caption_t5s) and len(self.captions) == len(self.videos), 'error length'
|
121 |
+
log.info(f'{len(videos)} videos found in {root}')
|
122 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
123 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
124 |
+
|
125 |
+
self.duration_sec = duration_sec
|
126 |
+
|
127 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
128 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
129 |
+
|
130 |
+
self.clip_transform = v2.Compose([
|
131 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
132 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
133 |
+
v2.ToImage(),
|
134 |
+
v2.ToDtype(torch.float32, scale=True),
|
135 |
+
])
|
136 |
+
self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
|
137 |
+
self.sync_transform = v2.Compose([
|
138 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
139 |
+
v2.CenterCrop(_SYNC_SIZE),
|
140 |
+
v2.ToImage(),
|
141 |
+
v2.ToDtype(torch.float32, scale=True),
|
142 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
143 |
+
])
|
144 |
+
|
145 |
+
self.resampler = {}
|
146 |
+
|
147 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
148 |
+
video_path = self.videos[idx]
|
149 |
+
video_id = 'Y'+str(Path(video_path).stem)
|
150 |
+
caption = self.captions[idx]
|
151 |
+
caption_t5 = self.caption_t5s[idx]
|
152 |
+
|
153 |
+
reader = StreamingMediaDecoder(video_path)
|
154 |
+
reader.add_basic_video_stream(
|
155 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
156 |
+
frame_rate=_CLIP_FPS,
|
157 |
+
format='rgb24',
|
158 |
+
)
|
159 |
+
reader.add_basic_video_stream(
|
160 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
161 |
+
frame_rate=_SYNC_FPS,
|
162 |
+
format='rgb24',
|
163 |
+
)
|
164 |
+
|
165 |
+
reader.fill_buffer()
|
166 |
+
data_chunk = reader.pop_chunks()
|
167 |
+
|
168 |
+
clip_chunk = data_chunk[0]
|
169 |
+
sync_chunk = data_chunk[1]
|
170 |
+
|
171 |
+
if clip_chunk is None:
|
172 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
173 |
+
# if clip_chunk.shape[0] < self.clip_expected_length:
|
174 |
+
# raise RuntimeError(
|
175 |
+
# f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
176 |
+
# )
|
177 |
+
|
178 |
+
if sync_chunk is None:
|
179 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
180 |
+
# if sync_chunk.shape[0] < self.sync_expected_length:
|
181 |
+
# raise RuntimeError(
|
182 |
+
# f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
183 |
+
# )
|
184 |
+
|
185 |
+
|
186 |
+
# truncate the video
|
187 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
188 |
+
# import ipdb
|
189 |
+
# ipdb.set_trace()
|
190 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
191 |
+
current_length = clip_chunk.shape[0]
|
192 |
+
padding_needed = self.clip_expected_length - current_length
|
193 |
+
|
194 |
+
# Check that padding needed is no more than 2
|
195 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
196 |
+
|
197 |
+
# If assertion passes, proceed with padding
|
198 |
+
if padding_needed > 0:
|
199 |
+
last_frame = clip_chunk[-1]
|
200 |
+
log.info(clip_chunk.shape)
|
201 |
+
# Repeat the last frame to reach the expected length
|
202 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
203 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
204 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
205 |
+
# f'expected {self.clip_expected_length}, '
|
206 |
+
# f'got {clip_chunk.shape[0]}')
|
207 |
+
|
208 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
209 |
+
clip_chunk = pad_to_square(clip_chunk)
|
210 |
+
# save_image(clip_chunk[0] / 255.0,'square.png')
|
211 |
+
# clip_chunk = self.clip_transform(clip_chunk)
|
212 |
+
# import ipdb
|
213 |
+
# ipdb.set_trace()
|
214 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
215 |
+
# log.info(clip_chunk.shape)
|
216 |
+
# save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
|
217 |
+
# log.info(clip_chunk[0])
|
218 |
+
# clip_chunk = outputs
|
219 |
+
# text_ids = outputs["input_ids"]
|
220 |
+
# temp_img = clip_chunk[0].permute(1, 2, 0) * 255
|
221 |
+
# save_image(clip_chunk[0],'scale.png')
|
222 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
223 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
224 |
+
# padding using the last frame, but no more than 2
|
225 |
+
current_length = sync_chunk.shape[0]
|
226 |
+
last_frame = sync_chunk[-1]
|
227 |
+
# 重复最后一帧以进行填充
|
228 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
229 |
+
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
230 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
231 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
232 |
+
# f'expected {self.sync_expected_length}, '
|
233 |
+
# f'got {sync_chunk.shape[0]}')
|
234 |
+
|
235 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
236 |
+
assert clip_chunk.shape[0] == self.clip_expected_length and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
237 |
+
data = {
|
238 |
+
'id': video_id,
|
239 |
+
'caption': caption,
|
240 |
+
'caption_t5': caption_t5,
|
241 |
+
'clip_video': clip_chunk,
|
242 |
+
'sync_video': sync_chunk,
|
243 |
+
}
|
244 |
+
|
245 |
+
return data
|
246 |
+
|
247 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
248 |
+
try:
|
249 |
+
return self.sample(idx)
|
250 |
+
except Exception as e:
|
251 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
252 |
+
return None
|
253 |
+
|
254 |
+
def __len__(self):
|
255 |
+
return len(self.captions)
|
256 |
+
|
257 |
+
|
258 |
+
# dataset = VGGSound(
|
259 |
+
# root="data/vggsound/video/train",
|
260 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
261 |
+
# sample_rate=44100,
|
262 |
+
# duration_sec=9.0,
|
263 |
+
# audio_samples=397312,
|
264 |
+
# start_row=0,
|
265 |
+
# end_row=None,
|
266 |
+
# save_dir="data/vggsound/video_224_latents_text/train"
|
267 |
+
# )
|
268 |
+
# dataset[0]
|
data_utils/v2a_utils/feature_utils_224.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
import json
|
3 |
+
import open_clip
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from open_clip import create_model_from_pretrained
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
from ThinkSound.models.factory import create_model_from_config
|
11 |
+
from ThinkSound.models.utils import load_ckpt_state_dict
|
12 |
+
from ThinkSound.training.utils import copy_state_dict
|
13 |
+
from transformers import AutoModel
|
14 |
+
from transformers import AutoProcessor
|
15 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
16 |
+
import logging
|
17 |
+
from data_utils.ext.synchformer import Synchformer
|
18 |
+
|
19 |
+
log = logging.getLogger()
|
20 |
+
|
21 |
+
def patch_clip(clip_model):
|
22 |
+
# a hack to make it output last hidden states
|
23 |
+
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
|
24 |
+
def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
|
25 |
+
output_attentions: Optional[bool] = None,
|
26 |
+
output_hidden_states: Optional[bool] = None,
|
27 |
+
return_dict: Optional[bool] = None):
|
28 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
29 |
+
output_hidden_states = (
|
30 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
31 |
+
)
|
32 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
33 |
+
|
34 |
+
text_outputs = self.text_model(
|
35 |
+
input_ids=input_ids,
|
36 |
+
attention_mask=attention_mask,
|
37 |
+
position_ids=position_ids,
|
38 |
+
output_attentions=output_attentions,
|
39 |
+
output_hidden_states=output_hidden_states,
|
40 |
+
return_dict=return_dict,
|
41 |
+
)
|
42 |
+
last_hidden_state = text_outputs[0]
|
43 |
+
pooled_output = text_outputs[1]
|
44 |
+
text_features = self.text_projection(pooled_output)
|
45 |
+
|
46 |
+
return text_features, last_hidden_state
|
47 |
+
|
48 |
+
clip_model.get_text_features = new_get_text_features.__get__(clip_model)
|
49 |
+
return clip_model
|
50 |
+
|
51 |
+
|
52 |
+
class FeaturesUtils(nn.Module):
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
*,
|
57 |
+
vae_ckpt: Optional[str] = None,
|
58 |
+
vae_config: Optional[str] = None,
|
59 |
+
synchformer_ckpt: Optional[str] = None,
|
60 |
+
enable_conditions: bool = True,
|
61 |
+
need_vae_encoder: bool = True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
if enable_conditions:
|
66 |
+
self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
67 |
+
self.clip_model = patch_clip(self.clip_model)
|
68 |
+
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
|
69 |
+
self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl")
|
70 |
+
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
71 |
+
# self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
72 |
+
# std=[0.26862954, 0.26130258, 0.27577711])
|
73 |
+
self.synchformer = Synchformer()
|
74 |
+
self.synchformer.load_state_dict(
|
75 |
+
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
|
76 |
+
|
77 |
+
# self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
|
78 |
+
else:
|
79 |
+
self.clip_model = None
|
80 |
+
self.synchformer = None
|
81 |
+
self.tokenizer = None
|
82 |
+
|
83 |
+
if vae_ckpt is not None:
|
84 |
+
with open(vae_config) as f:
|
85 |
+
vae_config = json.load(f)
|
86 |
+
self.vae = create_model_from_config(vae_config)
|
87 |
+
print(f"Loading model checkpoint from {vae_ckpt}")
|
88 |
+
# Load checkpoint
|
89 |
+
copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
|
90 |
+
else:
|
91 |
+
self.tod = None
|
92 |
+
|
93 |
+
def compile(self):
|
94 |
+
if self.clip_model is not None:
|
95 |
+
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
96 |
+
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
97 |
+
if self.synchformer is not None:
|
98 |
+
self.synchformer = torch.compile(self.synchformer)
|
99 |
+
|
100 |
+
|
101 |
+
def train(self, mode: bool) -> None:
|
102 |
+
return super().train(False)
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
106 |
+
assert self.clip_model is not None, 'CLIP is not loaded'
|
107 |
+
# x: (B, T, C, H, W) H/W: 384
|
108 |
+
b, t, c, h, w = x.shape
|
109 |
+
|
110 |
+
assert c == 3 and h == 224 and w == 224
|
111 |
+
# x = self.clip_preprocess(x)
|
112 |
+
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
113 |
+
outputs = []
|
114 |
+
if batch_size < 0:
|
115 |
+
batch_size = b * t
|
116 |
+
for i in range(0, b * t, batch_size):
|
117 |
+
outputs.append(self.clip_model.get_image_features(x[i:i + batch_size]))
|
118 |
+
x = torch.cat(outputs, dim=0)
|
119 |
+
# x = self.clip_model.encode_image(x, normalize=True)
|
120 |
+
x = rearrange(x, '(b t) d -> b t d', b=b)
|
121 |
+
return x
|
122 |
+
|
123 |
+
@torch.inference_mode()
|
124 |
+
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
125 |
+
assert self.synchformer is not None, 'Synchformer is not loaded'
|
126 |
+
# x: (B, T, C, H, W) H/W: 384
|
127 |
+
b, t, c, h, w = x.shape
|
128 |
+
# import ipdb
|
129 |
+
# ipdb.set_trace()
|
130 |
+
assert c == 3 and h == 224 and w == 224
|
131 |
+
|
132 |
+
# partition the video
|
133 |
+
segment_size = 16
|
134 |
+
step_size = 8
|
135 |
+
num_segments = (t - segment_size) // step_size + 1
|
136 |
+
segments = []
|
137 |
+
for i in range(num_segments):
|
138 |
+
segments.append(x[:, i * step_size:i * step_size + segment_size])
|
139 |
+
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
|
140 |
+
|
141 |
+
outputs = []
|
142 |
+
if batch_size < 0:
|
143 |
+
batch_size = b
|
144 |
+
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
145 |
+
for i in range(0, b * num_segments, batch_size):
|
146 |
+
outputs.append(self.synchformer(x[i:i + batch_size]))
|
147 |
+
x = torch.cat(outputs, dim=0)
|
148 |
+
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
149 |
+
return x
|
150 |
+
|
151 |
+
@torch.inference_mode()
|
152 |
+
def encode_text(self, text: list[str]) -> torch.Tensor:
|
153 |
+
assert self.clip_model is not None, 'CLIP is not loaded'
|
154 |
+
# assert self.tokenizer is not None, 'Tokenizer is not loaded'
|
155 |
+
# x: (B, L)
|
156 |
+
tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device)
|
157 |
+
return self.clip_model.get_text_features(**tokens)
|
158 |
+
|
159 |
+
@torch.inference_mode()
|
160 |
+
def encode_t5_text(self, text: list[str]) -> torch.Tensor:
|
161 |
+
assert self.t5_model is not None, 'T5 model is not loaded'
|
162 |
+
assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded'
|
163 |
+
# x: (B, L)
|
164 |
+
inputs = self.t5_tokenizer(text,
|
165 |
+
truncation=True,
|
166 |
+
max_length=77,
|
167 |
+
padding="max_length",
|
168 |
+
return_tensors="pt").to(self.device)
|
169 |
+
return self.t5_model(**inputs).last_hidden_state
|
170 |
+
|
171 |
+
@torch.inference_mode()
|
172 |
+
def encode_audio(self, x) -> torch.Tensor:
|
173 |
+
x = self.vae.encode(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
@property
|
177 |
+
def device(self):
|
178 |
+
return next(self.parameters()).device
|
179 |
+
|
180 |
+
@property
|
181 |
+
def dtype(self):
|
182 |
+
return next(self.parameters()).dtype
|
data_utils/v2a_utils/vggsound.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
|
14 |
+
log = logging.getLogger()
|
15 |
+
|
16 |
+
_CLIP_SIZE = 384
|
17 |
+
_CLIP_FPS = 8.0
|
18 |
+
|
19 |
+
_SYNC_SIZE = 224
|
20 |
+
_SYNC_FPS = 25.0
|
21 |
+
|
22 |
+
|
23 |
+
class VGGSound(Dataset):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
root: Union[str, Path],
|
28 |
+
*,
|
29 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
30 |
+
sample_rate: int = 44_100,
|
31 |
+
duration_sec: float = 9.0,
|
32 |
+
audio_samples: Optional[int] = 397312,
|
33 |
+
normalize_audio: bool = False,
|
34 |
+
start_row: Optional[int] = None,
|
35 |
+
end_row: Optional[int] = None,
|
36 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
37 |
+
):
|
38 |
+
self.root = Path(root)
|
39 |
+
self.normalize_audio = normalize_audio
|
40 |
+
if audio_samples is None:
|
41 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
42 |
+
else:
|
43 |
+
self.audio_samples = audio_samples
|
44 |
+
effective_duration = audio_samples / sample_rate
|
45 |
+
# make sure the duration is close enough, within 15ms
|
46 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
47 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
48 |
+
|
49 |
+
videos = sorted(os.listdir(self.root))
|
50 |
+
videos = set([Path(v).stem for v in videos]) # remove extensions
|
51 |
+
# videos = []
|
52 |
+
self.labels = []
|
53 |
+
self.videos = []
|
54 |
+
missing_videos = []
|
55 |
+
# read the tsv for subset information
|
56 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
57 |
+
|
58 |
+
# 控制处理的行范围
|
59 |
+
if start_row is not None and end_row is not None:
|
60 |
+
df_list = df_list[start_row:end_row]
|
61 |
+
|
62 |
+
for record in df_list:
|
63 |
+
id = record['id']
|
64 |
+
if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
65 |
+
label = record['caption']
|
66 |
+
if id in videos:
|
67 |
+
# self.labels.append(label)
|
68 |
+
self.labels[id] = label
|
69 |
+
self.videos.append(id)
|
70 |
+
else:
|
71 |
+
missing_videos.append(id)
|
72 |
+
|
73 |
+
log.info(f'{len(videos)} videos found in {root}')
|
74 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
75 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
76 |
+
|
77 |
+
self.sample_rate = sample_rate
|
78 |
+
self.duration_sec = duration_sec
|
79 |
+
|
80 |
+
self.expected_audio_length = self.audio_samples
|
81 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
82 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
83 |
+
|
84 |
+
self.clip_transform = v2.Compose([
|
85 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
86 |
+
v2.ToImage(),
|
87 |
+
v2.ToDtype(torch.float32, scale=True),
|
88 |
+
])
|
89 |
+
|
90 |
+
self.sync_transform = v2.Compose([
|
91 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
92 |
+
v2.CenterCrop(_SYNC_SIZE),
|
93 |
+
v2.ToImage(),
|
94 |
+
v2.ToDtype(torch.float32, scale=True),
|
95 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
96 |
+
])
|
97 |
+
|
98 |
+
self.resampler = {}
|
99 |
+
|
100 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
101 |
+
video_id = self.videos[idx]
|
102 |
+
label = self.labels[idx]
|
103 |
+
|
104 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
105 |
+
reader.add_basic_video_stream(
|
106 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
107 |
+
frame_rate=_CLIP_FPS,
|
108 |
+
format='rgb24',
|
109 |
+
)
|
110 |
+
reader.add_basic_video_stream(
|
111 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
112 |
+
frame_rate=_SYNC_FPS,
|
113 |
+
format='rgb24',
|
114 |
+
)
|
115 |
+
reader.add_basic_audio_stream(frames_per_chunk=2**30,)
|
116 |
+
|
117 |
+
reader.fill_buffer()
|
118 |
+
data_chunk = reader.pop_chunks()
|
119 |
+
|
120 |
+
clip_chunk = data_chunk[0]
|
121 |
+
sync_chunk = data_chunk[1]
|
122 |
+
audio_chunk = data_chunk[2]
|
123 |
+
if len(audio_chunk.shape) != 2:
|
124 |
+
raise RuntimeError(f'error audio shape {video_id}')
|
125 |
+
if clip_chunk is None:
|
126 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
127 |
+
# if clip_chunk.shape[0] < self.clip_expected_length:
|
128 |
+
# raise RuntimeError(
|
129 |
+
# f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
130 |
+
# )
|
131 |
+
|
132 |
+
if sync_chunk is None:
|
133 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
134 |
+
# if sync_chunk.shape[0] < self.sync_expected_length:
|
135 |
+
# raise RuntimeError(
|
136 |
+
# f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
137 |
+
# )
|
138 |
+
# import ipdb
|
139 |
+
# ipdb.set_trace()
|
140 |
+
# process audio
|
141 |
+
sample_rate = int(reader.get_out_stream_info(2).sample_rate)
|
142 |
+
audio_chunk = audio_chunk.transpose(0, 1)
|
143 |
+
abs_max = audio_chunk[0].abs().max()
|
144 |
+
# audio_chunk = audio_chunk.mean(dim=0) # mono
|
145 |
+
# if self.normalize_audio:
|
146 |
+
# abs_max = audio_chunk.abs().max()
|
147 |
+
# audio_chunk = audio_chunk / abs_max * 0.95
|
148 |
+
if abs_max <= 1e-6:
|
149 |
+
if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
|
150 |
+
audio_chunk = audio_chunk[1:2]
|
151 |
+
else:
|
152 |
+
raise RuntimeError(f'Audio is silent {video_id}')
|
153 |
+
|
154 |
+
|
155 |
+
# if abs_max <= 1e-6:
|
156 |
+
# raise RuntimeError(f'Audio is silent {video_id}')
|
157 |
+
|
158 |
+
# ensure the stereo audio
|
159 |
+
if audio_chunk.shape[0] < 2:
|
160 |
+
audio_chunk = audio_chunk.repeat(2, 1)
|
161 |
+
|
162 |
+
# resample
|
163 |
+
if sample_rate == self.sample_rate:
|
164 |
+
audio_chunk = audio_chunk
|
165 |
+
else:
|
166 |
+
if sample_rate not in self.resampler:
|
167 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
168 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
169 |
+
sample_rate,
|
170 |
+
self.sample_rate,
|
171 |
+
lowpass_filter_width=64,
|
172 |
+
rolloff=0.9475937167399596,
|
173 |
+
resampling_method='sinc_interp_kaiser',
|
174 |
+
beta=14.769656459379492,
|
175 |
+
)
|
176 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
177 |
+
|
178 |
+
if audio_chunk.shape[1] < self.expected_audio_length:
|
179 |
+
# zero-padding audio
|
180 |
+
padding_length = self.expected_audio_length - audio_chunk.shape[1]
|
181 |
+
# 创建 padding 张量,大小为 [batch_size, padding_length],值为0
|
182 |
+
padding = torch.zeros(audio_chunk.shape[0], padding_length)
|
183 |
+
# 将原始音频和 padding 沿第 1 维度拼接在一起
|
184 |
+
audio_chunk = torch.cat((audio_chunk, padding), dim=1)
|
185 |
+
# raise RuntimeError(f'Audio too short {video_id}')
|
186 |
+
audio_chunk = audio_chunk[:,:self.expected_audio_length]
|
187 |
+
# truncate the video
|
188 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
189 |
+
# import ipdb
|
190 |
+
# ipdb.set_trace()
|
191 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
192 |
+
current_length = clip_chunk.shape[0]
|
193 |
+
padding_needed = self.clip_expected_length - current_length
|
194 |
+
|
195 |
+
# Check that padding needed is no more than 2
|
196 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
197 |
+
|
198 |
+
# If assertion passes, proceed with padding
|
199 |
+
if padding_needed > 0:
|
200 |
+
last_frame = clip_chunk[-1]
|
201 |
+
log.info(last_frame.shape)
|
202 |
+
# Repeat the last frame to reach the expected length
|
203 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
204 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
205 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
206 |
+
# f'expected {self.clip_expected_length}, '
|
207 |
+
# f'got {clip_chunk.shape[0]}')
|
208 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
209 |
+
clip_chunk = self.clip_transform(clip_chunk)
|
210 |
+
# temp_img = clip_chunk[0].permute(1, 2, 0) * 255
|
211 |
+
# save_image(clip_chunk[0],'scale.png')
|
212 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
213 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
214 |
+
# padding using the last frame, but no more than 2
|
215 |
+
current_length = sync_chunk.shape[0]
|
216 |
+
last_frame = sync_chunk[-1]
|
217 |
+
# 重复最后一帧以进行填充
|
218 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
219 |
+
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
220 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
221 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
222 |
+
# f'expected {self.sync_expected_length}, '
|
223 |
+
# f'got {sync_chunk.shape[0]}')
|
224 |
+
|
225 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
226 |
+
assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
|
227 |
+
and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
228 |
+
data = {
|
229 |
+
'id': video_id,
|
230 |
+
'caption': label,
|
231 |
+
'audio': audio_chunk,
|
232 |
+
'clip_video': clip_chunk,
|
233 |
+
'sync_video': sync_chunk,
|
234 |
+
}
|
235 |
+
|
236 |
+
return data
|
237 |
+
|
238 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
239 |
+
try:
|
240 |
+
return self.sample(idx)
|
241 |
+
except Exception as e:
|
242 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
243 |
+
return None
|
244 |
+
|
245 |
+
def __len__(self):
|
246 |
+
return len(self.labels)
|
247 |
+
|
248 |
+
|
249 |
+
# dataset = VGGSound(
|
250 |
+
# root="data/vggsound/video/test",
|
251 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
252 |
+
# sample_rate=44100,
|
253 |
+
# duration_sec=9.0,
|
254 |
+
# audio_samples=397312,
|
255 |
+
# start_row=0,
|
256 |
+
# end_row=None,
|
257 |
+
# save_dir="data/vggsound/video_latents_text/test"
|
258 |
+
# )
|
259 |
+
# dataset[0]
|
data_utils/v2a_utils/vggsound_224.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import logging
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
_CLIP_SIZE = 224
|
21 |
+
_CLIP_FPS = 8.0
|
22 |
+
|
23 |
+
_SYNC_SIZE = 224
|
24 |
+
_SYNC_FPS = 25.0
|
25 |
+
|
26 |
+
def save_tensor_as_image(tensor, save_path):
|
27 |
+
"""
|
28 |
+
将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
|
29 |
+
|
30 |
+
:param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
|
31 |
+
:param save_path: 图片保存路径。
|
32 |
+
"""
|
33 |
+
# # 移除批次维度,变成 (3, H, W)
|
34 |
+
# tensor = tensor.squeeze(0)
|
35 |
+
|
36 |
+
# 交换轴顺序,变为 (H, W, 3)
|
37 |
+
image_array = np.transpose(tensor, (1, 2, 0))
|
38 |
+
|
39 |
+
# 检查数组是否为合适的数据类型
|
40 |
+
if image_array.dtype != np.uint8:
|
41 |
+
# 如果不是 uint8,首先标准化,然后转换
|
42 |
+
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
|
43 |
+
image_array = image_array.astype(np.uint8)
|
44 |
+
|
45 |
+
# 创建图像对象
|
46 |
+
image = Image.fromarray(image_array)
|
47 |
+
|
48 |
+
# 保存图片
|
49 |
+
image.save(save_path)
|
50 |
+
print(f"Image saved to {save_path}")
|
51 |
+
|
52 |
+
def pad_to_square(video_tensor):
|
53 |
+
# 验证输入的形状
|
54 |
+
if len(video_tensor.shape) != 4:
|
55 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
56 |
+
|
57 |
+
l, c, h, w = video_tensor.shape
|
58 |
+
max_side = max(h, w)
|
59 |
+
|
60 |
+
# 计算每一维度需要的填充量:(left, right, top, bottom)
|
61 |
+
pad_h = max_side - h
|
62 |
+
pad_w = max_side - w
|
63 |
+
|
64 |
+
# 创建padding tuple (left, right, top, bottom)
|
65 |
+
# 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
|
66 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
67 |
+
|
68 |
+
# 使用F.pad对视频张量进行填充操作
|
69 |
+
# 填充参数为 (left, right, top, bottom)
|
70 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
71 |
+
|
72 |
+
return video_padded
|
73 |
+
|
74 |
+
class VGGSound(Dataset):
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
root: Union[str, Path],
|
79 |
+
*,
|
80 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
81 |
+
sample_rate: int = 44_100,
|
82 |
+
duration_sec: float = 9.0,
|
83 |
+
audio_samples: Optional[int] = 397312,
|
84 |
+
normalize_audio: bool = False,
|
85 |
+
start_row: Optional[int] = None,
|
86 |
+
end_row: Optional[int] = None,
|
87 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
88 |
+
):
|
89 |
+
self.root = Path(root)
|
90 |
+
self.normalize_audio = normalize_audio
|
91 |
+
if audio_samples is None:
|
92 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
93 |
+
else:
|
94 |
+
self.audio_samples = audio_samples
|
95 |
+
effective_duration = audio_samples / sample_rate
|
96 |
+
# make sure the duration is close enough, within 15ms
|
97 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
98 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
99 |
+
|
100 |
+
# videos = sorted(os.listdir(self.root))
|
101 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
102 |
+
videos = []
|
103 |
+
self.labels = []
|
104 |
+
self.videos = []
|
105 |
+
missing_videos = []
|
106 |
+
# read the tsv for subset information
|
107 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
108 |
+
|
109 |
+
# 控制处理的行范围
|
110 |
+
if start_row is not None and end_row is not None:
|
111 |
+
df_list = df_list[start_row:end_row]
|
112 |
+
|
113 |
+
for record in df_list:
|
114 |
+
id = record['id']
|
115 |
+
if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
116 |
+
label = record['label']
|
117 |
+
# if id in videos:
|
118 |
+
self.labels.append(label)
|
119 |
+
# self.labels[id] = label
|
120 |
+
self.videos.append(id)
|
121 |
+
# else:
|
122 |
+
# missing_videos.append(id)
|
123 |
+
|
124 |
+
log.info(f'{len(videos)} videos found in {root}')
|
125 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
126 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
127 |
+
|
128 |
+
self.sample_rate = sample_rate
|
129 |
+
self.duration_sec = duration_sec
|
130 |
+
|
131 |
+
self.expected_audio_length = self.audio_samples
|
132 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
133 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
134 |
+
|
135 |
+
self.clip_transform = v2.Compose([
|
136 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
137 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
138 |
+
v2.ToImage(),
|
139 |
+
v2.ToDtype(torch.float32, scale=True),
|
140 |
+
])
|
141 |
+
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
142 |
+
self.sync_transform = v2.Compose([
|
143 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
144 |
+
v2.CenterCrop(_SYNC_SIZE),
|
145 |
+
v2.ToImage(),
|
146 |
+
v2.ToDtype(torch.float32, scale=True),
|
147 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
148 |
+
])
|
149 |
+
|
150 |
+
self.resampler = {}
|
151 |
+
|
152 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
153 |
+
video_id = self.videos[idx]
|
154 |
+
label = self.labels[idx]
|
155 |
+
|
156 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
157 |
+
reader.add_basic_video_stream(
|
158 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
159 |
+
frame_rate=_CLIP_FPS,
|
160 |
+
format='rgb24',
|
161 |
+
)
|
162 |
+
reader.add_basic_video_stream(
|
163 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
164 |
+
frame_rate=_SYNC_FPS,
|
165 |
+
format='rgb24',
|
166 |
+
)
|
167 |
+
reader.add_basic_audio_stream(frames_per_chunk=2**30,)
|
168 |
+
|
169 |
+
reader.fill_buffer()
|
170 |
+
data_chunk = reader.pop_chunks()
|
171 |
+
|
172 |
+
clip_chunk = data_chunk[0]
|
173 |
+
sync_chunk = data_chunk[1]
|
174 |
+
audio_chunk = data_chunk[2]
|
175 |
+
if len(audio_chunk.shape) != 2:
|
176 |
+
raise RuntimeError(f'error audio shape {video_id}')
|
177 |
+
if clip_chunk is None:
|
178 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
179 |
+
# if clip_chunk.shape[0] < self.clip_expected_length:
|
180 |
+
# raise RuntimeError(
|
181 |
+
# f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
182 |
+
# )
|
183 |
+
|
184 |
+
if sync_chunk is None:
|
185 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
186 |
+
# if sync_chunk.shape[0] < self.sync_expected_length:
|
187 |
+
# raise RuntimeError(
|
188 |
+
# f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
189 |
+
# )
|
190 |
+
# import ipdb
|
191 |
+
# ipdb.set_trace()
|
192 |
+
# process audio
|
193 |
+
# import ipdb
|
194 |
+
# ipdb.set_trace()
|
195 |
+
sample_rate = int(reader.get_out_stream_info(2).sample_rate)
|
196 |
+
audio_chunk = audio_chunk.transpose(0, 1)
|
197 |
+
abs_max = audio_chunk[0].abs().max()
|
198 |
+
# audio_chunk = audio_chunk.mean(dim=0) # mono
|
199 |
+
# if self.normalize_audio:
|
200 |
+
# abs_max = audio_chunk.abs().max()
|
201 |
+
# audio_chunk = audio_chunk / abs_max * 0.95
|
202 |
+
if abs_max <= 1e-6:
|
203 |
+
if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6:
|
204 |
+
audio_chunk = audio_chunk[1:2]
|
205 |
+
else:
|
206 |
+
raise RuntimeError(f'Audio is silent {video_id}')
|
207 |
+
|
208 |
+
# ensure the stereo audio
|
209 |
+
if audio_chunk.shape[0] < 2:
|
210 |
+
audio_chunk = audio_chunk.repeat(2, 1)
|
211 |
+
|
212 |
+
# resample
|
213 |
+
if sample_rate == self.sample_rate:
|
214 |
+
audio_chunk = audio_chunk
|
215 |
+
else:
|
216 |
+
if sample_rate not in self.resampler:
|
217 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
218 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
219 |
+
sample_rate,
|
220 |
+
self.sample_rate,
|
221 |
+
lowpass_filter_width=64,
|
222 |
+
rolloff=0.9475937167399596,
|
223 |
+
resampling_method='sinc_interp_kaiser',
|
224 |
+
beta=14.769656459379492,
|
225 |
+
)
|
226 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
227 |
+
|
228 |
+
if audio_chunk.shape[1] < self.expected_audio_length:
|
229 |
+
# zero-padding audio
|
230 |
+
padding_length = self.expected_audio_length - audio_chunk.shape[1]
|
231 |
+
# 创建 padding 张量,大小为 [batch_size, padding_length],值为0
|
232 |
+
padding = torch.zeros(audio_chunk.shape[0], padding_length)
|
233 |
+
# 将原始音频和 padding 沿第 1 维度拼接在一起
|
234 |
+
audio_chunk = torch.cat((audio_chunk, padding), dim=1)
|
235 |
+
# raise RuntimeError(f'Audio too short {video_id}')
|
236 |
+
audio_chunk = audio_chunk[:,:self.expected_audio_length]
|
237 |
+
# truncate the video
|
238 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
239 |
+
# import ipdb
|
240 |
+
# ipdb.set_trace()
|
241 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
242 |
+
current_length = clip_chunk.shape[0]
|
243 |
+
padding_needed = self.clip_expected_length - current_length
|
244 |
+
|
245 |
+
# Check that padding needed is no more than 2
|
246 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
247 |
+
|
248 |
+
# If assertion passes, proceed with padding
|
249 |
+
if padding_needed > 0:
|
250 |
+
last_frame = clip_chunk[-1]
|
251 |
+
log.info(last_frame.shape)
|
252 |
+
# Repeat the last frame to reach the expected length
|
253 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
254 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
255 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
256 |
+
# f'expected {self.clip_expected_length}, '
|
257 |
+
# f'got {clip_chunk.shape[0]}')
|
258 |
+
|
259 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
260 |
+
clip_chunk = pad_to_square(clip_chunk)
|
261 |
+
# save_image(clip_chunk[0] / 255.0,'square.png')
|
262 |
+
# clip_chunk = self.clip_transform(clip_chunk)
|
263 |
+
# import ipdb
|
264 |
+
# ipdb.set_trace()
|
265 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
266 |
+
# log.info(clip_chunk.shape)
|
267 |
+
# save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
|
268 |
+
# log.info(clip_chunk[0])
|
269 |
+
# clip_chunk = outputs
|
270 |
+
# text_ids = outputs["input_ids"]
|
271 |
+
# temp_img = clip_chunk[0].permute(1, 2, 0) * 255
|
272 |
+
# save_image(clip_chunk[0],'scale.png')
|
273 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
274 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
275 |
+
# padding using the last frame, but no more than 2
|
276 |
+
current_length = sync_chunk.shape[0]
|
277 |
+
last_frame = sync_chunk[-1]
|
278 |
+
# 重复最后一帧以进行填充
|
279 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
280 |
+
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
281 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
282 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
283 |
+
# f'expected {self.sync_expected_length}, '
|
284 |
+
# f'got {sync_chunk.shape[0]}')
|
285 |
+
|
286 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
287 |
+
assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
|
288 |
+
and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
289 |
+
data = {
|
290 |
+
'id': video_id,
|
291 |
+
'caption': label,
|
292 |
+
'audio': audio_chunk,
|
293 |
+
'clip_video': clip_chunk,
|
294 |
+
'sync_video': sync_chunk,
|
295 |
+
}
|
296 |
+
|
297 |
+
return data
|
298 |
+
|
299 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
300 |
+
try:
|
301 |
+
return self.sample(idx)
|
302 |
+
except Exception as e:
|
303 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
304 |
+
return None
|
305 |
+
|
306 |
+
def __len__(self):
|
307 |
+
return len(self.labels)
|
308 |
+
|
309 |
+
|
310 |
+
# dataset = VGGSound(
|
311 |
+
# root="data/vggsound/video/train",
|
312 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
313 |
+
# sample_rate=44100,
|
314 |
+
# duration_sec=9.0,
|
315 |
+
# audio_samples=397312,
|
316 |
+
# start_row=0,
|
317 |
+
# end_row=None,
|
318 |
+
# save_dir="data/vggsound/video_224_latents_text/train"
|
319 |
+
# )
|
320 |
+
# dataset[0]
|
data_utils/v2a_utils/vggsound_224_no_audio.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import logging
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
_CLIP_SIZE = 224
|
21 |
+
_CLIP_FPS = 8.0
|
22 |
+
|
23 |
+
_SYNC_SIZE = 224
|
24 |
+
_SYNC_FPS = 25.0
|
25 |
+
|
26 |
+
def save_tensor_as_image(tensor, save_path):
|
27 |
+
"""
|
28 |
+
将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
|
29 |
+
|
30 |
+
:param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
|
31 |
+
:param save_path: 图片保存路径。
|
32 |
+
"""
|
33 |
+
# # 移除批次维度,变成 (3, H, W)
|
34 |
+
# tensor = tensor.squeeze(0)
|
35 |
+
|
36 |
+
# 交换轴顺序,变为 (H, W, 3)
|
37 |
+
image_array = np.transpose(tensor, (1, 2, 0))
|
38 |
+
|
39 |
+
# 检查数组是否为合适的数据类型
|
40 |
+
if image_array.dtype != np.uint8:
|
41 |
+
# 如果不是 uint8,首先标准化,然后转换
|
42 |
+
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
|
43 |
+
image_array = image_array.astype(np.uint8)
|
44 |
+
|
45 |
+
# 创建图像对象
|
46 |
+
image = Image.fromarray(image_array)
|
47 |
+
|
48 |
+
# 保存图片
|
49 |
+
image.save(save_path)
|
50 |
+
print(f"Image saved to {save_path}")
|
51 |
+
|
52 |
+
def pad_to_square(video_tensor):
|
53 |
+
# 验证输入的形状
|
54 |
+
if len(video_tensor.shape) != 4:
|
55 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
56 |
+
|
57 |
+
l, c, h, w = video_tensor.shape
|
58 |
+
max_side = max(h, w)
|
59 |
+
|
60 |
+
# 计算每一维度需要的填充量:(left, right, top, bottom)
|
61 |
+
pad_h = max_side - h
|
62 |
+
pad_w = max_side - w
|
63 |
+
|
64 |
+
# 创建padding tuple (left, right, top, bottom)
|
65 |
+
# 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
|
66 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
67 |
+
|
68 |
+
# 使用F.pad对视频张量进行填充操作
|
69 |
+
# 填充参数为 (left, right, top, bottom)
|
70 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
71 |
+
|
72 |
+
return video_padded
|
73 |
+
|
74 |
+
class VGGSound(Dataset):
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
root: Union[str, Path],
|
79 |
+
*,
|
80 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
81 |
+
sample_rate: int = 44_100,
|
82 |
+
duration_sec: float = 9.0,
|
83 |
+
audio_samples: Optional[int] = 397312,
|
84 |
+
normalize_audio: bool = False,
|
85 |
+
start_row: Optional[int] = None,
|
86 |
+
end_row: Optional[int] = None,
|
87 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
88 |
+
):
|
89 |
+
self.root = Path(root)
|
90 |
+
self.normalize_audio = normalize_audio
|
91 |
+
if audio_samples is None:
|
92 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
93 |
+
else:
|
94 |
+
self.audio_samples = audio_samples
|
95 |
+
effective_duration = audio_samples / sample_rate
|
96 |
+
# make sure the duration is close enough, within 15ms
|
97 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
98 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
99 |
+
|
100 |
+
# videos = sorted(os.listdir(self.root))
|
101 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
102 |
+
videos = []
|
103 |
+
self.labels = []
|
104 |
+
self.videos = []
|
105 |
+
self.caption_cot = []
|
106 |
+
missing_videos = []
|
107 |
+
# read the tsv for subset information
|
108 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
109 |
+
|
110 |
+
# 控制处理的行范围
|
111 |
+
if start_row is not None and end_row is not None:
|
112 |
+
df_list = df_list[start_row:end_row]
|
113 |
+
|
114 |
+
for record in df_list:
|
115 |
+
id = record['id']
|
116 |
+
if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
117 |
+
label = record['caption']
|
118 |
+
caption_cot = record['caption_cot']
|
119 |
+
# if id in videos:
|
120 |
+
self.labels.append(label)
|
121 |
+
# self.labels[id] = label
|
122 |
+
self.videos.append(id)
|
123 |
+
self.caption_cot.append(caption_cot)
|
124 |
+
# else:
|
125 |
+
# missing_videos.append(id)
|
126 |
+
|
127 |
+
log.info(f'{len(videos)} videos found in {root}')
|
128 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
129 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
130 |
+
|
131 |
+
self.sample_rate = sample_rate
|
132 |
+
self.duration_sec = duration_sec
|
133 |
+
|
134 |
+
self.expected_audio_length = self.audio_samples
|
135 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
136 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
137 |
+
|
138 |
+
self.clip_transform = v2.Compose([
|
139 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
140 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
141 |
+
v2.ToImage(),
|
142 |
+
v2.ToDtype(torch.float32, scale=True),
|
143 |
+
])
|
144 |
+
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
145 |
+
self.sync_transform = v2.Compose([
|
146 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
147 |
+
v2.CenterCrop(_SYNC_SIZE),
|
148 |
+
v2.ToImage(),
|
149 |
+
v2.ToDtype(torch.float32, scale=True),
|
150 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
151 |
+
])
|
152 |
+
|
153 |
+
self.resampler = {}
|
154 |
+
|
155 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
156 |
+
video_id = self.videos[idx]
|
157 |
+
label = self.labels[idx]
|
158 |
+
caption_cot = self.caption_cot[idx]
|
159 |
+
|
160 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
161 |
+
reader.add_basic_video_stream(
|
162 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
163 |
+
frame_rate=_CLIP_FPS,
|
164 |
+
format='rgb24',
|
165 |
+
)
|
166 |
+
reader.add_basic_video_stream(
|
167 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
168 |
+
frame_rate=_SYNC_FPS,
|
169 |
+
format='rgb24',
|
170 |
+
)
|
171 |
+
# reader.add_basic_audio_stream(frames_per_chunk=2**30,)
|
172 |
+
|
173 |
+
reader.fill_buffer()
|
174 |
+
data_chunk = reader.pop_chunks()
|
175 |
+
|
176 |
+
clip_chunk = data_chunk[0]
|
177 |
+
sync_chunk = data_chunk[1]
|
178 |
+
# audio_chunk = data_chunk[2]
|
179 |
+
# if len(audio_chunk.shape) != 2:
|
180 |
+
# raise RuntimeError(f'error audio shape {video_id}')
|
181 |
+
if clip_chunk is None:
|
182 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
183 |
+
# if clip_chunk.shape[0] < self.clip_expected_length:
|
184 |
+
# raise RuntimeError(
|
185 |
+
# f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
186 |
+
# )
|
187 |
+
|
188 |
+
if sync_chunk is None:
|
189 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
190 |
+
|
191 |
+
# truncate the video
|
192 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
193 |
+
# import ipdb
|
194 |
+
# ipdb.set_trace()
|
195 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
196 |
+
current_length = clip_chunk.shape[0]
|
197 |
+
padding_needed = self.clip_expected_length - current_length
|
198 |
+
|
199 |
+
# Check that padding needed is no more than 2
|
200 |
+
# assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
201 |
+
|
202 |
+
# If assertion passes, proceed with padding
|
203 |
+
if padding_needed > 0:
|
204 |
+
last_frame = clip_chunk[-1]
|
205 |
+
log.info(last_frame.shape)
|
206 |
+
# Repeat the last frame to reach the expected length
|
207 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
208 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
209 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
210 |
+
# f'expected {self.clip_expected_length}, '
|
211 |
+
# f'got {clip_chunk.shape[0]}')
|
212 |
+
|
213 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
214 |
+
clip_chunk = pad_to_square(clip_chunk)
|
215 |
+
# save_image(clip_chunk[0] / 255.0,'square.png')
|
216 |
+
# clip_chunk = self.clip_transform(clip_chunk)
|
217 |
+
# import ipdb
|
218 |
+
# ipdb.set_trace()
|
219 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
220 |
+
# log.info(clip_chunk.shape)
|
221 |
+
# save_tensor_as_image(clip_chunk[0].numpy(),'scale.png')
|
222 |
+
# log.info(clip_chunk[0])
|
223 |
+
# clip_chunk = outputs
|
224 |
+
# text_ids = outputs["input_ids"]
|
225 |
+
# temp_img = clip_chunk[0].permute(1, 2, 0) * 255
|
226 |
+
# save_image(clip_chunk[0],'scale.png')
|
227 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
228 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
229 |
+
# padding using the last frame, but no more than 2
|
230 |
+
current_length = sync_chunk.shape[0]
|
231 |
+
last_frame = sync_chunk[-1]
|
232 |
+
# 重复最后一帧以进行填充
|
233 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
234 |
+
# assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
235 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
236 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
237 |
+
# f'expected {self.sync_expected_length}, '
|
238 |
+
# f'got {sync_chunk.shape[0]}')
|
239 |
+
|
240 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
241 |
+
# assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
|
242 |
+
# and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
243 |
+
data = {
|
244 |
+
'id': video_id,
|
245 |
+
'caption': label,
|
246 |
+
# 'audio': audio_chunk,
|
247 |
+
'clip_video': clip_chunk,
|
248 |
+
'sync_video': sync_chunk,
|
249 |
+
'caption_cot': caption_cot,
|
250 |
+
}
|
251 |
+
|
252 |
+
return data
|
253 |
+
|
254 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
255 |
+
try:
|
256 |
+
return self.sample(idx)
|
257 |
+
except Exception as e:
|
258 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
259 |
+
return None
|
260 |
+
|
261 |
+
def __len__(self):
|
262 |
+
return len(self.labels)
|
263 |
+
|
264 |
+
|
265 |
+
# dataset = VGGSound(
|
266 |
+
# root="data/vggsound/video/train",
|
267 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
268 |
+
# sample_rate=44100,
|
269 |
+
# duration_sec=9.0,
|
270 |
+
# audio_samples=397312,
|
271 |
+
# start_row=0,
|
272 |
+
# end_row=None,
|
273 |
+
# save_dir="data/vggsound/video_224_latents_text/train"
|
274 |
+
# )
|
275 |
+
# dataset[0]
|
data_utils/v2a_utils/vggsound_224_no_sync.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Union
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from transformers import AutoProcessor
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import logging
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
_CLIP_SIZE = 224
|
21 |
+
_CLIP_FPS = 8.0
|
22 |
+
|
23 |
+
_SYNC_SIZE = 224
|
24 |
+
_SYNC_FPS = 25.0
|
25 |
+
|
26 |
+
def save_tensor_as_image(tensor, save_path):
|
27 |
+
"""
|
28 |
+
将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。
|
29 |
+
|
30 |
+
:param tensor: 输入的 NumPy 数组 (1, 3, H, W)。
|
31 |
+
:param save_path: 图片保存路径。
|
32 |
+
"""
|
33 |
+
# # 移除批次维度,变成 (3, H, W)
|
34 |
+
# tensor = tensor.squeeze(0)
|
35 |
+
|
36 |
+
# 交换轴顺序,变为 (H, W, 3)
|
37 |
+
image_array = np.transpose(tensor, (1, 2, 0))
|
38 |
+
|
39 |
+
# 检查数组是否为合适的数据类型
|
40 |
+
if image_array.dtype != np.uint8:
|
41 |
+
# 如果不是 uint8,首先标准化,然后转换
|
42 |
+
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255
|
43 |
+
image_array = image_array.astype(np.uint8)
|
44 |
+
|
45 |
+
# 创建图像对象
|
46 |
+
image = Image.fromarray(image_array)
|
47 |
+
|
48 |
+
# 保存图片
|
49 |
+
image.save(save_path)
|
50 |
+
print(f"Image saved to {save_path}")
|
51 |
+
|
52 |
+
def pad_to_square(video_tensor):
|
53 |
+
# 验证输入的形状
|
54 |
+
if len(video_tensor.shape) != 4:
|
55 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
56 |
+
|
57 |
+
l, c, h, w = video_tensor.shape
|
58 |
+
max_side = max(h, w)
|
59 |
+
|
60 |
+
# 计算每一维度需要的填充量:(left, right, top, bottom)
|
61 |
+
pad_h = max_side - h
|
62 |
+
pad_w = max_side - w
|
63 |
+
|
64 |
+
# 创建padding tuple (left, right, top, bottom)
|
65 |
+
# 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充
|
66 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
67 |
+
|
68 |
+
# 使用F.pad对视频张量进行填充操作
|
69 |
+
# 填充参数为 (left, right, top, bottom)
|
70 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
71 |
+
|
72 |
+
return video_padded
|
73 |
+
|
74 |
+
class VGGSound(Dataset):
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
root: Union[str, Path],
|
79 |
+
*,
|
80 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
81 |
+
sample_rate: int = 44_100,
|
82 |
+
duration_sec: float = 9.0,
|
83 |
+
audio_samples: Optional[int] = 397312,
|
84 |
+
normalize_audio: bool = False,
|
85 |
+
start_row: Optional[int] = None,
|
86 |
+
end_row: Optional[int] = None,
|
87 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
88 |
+
):
|
89 |
+
self.root = Path(root)
|
90 |
+
self.normalize_audio = normalize_audio
|
91 |
+
if audio_samples is None:
|
92 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
93 |
+
else:
|
94 |
+
self.audio_samples = audio_samples
|
95 |
+
effective_duration = audio_samples / sample_rate
|
96 |
+
# make sure the duration is close enough, within 15ms
|
97 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
98 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
99 |
+
|
100 |
+
# videos = sorted(os.listdir(self.root))
|
101 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
102 |
+
videos = []
|
103 |
+
self.labels = []
|
104 |
+
self.videos = []
|
105 |
+
missing_videos = []
|
106 |
+
# read the tsv for subset information
|
107 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
108 |
+
|
109 |
+
# 控制处理的行范围
|
110 |
+
if start_row is not None and end_row is not None:
|
111 |
+
df_list = df_list[start_row:end_row]
|
112 |
+
|
113 |
+
for record in df_list:
|
114 |
+
id = record['id']
|
115 |
+
if os.path.exists(f'{save_dir}/{id}.pth'): continue
|
116 |
+
label = record['label']
|
117 |
+
# if id in videos:
|
118 |
+
self.labels.append(label)
|
119 |
+
# self.labels[id] = label
|
120 |
+
self.videos.append(id)
|
121 |
+
# else:
|
122 |
+
# missing_videos.append(id)
|
123 |
+
|
124 |
+
log.info(f'{len(videos)} videos found in {root}')
|
125 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
126 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
127 |
+
|
128 |
+
self.sample_rate = sample_rate
|
129 |
+
self.duration_sec = duration_sec
|
130 |
+
|
131 |
+
self.expected_audio_length = self.audio_samples
|
132 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
133 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
134 |
+
|
135 |
+
self.clip_transform = v2.Compose([
|
136 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
137 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
138 |
+
v2.ToImage(),
|
139 |
+
v2.ToDtype(torch.float32, scale=True),
|
140 |
+
])
|
141 |
+
self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
|
142 |
+
|
143 |
+
self.resampler = {}
|
144 |
+
|
145 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
146 |
+
video_id = self.videos[idx]
|
147 |
+
label = self.labels[idx]
|
148 |
+
|
149 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
150 |
+
reader.add_basic_video_stream(
|
151 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
152 |
+
frame_rate=_CLIP_FPS,
|
153 |
+
format='rgb24',
|
154 |
+
)
|
155 |
+
|
156 |
+
reader.fill_buffer()
|
157 |
+
data_chunk = reader.pop_chunks()
|
158 |
+
|
159 |
+
clip_chunk = data_chunk[0]
|
160 |
+
if clip_chunk is None:
|
161 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
162 |
+
|
163 |
+
|
164 |
+
# truncate the video
|
165 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
166 |
+
# import ipdb
|
167 |
+
# ipdb.set_trace()
|
168 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
169 |
+
current_length = clip_chunk.shape[0]
|
170 |
+
padding_needed = self.clip_expected_length - current_length
|
171 |
+
|
172 |
+
# Check that padding needed is no more than 2
|
173 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
174 |
+
|
175 |
+
# If assertion passes, proceed with padding
|
176 |
+
if padding_needed > 0:
|
177 |
+
last_frame = clip_chunk[-1]
|
178 |
+
log.info(last_frame.shape)
|
179 |
+
# Repeat the last frame to reach the expected length
|
180 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
181 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
182 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
183 |
+
# f'expected {self.clip_expected_length}, '
|
184 |
+
# f'got {clip_chunk.shape[0]}')
|
185 |
+
|
186 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
187 |
+
clip_chunk = pad_to_square(clip_chunk)
|
188 |
+
# save_image(clip_chunk[0] / 255.0,'square.png')
|
189 |
+
# clip_chunk = self.clip_transform(clip_chunk)
|
190 |
+
# import ipdb
|
191 |
+
# ipdb.set_trace()
|
192 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
193 |
+
|
194 |
+
data = {
|
195 |
+
'id': video_id,
|
196 |
+
'caption': label,
|
197 |
+
'clip_video': clip_chunk,
|
198 |
+
}
|
199 |
+
|
200 |
+
return data
|
201 |
+
|
202 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
203 |
+
try:
|
204 |
+
return self.sample(idx)
|
205 |
+
except Exception as e:
|
206 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
207 |
+
return None
|
208 |
+
|
209 |
+
def __len__(self):
|
210 |
+
return len(self.labels)
|
211 |
+
|
212 |
+
|
213 |
+
# dataset = VGGSound(
|
214 |
+
# root="data/vggsound/video/train",
|
215 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
216 |
+
# sample_rate=44100,
|
217 |
+
# duration_sec=9.0,
|
218 |
+
# audio_samples=397312,
|
219 |
+
# start_row=0,
|
220 |
+
# end_row=None,
|
221 |
+
# save_dir="data/vggsound/video_224_latents_text/train"
|
222 |
+
# )
|
223 |
+
# dataset[0]
|
data_utils/v2a_utils/vggsound_text.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
|
14 |
+
log = logging.getLogger()
|
15 |
+
|
16 |
+
_CLIP_SIZE = 384
|
17 |
+
_CLIP_FPS = 8.0
|
18 |
+
|
19 |
+
_SYNC_SIZE = 224
|
20 |
+
_SYNC_FPS = 25.0
|
21 |
+
|
22 |
+
|
23 |
+
class VGGSound(Dataset):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
root: Union[str, Path],
|
28 |
+
*,
|
29 |
+
tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv',
|
30 |
+
start_row: Optional[int] = None,
|
31 |
+
end_row: Optional[int] = None,
|
32 |
+
save_dir: str = 'data/vggsound/video_latents_text/train'
|
33 |
+
):
|
34 |
+
self.root = Path(root)
|
35 |
+
|
36 |
+
# videos = sorted(os.listdir(self.root))
|
37 |
+
# videos = set([Path(v).stem for v in videos]) # remove extensions
|
38 |
+
videos = []
|
39 |
+
self.labels = []
|
40 |
+
self.cots = []
|
41 |
+
self.videos = []
|
42 |
+
missing_videos = []
|
43 |
+
# read the tsv for subset information
|
44 |
+
df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records')
|
45 |
+
|
46 |
+
# 控制处理的行范围
|
47 |
+
if start_row is not None and end_row is not None:
|
48 |
+
df_list = df_list[start_row:end_row]
|
49 |
+
|
50 |
+
for record in df_list:
|
51 |
+
id = record['id']
|
52 |
+
# if os.path.exists(f'{save_dir}/{id}.pth'):
|
53 |
+
# continue
|
54 |
+
# try:
|
55 |
+
# torch.load(f'{save_dir}/{id}.pth')
|
56 |
+
# continue
|
57 |
+
# except:
|
58 |
+
# print(f'error load file: {save_dir}/{id}.pth')
|
59 |
+
# os.system(f'rm -f {save_dir}/{id}.pth')
|
60 |
+
label = record['caption']
|
61 |
+
# if id in videos:
|
62 |
+
self.labels.append(label)
|
63 |
+
self.cots.append(record['caption_cot'])
|
64 |
+
# self.labels[id] = label
|
65 |
+
self.videos.append(id)
|
66 |
+
# else:
|
67 |
+
# missing_videos.append(id)
|
68 |
+
|
69 |
+
log.info(f'{len(videos)} videos found in {root}')
|
70 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
71 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
77 |
+
video_id = self.videos[idx]
|
78 |
+
label = self.labels[idx]
|
79 |
+
cot = self.cots[idx]
|
80 |
+
data = {
|
81 |
+
'id': video_id,
|
82 |
+
'caption': label,
|
83 |
+
'caption_cot': cot
|
84 |
+
}
|
85 |
+
|
86 |
+
return data
|
87 |
+
|
88 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
89 |
+
try:
|
90 |
+
return self.sample(idx)
|
91 |
+
except Exception as e:
|
92 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
93 |
+
return None
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.labels)
|
97 |
+
|
98 |
+
|
99 |
+
# dataset = VGGSound(
|
100 |
+
# root="data/vggsound/video/test",
|
101 |
+
# tsv_path="data/vggsound/split_txt/temp.csv",
|
102 |
+
# sample_rate=44100,
|
103 |
+
# duration_sec=9.0,
|
104 |
+
# audio_samples=397312,
|
105 |
+
# start_row=0,
|
106 |
+
# end_row=None,
|
107 |
+
# save_dir="data/vggsound/video_latents_text/test"
|
108 |
+
# )
|
109 |
+
# dataset[0]
|
defaults.ini
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[DEFAULTS]
|
3 |
+
|
4 |
+
#name of the run
|
5 |
+
name = stable_audio_tools
|
6 |
+
|
7 |
+
# the batch size
|
8 |
+
batch_size = 8
|
9 |
+
test_batch_size = 1
|
10 |
+
|
11 |
+
# predict ckpt directory
|
12 |
+
ckpt_dir = "ckpts"
|
13 |
+
|
14 |
+
# number of GPUs to use for training
|
15 |
+
num_gpus = 1
|
16 |
+
|
17 |
+
# number of nodes to use for training
|
18 |
+
num_nodes = 1
|
19 |
+
|
20 |
+
# Multi-GPU strategy for PyTorch Lightning
|
21 |
+
strategy = ""
|
22 |
+
|
23 |
+
# Precision to use for training
|
24 |
+
precision = "bf16-mixed"
|
25 |
+
|
26 |
+
# number of CPU workers for the DataLoader
|
27 |
+
num_workers = 8
|
28 |
+
|
29 |
+
# the random seed
|
30 |
+
seed = 42
|
31 |
+
|
32 |
+
# Batches for gradient accumulation
|
33 |
+
accum_batches = 1
|
34 |
+
|
35 |
+
# Number of steps between checkpoints
|
36 |
+
checkpoint_every = 2000
|
37 |
+
|
38 |
+
# trainer checkpoint file to restart training from
|
39 |
+
ckpt_path = ''
|
40 |
+
|
41 |
+
# model checkpoint file to start a new training run from
|
42 |
+
pretrained_ckpt_path = ''
|
43 |
+
|
44 |
+
# Checkpoint path for the pretransform model if needed
|
45 |
+
pretransform_ckpt_path = ''
|
46 |
+
|
47 |
+
# configuration model specifying model hyperparameters
|
48 |
+
model_config = ''
|
49 |
+
|
50 |
+
# configuration for datasets
|
51 |
+
dataset_config = ''
|
52 |
+
|
53 |
+
# directory to save the checkpoints in
|
54 |
+
save_dir = ''
|
55 |
+
|
56 |
+
# gradient_clip_val passed into PyTorch Lightning Trainer
|
57 |
+
gradient_clip_val = 0.0
|
58 |
+
|
59 |
+
# remove the weight norm from the pretransform model
|
60 |
+
remove_pretransform_weight_norm = ''
|
61 |
+
|
62 |
+
compile = False
|
63 |
+
|
64 |
+
repeat_num = 5
|
65 |
+
|
66 |
+
duration_sec = '9'
|
67 |
+
|
68 |
+
results_dir = 'results'
|
demo_test.csv
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
id,caption,caption_cot
|
2 |
+
W1nb2hIeDKc_000021,striking bowling,"Start with a background of ambient music, then add consistent sounds of bowling balls striking pins to emphasize the action. Include occasional subtle sounds of pins rattling and settling. Keep human voices or other noises minimal or absent for authenticity."
|
3 |
+
YYRdv32TJnc_000184,plastic bottle crushing,"Start with the sound of crushing plastic bottles, including crinkling and crunching. Add background noise resembling a factory environment, with machinery sounds. Incorporate subtle rustling and paper crinkling to suggest manipulation of plastic items."
|
4 |
+
Rp39_WnX5Fk_000380,"subway, metro, underground","Generate subway sounds including ambient station noise, train doors opening and closing, engine hum, wheels on tracks, and conductor announcements to produce an accurate underground train environment."
|
5 |
+
-KqXcm-I2zY_000087,playing tennis,"Generate sounds of tennis hitting a racket, the ball bouncing, and the girl’s grunts, with distant tennis court ambient noise. Avoid unrelated sounds like horses, basketballs, or indoor voices. Focus on clear tennis scene with realistic audio cues."
|
6 |
+
0W_wPc-zV3I_000101,hedge trimmer running,"Generate the sound of a hedge trimmer running steadily, focusing on consistent motor noise and cutting sounds. Ensure minimal background noise or voices, capturing the primary sound of the trimmer in operation. Avoid including any chainsaw or unrelated sounds for accuracy."
|
7 |
+
_Betmm6FaWo_000096,writing on blackboard with chalk,"The audio should feature consistent sounds of chalk scratching the blackboard, including occasional voice instructions, encouragement, and children’s chatter, with background music playing softly or fading in/out to match the scene's atmosphere. The sounds of laughter and chatter should be lively but balanced with the primary chalk and voice sounds for clarity. Overall, the audio combines educational sounds with background activity to reflect a classroom or play environment."
|
8 |
+
xmTfE3F2huE_000854,chopping food,"Generate rhythmic chopping sounds consistent with meat or food being sliced, incorporating occasional rustling noises like a plastic bag. Avoid adding human voices or train sounds to match the correct audio descriptions, ensuring a focused, realistic kitchen chopping scene."
|
9 |
+
ZaUaqnLdg6k_000030,skateboarding,"Generate the audio featuring skateboarding sounds with wheels rolling on various surfaces, including ramps, rails, and sidewalks, capturing the sound of tricks and landings. Include subtle ambient background noise to suggest an outdoor setting, avoiding any human voices or singing. Focus on realistic skateboarding sounds, emphasizing wheel contact, impacts, and movement."
|
10 |
+
_ZC6yk5iE1I_000026,playing trumpet,"Generate a continuous trumpet sound with melodic variations, mimicking the sound of a person playing the trumpet idealy in a musical setting, ensuring clarity and realistic tone. Avoid extraneous noise or background sounds to reflect the focus on trumpet playing. The audio should resemble a skilled player producing expressive, melodious trumpet notes."
|
11 |
+
55L7peYRB_Q_000120,using sewing machines,"Generate ambient sewing room sounds with consistent sewing machine hum, minimal background noise, and no human voices, focusing on characteristic machine noise to match the correct descriptions."
|
12 |
+
4p8n4Zf-WMM_000190,lighting firecrackers,"Generate the sound of firecrackers lighting and exploding repeatedly, mixed with distant background sounds of crickets chirping. Incorporate occasional subtle echoes to mimic outdoor night ambiance, with no human voices present. End with a series of sharp cracker bursts to create a lively, festive atmosphere."
|
13 |
+
yLazKv68TeA_000078,people eating crisps,"Create audio with consistent crisp sounds of people eating chips, including crinkling paper and breathing. Include subtle chewing noises to match the activity. Avoid background music or voices for clarity."
|
14 |
+
_XyxrZDZ36E_000034,hammering nails,"Generate audio with consistent hammering sounds, featuring a rhythmic pattern of nails being driven into a surface, with occasional ambient background sounds like birds chirping and distant traffic. Avoid human voices, focusing on realistic hammer strikes and natural outdoor environment sounds. Ensure the hammering tone is steady and clear, matching the description of continuous nail hammering."
|
15 |
+
1u1orBeV4xI_000428,ripping paper,"Start with a subtle tearing sound of paper being ripped, emphasizing a continuous, consistent noise. Ensure the sound has slight variations to mimic real tearing. No background or additional noises are needed, focusing solely on the tearing action."
|
16 |
+
JFG4YvcJ3bo_000228,playing bongo,"Generate a lively percussion track featuring rhythmic djembe beats, with a melodic guitar strumming softly in the background to enhance the musical atmosphere. Ensure no human voice is included, focusing on the percussive and guitar sounds. Maintain a natural, well-balanced stereo mix to highlight the instruments' interplay."
|
17 |
+
1pViEqMXJH0_000030,printer printing,"Generate a continuous printer printing sound with periodic beeps, resembling typical printer noise, including paper movement and occasional beeps for realism. Add subtle ambient background noise, like faint room sounds, to enhance authenticity. Ensure the primary focus remains on the printing and beeping sounds, consistent with the correct audio descriptions."
|
examples/1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8884c466292b46510c298a9ee88d8a584c86cb750afb558108c0850413e21e51
|
3 |
+
size 634576
|
examples/1_mute.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0ca4223b15583d8099d023bac3e86725bfd5cfbbad771ef67d31c1ad953bdc3
|
3 |
+
size 482981
|
examples/2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8c7a3dd144c91d690b07892b30544c24000008873116451291f59553f3908a4
|
3 |
+
size 368050
|
examples/2_mute.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e09de17f3d3f631a5a7cd5dfb2b5b32a78fbcd3c1b90673dfa36ce798647f1e8
|
3 |
+
size 216098
|
examples/3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9dccddd67a954b12d34d481107c499460c69bebd92913b8f092724fcdf1c5baf
|
3 |
+
size 1716778
|
examples/3_mute.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:103ed8d5e4fbe8d6954dde463c2a997acd0e21c13895404706dbbaab39e2b086
|
3 |
+
size 1564981
|
examples/4.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3a399f94d5748b77497e28e1fca4e70c3900e1583cab4ecaefb242507b9fe1b
|
3 |
+
size 3642290
|
examples/4_mute.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9aa6d9ef7523cec4f5e0087d9f6c6b86efceb694f0433da5d07bdf57eea1247
|
3 |
+
size 3490447
|
examples/5.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:189a6cfcac18470a3f18285013f63f46c7b5996f31e0d3ecf617d9f7f91fdfeb
|
3 |
+
size 738718
|
examples/5_mute.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:103c2b1517d9cbecd81b394c93ebe36f166e95f635ceee280c28073233f08173
|
3 |
+
size 586982
|
extract_latents.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from tqdm import tqdm
|
7 |
+
import logging
|
8 |
+
from data_utils.v2a_utils.vggsound_224_no_audio import VGGSound
|
9 |
+
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
|
10 |
+
import torchaudio
|
11 |
+
from einops import rearrange
|
12 |
+
from torch.utils.data.dataloader import default_collate
|
13 |
+
import numpy as np
|
14 |
+
from huggingface_hub import hf_hub_download
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
+
|
17 |
+
def setup(rank, world_size):
|
18 |
+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
19 |
+
torch.cuda.set_device(rank)
|
20 |
+
|
21 |
+
def cleanup():
|
22 |
+
dist.destroy_process_group()
|
23 |
+
|
24 |
+
def error_avoidance_collate(batch):
|
25 |
+
batch = list(filter(lambda x: x is not None, batch))
|
26 |
+
return default_collate(batch)
|
27 |
+
|
28 |
+
def main(args):
|
29 |
+
|
30 |
+
print(f"Using root: {args.root}, tsv_path: {args.tsv_path}, save_dir: {args.save_dir}")
|
31 |
+
dataset = VGGSound(
|
32 |
+
root=args.root,
|
33 |
+
tsv_path=args.tsv_path,
|
34 |
+
sample_rate=args.sample_rate,
|
35 |
+
duration_sec=args.duration_sec,
|
36 |
+
audio_samples=args.audio_samples,
|
37 |
+
start_row=args.start_row,
|
38 |
+
end_row=args.end_row,
|
39 |
+
save_dir=args.save_dir
|
40 |
+
)
|
41 |
+
save_dir = args.save_dir
|
42 |
+
os.makedirs(save_dir, exist_ok=True)
|
43 |
+
|
44 |
+
dataloader = DataLoader(dataset, batch_size=2, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate)
|
45 |
+
|
46 |
+
print(f"Dataset length: {len(dataset)}")
|
47 |
+
feature_extractor = FeaturesUtils(
|
48 |
+
vae_ckpt=None,
|
49 |
+
vae_config=args.vae_config,
|
50 |
+
enable_conditions=True,
|
51 |
+
synchformer_ckpt=args.synchformer_ckpt
|
52 |
+
).eval().cuda()
|
53 |
+
|
54 |
+
feature_extractor = feature_extractor
|
55 |
+
|
56 |
+
for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")):
|
57 |
+
ids = data['id']
|
58 |
+
with torch.no_grad():
|
59 |
+
# audio = data['audio'].cuda(rank, non_blocking=True)
|
60 |
+
output = {
|
61 |
+
'caption': str(data['caption']),
|
62 |
+
'caption_cot': str(data['caption_cot'])
|
63 |
+
}
|
64 |
+
print(output)
|
65 |
+
|
66 |
+
# latent = feature_extractor.module.encode_audio(audio)
|
67 |
+
# output['latent'] = latent.detach().cpu()
|
68 |
+
|
69 |
+
clip_video = data['clip_video'].cuda()
|
70 |
+
clip_features = feature_extractor.encode_video_with_clip(clip_video)
|
71 |
+
output['metaclip_features'] = clip_features.detach().cpu()
|
72 |
+
|
73 |
+
sync_video = data['sync_video'].cuda()
|
74 |
+
sync_features = feature_extractor.encode_video_with_sync(sync_video)
|
75 |
+
output['sync_features'] = sync_features.detach().cpu()
|
76 |
+
|
77 |
+
caption = data['caption']
|
78 |
+
metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(caption)
|
79 |
+
output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu()
|
80 |
+
output['metaclip_text_features'] = metaclip_text_features.detach().cpu()
|
81 |
+
|
82 |
+
caption_cot = data['caption_cot']
|
83 |
+
t5_features = feature_extractor.encode_t5_text(caption_cot)
|
84 |
+
output['t5_features'] = t5_features.detach().cpu()
|
85 |
+
|
86 |
+
for j in range(len(ids)):
|
87 |
+
sample_output = {
|
88 |
+
'id': ids[j],
|
89 |
+
'caption': output['caption'][j],
|
90 |
+
'caption_cot': output['caption_cot'][j],
|
91 |
+
# 'latent': output['latent'][j],
|
92 |
+
'metaclip_features': output['metaclip_features'][j],
|
93 |
+
'sync_features': output['sync_features'][j],
|
94 |
+
'metaclip_global_text_features': output['metaclip_global_text_features'][j],
|
95 |
+
'metaclip_text_features': output['metaclip_text_features'][j],
|
96 |
+
't5_features': output['t5_features'][j],
|
97 |
+
}
|
98 |
+
# torch.save(sample_output, f'{save_dir}/{ids[j]}.pth')
|
99 |
+
np.savez(f'{save_dir}/demo.npz', **sample_output)
|
100 |
+
|
101 |
+
## test the sync between videos and audios
|
102 |
+
# torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100)
|
103 |
+
# recon_audio = feature_extractor.decode_audio(latent)
|
104 |
+
# recon_audio = rearrange(recon_audio, "b d n -> d (b n)")
|
105 |
+
# id = data['id']
|
106 |
+
# torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100)
|
107 |
+
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4')
|
108 |
+
# os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4')
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == '__main__':
|
112 |
+
parser = argparse.ArgumentParser(description='Extract Video Training Latents')
|
113 |
+
parser.add_argument('--root', type=str, default='videos', help='Root directory of the video dataset')
|
114 |
+
parser.add_argument('--tsv_path', type=str, default='cot_coarse/cot.csv', help='Path to the TSV file')
|
115 |
+
parser.add_argument('--save-dir', type=str, default='results', help='Save Directory')
|
116 |
+
parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio')
|
117 |
+
parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds')
|
118 |
+
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
|
119 |
+
parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
|
120 |
+
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
|
121 |
+
parser.add_argument('--start-row', type=int, default=0, help='start row')
|
122 |
+
parser.add_argument('--end-row', type=int, default=None, help='end row')
|
123 |
+
|
124 |
+
args = parser.parse_args()
|
125 |
+
args.audio_samples = int(args.sample_rate * args.duration_sec)
|
126 |
+
|
127 |
+
main(args=args)
|
128 |
+
|
predict.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from prefigure.prefigure import get_all_args, push_wandb_config
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
# import pytorch_lightning as pl
|
8 |
+
import lightning as L
|
9 |
+
from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
|
10 |
+
from lightning.pytorch.callbacks import Callback
|
11 |
+
from lightning.pytorch.tuner import Tuner
|
12 |
+
from lightning.pytorch import seed_everything
|
13 |
+
import random
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
from ThinkSound.data.datamodule import DataModule
|
17 |
+
from ThinkSound.models import create_model_from_config
|
18 |
+
from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
|
19 |
+
from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
|
20 |
+
from ThinkSound.training.utils import copy_state_dict
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
|
23 |
+
class ExceptionCallback(Callback):
|
24 |
+
def on_exception(self, trainer, module, err):
|
25 |
+
print(f'{type(err).__name__}: {err}')
|
26 |
+
|
27 |
+
class ModelConfigEmbedderCallback(Callback):
|
28 |
+
def __init__(self, model_config):
|
29 |
+
self.model_config = model_config
|
30 |
+
|
31 |
+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
32 |
+
checkpoint["model_config"] = self.model_config
|
33 |
+
|
34 |
+
class CustomWriter(BasePredictionWriter):
|
35 |
+
|
36 |
+
def __init__(self, output_dir, write_interval='batch', batch_size=32):
|
37 |
+
super().__init__(write_interval)
|
38 |
+
self.output_dir = output_dir
|
39 |
+
self.batch_size = batch_size
|
40 |
+
|
41 |
+
def write_on_batch_end(self, trainer, pl_module, predictions, batch_indices, batch, batch_idx, dataloader_idx):
|
42 |
+
|
43 |
+
audios = predictions
|
44 |
+
ids = [item['id'] for item in batch[1]]
|
45 |
+
current_date = datetime.now()
|
46 |
+
|
47 |
+
formatted_date = current_date.strftime('%m%d')
|
48 |
+
os.makedirs(os.path.join(self.output_dir, f'{formatted_date}_batch_size{self.batch_size}'),exist_ok=True)
|
49 |
+
for audio, id in zip(audios, ids):
|
50 |
+
save_path = os.path.join(self.output_dir, f'{formatted_date}_batch_size{self.batch_size}', f'{id}.wav')
|
51 |
+
torchaudio.save(save_path, audio, 44100)
|
52 |
+
|
53 |
+
def main():
|
54 |
+
|
55 |
+
args = get_all_args()
|
56 |
+
|
57 |
+
|
58 |
+
# args.pretransform_ckpt_path = hf_hub_download(
|
59 |
+
# repo_id="liuhuadai/ThinkSound",
|
60 |
+
# filename="vae.ckpt"
|
61 |
+
# )
|
62 |
+
|
63 |
+
args.pretransform_ckpt_path = "./ckpts/vae.ckpt"
|
64 |
+
|
65 |
+
|
66 |
+
seed = 10086
|
67 |
+
|
68 |
+
# Set a different seed for each process if using SLURM
|
69 |
+
if os.environ.get("SLURM_PROCID") is not None:
|
70 |
+
seed += int(os.environ.get("SLURM_PROCID"))
|
71 |
+
|
72 |
+
# random.seed(seed)
|
73 |
+
# torch.manual_seed(seed)
|
74 |
+
seed_everything(seed, workers=True)
|
75 |
+
|
76 |
+
#Get JSON config from args.model_config
|
77 |
+
with open(args.model_config) as f:
|
78 |
+
model_config = json.load(f)
|
79 |
+
|
80 |
+
with open(args.dataset_config) as f:
|
81 |
+
dataset_config = json.load(f)
|
82 |
+
|
83 |
+
for td in dataset_config["test_datasets"]:
|
84 |
+
td["path"] = args.results_dir
|
85 |
+
|
86 |
+
# train_dl = create_dataloader_from_config(
|
87 |
+
# dataset_config,
|
88 |
+
# batch_size=args.batch_size,
|
89 |
+
# num_workers=args.num_workers,
|
90 |
+
# sample_rate=model_config["sample_rate"],
|
91 |
+
# sample_size=model_config["sample_size"],
|
92 |
+
# audio_channels=model_config.get("audio_channels", 2),
|
93 |
+
# )
|
94 |
+
|
95 |
+
|
96 |
+
duration=(float)(args.duration_sec)
|
97 |
+
|
98 |
+
dm = DataModule(
|
99 |
+
dataset_config,
|
100 |
+
batch_size=args.batch_size,
|
101 |
+
test_batch_size=args.test_batch_size,
|
102 |
+
num_workers=args.num_workers,
|
103 |
+
sample_rate=model_config["sample_rate"],
|
104 |
+
sample_size=(float)(args.duration_sec) * model_config["sample_rate"],
|
105 |
+
audio_channels=model_config.get("audio_channels", 2),
|
106 |
+
latent_length=round(44100/64/32*duration),
|
107 |
+
)
|
108 |
+
|
109 |
+
model_config["sample_size"] = duration * model_config["sample_rate"]
|
110 |
+
model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24*int(duration)
|
111 |
+
model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8*int(duration)
|
112 |
+
model_config["model"]["diffusion"]["config"]["latent_seq_len"] = round(44100/64/32*duration)
|
113 |
+
|
114 |
+
model = create_model_from_config(model_config)
|
115 |
+
|
116 |
+
## speed by torch.compile
|
117 |
+
if args.compile:
|
118 |
+
model = torch.compile(model)
|
119 |
+
|
120 |
+
if args.pretrained_ckpt_path:
|
121 |
+
copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
|
122 |
+
|
123 |
+
if args.remove_pretransform_weight_norm == "pre_load":
|
124 |
+
remove_weight_norm_from_model(model.pretransform)
|
125 |
+
# import ipdb
|
126 |
+
# ipdb.set_trace()
|
127 |
+
if args.pretransform_ckpt_path:
|
128 |
+
load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.')
|
129 |
+
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
130 |
+
model.pretransform.load_state_dict(load_vae_state)
|
131 |
+
|
132 |
+
# Remove weight_norm from the pretransform if specified
|
133 |
+
if args.remove_pretransform_weight_norm == "post_load":
|
134 |
+
remove_weight_norm_from_model(model.pretransform)
|
135 |
+
|
136 |
+
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
137 |
+
|
138 |
+
# wandb_logger = L.pytorch.loggers.WandbLogger(project=args.name)
|
139 |
+
# wandb_logger.watch(training_wrapper)
|
140 |
+
|
141 |
+
exc_callback = ExceptionCallback()
|
142 |
+
|
143 |
+
# if args.save_dir and isinstance(wandb_logger.experiment.id, str):
|
144 |
+
# checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints")
|
145 |
+
# else:
|
146 |
+
# checkpoint_dir = None
|
147 |
+
|
148 |
+
# ckpt_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, monitor='val_loss', mode='min', save_top_k=10)
|
149 |
+
save_model_config_callback = ModelConfigEmbedderCallback(model_config)
|
150 |
+
audio_dir = args.results_dir
|
151 |
+
pred_writer = CustomWriter(output_dir=audio_dir, write_interval="batch", batch_size=args.test_batch_size)
|
152 |
+
timer = Timer(duration="00:15:00:00")
|
153 |
+
demo_callback = create_demo_callback_from_config(model_config, demo_dl=dm)
|
154 |
+
|
155 |
+
#Combine args and config dicts
|
156 |
+
args_dict = vars(args)
|
157 |
+
args_dict.update({"model_config": model_config})
|
158 |
+
args_dict.update({"dataset_config": dataset_config})
|
159 |
+
# push_wandb_config(wandb_logger, args_dict)
|
160 |
+
|
161 |
+
#Set multi-GPU strategy if specified
|
162 |
+
if args.strategy:
|
163 |
+
if args.strategy == "deepspeed":
|
164 |
+
from pytorch_lightning.strategies import DeepSpeedStrategy
|
165 |
+
strategy = DeepSpeedStrategy(stage=2,
|
166 |
+
contiguous_gradients=True,
|
167 |
+
overlap_comm=True,
|
168 |
+
reduce_scatter=True,
|
169 |
+
reduce_bucket_size=5e8,
|
170 |
+
allgather_bucket_size=5e8,
|
171 |
+
load_full_weights=True
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
strategy = args.strategy
|
175 |
+
else:
|
176 |
+
strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto"
|
177 |
+
|
178 |
+
trainer = L.Trainer(
|
179 |
+
devices=args.num_gpus,
|
180 |
+
accelerator="gpu",
|
181 |
+
num_nodes = args.num_nodes,
|
182 |
+
strategy=strategy,
|
183 |
+
precision=args.precision,
|
184 |
+
accumulate_grad_batches=args.accum_batches,
|
185 |
+
callbacks=[demo_callback, exc_callback, save_model_config_callback, timer, pred_writer],
|
186 |
+
log_every_n_steps=1,
|
187 |
+
max_epochs=1000,
|
188 |
+
default_root_dir=args.save_dir,
|
189 |
+
gradient_clip_val=args.gradient_clip_val,
|
190 |
+
reload_dataloaders_every_n_epochs = 0,
|
191 |
+
check_val_every_n_epoch=2,
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
# ckpt_path = hf_hub_download(
|
197 |
+
# repo_id="liuhuadai/ThinkSound",
|
198 |
+
# filename="thinksound.ckpt"
|
199 |
+
# )
|
200 |
+
ckpt_path = 'ckpts/thinksound.ckpt'
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
current_date = datetime.now()
|
205 |
+
formatted_date = current_date.strftime('%m%d')
|
206 |
+
|
207 |
+
audio_dir = f'{formatted_date}_step68k_batch_size'+str(args.test_batch_size)
|
208 |
+
metrics_path = os.path.join(args.ckpt_dir, 'audios',audio_dir,'cache',"output_metrics.json")
|
209 |
+
# if os.path.exists(metrics_path): continue
|
210 |
+
|
211 |
+
trainer.predict(training_wrapper, dm, return_predictions=False,ckpt_path=ckpt_path)
|
212 |
+
|
213 |
+
if __name__ == '__main__':
|
214 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
requirements.txt
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
modelscope
|
2 |
+
absl-py==2.2.2
|
3 |
+
accelerate==1.6.0
|
4 |
+
aeiou==0.0.20
|
5 |
+
aiobotocore==2.22.0
|
6 |
+
aiofiles==23.2.1
|
7 |
+
aiohappyeyeballs==2.6.1
|
8 |
+
aiohttp==3.11.18
|
9 |
+
aioitertools==0.12.0
|
10 |
+
aiosignal==1.3.2
|
11 |
+
alias-free-torch==0.0.6
|
12 |
+
annotated-types==0.7.0
|
13 |
+
antlr4-python3-runtime==4.9.3
|
14 |
+
anyio==4.9.0
|
15 |
+
appdirs==1.4.4
|
16 |
+
argbind==0.3.9
|
17 |
+
asttokens==3.0.0
|
18 |
+
async-timeout==5.0.1
|
19 |
+
attrs==25.3.0
|
20 |
+
audiobox_aesthetics==0.0.2
|
21 |
+
audioread==3.0.1
|
22 |
+
auraloss==0.4.0
|
23 |
+
av==14.4.0
|
24 |
+
bleach==6.2.0
|
25 |
+
bokeh==3.7.3
|
26 |
+
botocore==1.37.3
|
27 |
+
braceexpand==0.1.7
|
28 |
+
Brotli==1.1.0
|
29 |
+
certifi==2025.4.26
|
30 |
+
cffi==1.17.1
|
31 |
+
charset-normalizer==3.4.1
|
32 |
+
clean-fid==0.1.35
|
33 |
+
click==8.1.8
|
34 |
+
clip-anytorch==2.6.0
|
35 |
+
cloudpickle==3.1.1
|
36 |
+
colorcet==3.1.0
|
37 |
+
colorlog==6.9.0
|
38 |
+
configparser==7.2.0
|
39 |
+
contourpy==1.3.2
|
40 |
+
cycler==0.12.1
|
41 |
+
Cython==3.1.1
|
42 |
+
dctorch==0.1.2
|
43 |
+
decorator==4.4.2
|
44 |
+
decord==0.6.0
|
45 |
+
descript-audio-codec==1.0.0
|
46 |
+
docker-pycreds==0.4.0
|
47 |
+
docstring_parser==0.16
|
48 |
+
einops==0.7.0
|
49 |
+
einops-exts==0.0.4
|
50 |
+
ema-pytorch==0.2.3
|
51 |
+
encodec==0.1.1
|
52 |
+
exceptiongroup==1.2.2
|
53 |
+
executing==2.2.0
|
54 |
+
fastapi==0.115.12
|
55 |
+
fastcore==1.8.2
|
56 |
+
ffmpeg==1.4
|
57 |
+
ffmpy==0.5.0
|
58 |
+
filelock==3.18.0
|
59 |
+
fire==0.7.0
|
60 |
+
flatten-dict==0.4.2
|
61 |
+
fonttools==4.58.0
|
62 |
+
frozenlist==1.6.0
|
63 |
+
fsspec==2025.5.0
|
64 |
+
ftfy==6.3.1
|
65 |
+
future==1.0.0
|
66 |
+
fvcore==0.1.5.post20221221
|
67 |
+
gin-config==0.5.0
|
68 |
+
gitdb==4.0.12
|
69 |
+
GitPython==3.1.44
|
70 |
+
gradio==3.50.0
|
71 |
+
gradio_client==0.6.1
|
72 |
+
groovy==0.1.2
|
73 |
+
grpcio==1.71.0
|
74 |
+
h11==0.16.0
|
75 |
+
hf_xet
|
76 |
+
h5py==3.13.0
|
77 |
+
hjson==3.1.0
|
78 |
+
holoviews==1.20.2
|
79 |
+
httpcore==1.0.9
|
80 |
+
httpx==0.28.1
|
81 |
+
huggingface-hub==0.30.2
|
82 |
+
hydra-colorlog==1.2.0
|
83 |
+
hydra-core==1.3.2
|
84 |
+
idna==3.10
|
85 |
+
imageio==2.37.0
|
86 |
+
imageio-ffmpeg==0.4.9
|
87 |
+
importlib-resources==5.12.0
|
88 |
+
importlib_metadata==8.7.0
|
89 |
+
iopath==0.1.10
|
90 |
+
ipython==8.36.0
|
91 |
+
jedi==0.19.2
|
92 |
+
Jinja2==3.1.0
|
93 |
+
jmespath==1.0.1
|
94 |
+
joblib==1.5.0
|
95 |
+
jsonmerge==1.9.2
|
96 |
+
jsonschema==4.23.0
|
97 |
+
jsonschema-specifications==2025.4.1
|
98 |
+
julius==0.2.7
|
99 |
+
k-diffusion==0.1.1
|
100 |
+
kiwisolver==1.4.8
|
101 |
+
kornia==0.8.1
|
102 |
+
kornia_rs==0.1.9
|
103 |
+
laion-clap==1.1.4
|
104 |
+
latex2mathml==3.77.0
|
105 |
+
lazy_loader==0.4
|
106 |
+
librosa==0.9.2
|
107 |
+
lightning==2.5.1.post0
|
108 |
+
lightning-utilities==0.14.3
|
109 |
+
linkify-it-py==2.0.3
|
110 |
+
llvmlite==0.43.0
|
111 |
+
local-attention==1.8.6
|
112 |
+
Markdown==3.8
|
113 |
+
markdown-it-py==3.0.0
|
114 |
+
markdown2==2.5.3
|
115 |
+
MarkupSafe==2.1.5
|
116 |
+
matplotlib==3.10.3
|
117 |
+
matplotlib-inline==0.1.7
|
118 |
+
mdit-py-plugins==0.4.2
|
119 |
+
mdurl==0.1.2
|
120 |
+
moviepy==1.0.3
|
121 |
+
mpmath==1.3.0
|
122 |
+
multidict==6.4.4
|
123 |
+
multiprocessing-logging==0.2.4
|
124 |
+
mutagen==1.47.0
|
125 |
+
narwhals==1.40.0
|
126 |
+
networkx==3.4.2
|
127 |
+
ninja==1.11.1.3
|
128 |
+
nitrous_ema==0.0.1
|
129 |
+
numba==0.60.0
|
130 |
+
numpy==1.23.5
|
131 |
+
omegaconf==2.3.0
|
132 |
+
open_clip_torch==2.32.0
|
133 |
+
openai==1.33.0
|
134 |
+
opencv-python==4.11.0.86
|
135 |
+
orjson==3.10.18
|
136 |
+
pafy==0.5.3.1
|
137 |
+
pandas==2.0.2
|
138 |
+
panel==1.7.0
|
139 |
+
param==2.2.0
|
140 |
+
parameterized==0.9.0
|
141 |
+
parso==0.8.4
|
142 |
+
pathtools==0.1.2
|
143 |
+
pedalboard==0.7.4
|
144 |
+
pexpect==4.9.0
|
145 |
+
pillow
|
146 |
+
platformdirs==4.3.8
|
147 |
+
plotly==6.1.1
|
148 |
+
pooch==1.8.2
|
149 |
+
prefigure==0.0.9
|
150 |
+
proglog==0.1.10
|
151 |
+
progressbar==2.5
|
152 |
+
prompt_toolkit==3.0.51
|
153 |
+
propcache==0.3.1
|
154 |
+
protobuf==3.19.6
|
155 |
+
psutil==7.0.0
|
156 |
+
ptyprocess==0.7.0
|
157 |
+
pure_eval==0.2.3
|
158 |
+
py-cpuinfo==9.0.0
|
159 |
+
pycparser==2.22
|
160 |
+
pydantic==2.11.5
|
161 |
+
pydantic_core==2.33.2
|
162 |
+
pydub==0.25.1
|
163 |
+
Pygments==2.19.1
|
164 |
+
pyloudnorm==0.1.1
|
165 |
+
pynndescent==0.5.13
|
166 |
+
pynvml==12.0.0
|
167 |
+
pyparsing==3.2.3
|
168 |
+
pystoi==0.4.1
|
169 |
+
pysubs2==1.8.0
|
170 |
+
python-dateutil==2.9.0.post0
|
171 |
+
python-dotenv==1.0.1
|
172 |
+
python-multipart==0.0.20
|
173 |
+
pytorch-lightning==2.5.1.post0
|
174 |
+
pytorchvideo==0.1.5
|
175 |
+
pytz==2025.2
|
176 |
+
pyviz_comms==3.0.4
|
177 |
+
PyWavelets==1.4.1
|
178 |
+
PyYAML==6.0.2
|
179 |
+
randomname==0.2.1
|
180 |
+
referencing==0.36.2
|
181 |
+
regex==2024.11.6
|
182 |
+
requests==2.32.3
|
183 |
+
resampy==0.4.3
|
184 |
+
rich==14.0.0
|
185 |
+
rpds-py==0.25.1
|
186 |
+
ruff==0.11.11
|
187 |
+
s3fs==2025.5.0
|
188 |
+
safehttpx==0.1.6
|
189 |
+
safetensors==0.5.3
|
190 |
+
scenedetect==0.6.3
|
191 |
+
scikit-image==0.24.0
|
192 |
+
scikit-learn==1.6.1
|
193 |
+
scipy==1.15.3
|
194 |
+
semantic-version==2.10.0
|
195 |
+
sentencepiece==0.1.99
|
196 |
+
sentry-sdk==2.29.1
|
197 |
+
setproctitle==1.3.6
|
198 |
+
shellingham==1.5.4
|
199 |
+
shortuuid==1.0.13
|
200 |
+
six==1.17.0
|
201 |
+
smmap==5.0.2
|
202 |
+
sniffio==1.3.1
|
203 |
+
SoundFile==0.10.2
|
204 |
+
sox==1.3.0
|
205 |
+
stack-data==0.6.3
|
206 |
+
starlette==0.46.2
|
207 |
+
submitit==1.5.2
|
208 |
+
svgwrite==1.4.3
|
209 |
+
sympy==1.13.1
|
210 |
+
tabulate==0.9.0
|
211 |
+
tensorboard-data-server==0.7.2
|
212 |
+
termcolor==3.1.0
|
213 |
+
threadpoolctl==3.6.0
|
214 |
+
tifffile==2025.5.10
|
215 |
+
timm==1.0.15
|
216 |
+
tokenizers==0.19
|
217 |
+
tomlkit==0.13.2
|
218 |
+
torch==2.4.0
|
219 |
+
torch-stoi==0.2.3
|
220 |
+
torchaudio==2.4.0
|
221 |
+
torchdiffeq==0.2.5
|
222 |
+
torchlibrosa==0.1.0
|
223 |
+
torchmetrics==0.11.4
|
224 |
+
torchsde==0.2.6
|
225 |
+
torchvision==0.19.0
|
226 |
+
tornado==6.5.1
|
227 |
+
git+https://github.com/patrick-kidger/torchcubicspline.git
|
228 |
+
tqdm==4.67.1
|
229 |
+
traitlets==5.14.3
|
230 |
+
trampoline==0.1.2
|
231 |
+
transformers==4.43
|
232 |
+
triton==3.0.0
|
233 |
+
typer==0.15.4
|
234 |
+
typing-inspection==0.4.1
|
235 |
+
typing_extensions==4.12.2
|
236 |
+
tzdata==2025.2
|
237 |
+
uc-micro-py==1.0.3
|
238 |
+
umap-learn==0.5.7
|
239 |
+
urllib3==2.4.0
|
240 |
+
uvicorn==0.34.2
|
241 |
+
v-diffusion-pytorch==0.0.2
|
242 |
+
vector-quantize-pytorch==1.9.14
|
243 |
+
wcwidth==0.2.13
|
244 |
+
webdataset==0.2.48
|
245 |
+
webencodings==0.5.1
|
246 |
+
Werkzeug==3.1.3
|
247 |
+
wget==3.2
|
248 |
+
wrapt==1.17.2
|
249 |
+
x-transformers==1.26.6
|
250 |
+
xyzservices==2025.4.0
|
251 |
+
yacs==0.1.8
|
252 |
+
yarl==1.20.0
|
253 |
+
zipp==3.21.0
|
254 |
+
altair==5.5.0
|
scripts/demo.sh
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Check number of arguments
|
4 |
+
if [ "$#" -ne 3 ]; then
|
5 |
+
echo "Usage: $0 <video_path> <title> <description>"
|
6 |
+
exit 1
|
7 |
+
fi
|
8 |
+
|
9 |
+
VIDEO_PATH="$1"
|
10 |
+
TITLE="$2"
|
11 |
+
DESCRIPTION="$3"
|
12 |
+
|
13 |
+
# Generate unique ID
|
14 |
+
UNIQUE_ID=$(uuidgen | cut -c 1-8)
|
15 |
+
|
16 |
+
# Create necessary directories
|
17 |
+
mkdir -p videos cot_coarse results
|
18 |
+
|
19 |
+
# Get video filename and extension
|
20 |
+
VIDEO_FILE=$(basename "$VIDEO_PATH")
|
21 |
+
VIDEO_EXT="${VIDEO_FILE##*.}"
|
22 |
+
VIDEO_ID="${VIDEO_FILE%.*}"
|
23 |
+
TEMP_VIDEO_PATH="videos/${VIDEO_ID}_${UNIQUE_ID}.mp4"
|
24 |
+
|
25 |
+
# Convert video to MP4 format if needed
|
26 |
+
if [ "${VIDEO_EXT,,}" != "mp4" ]; then
|
27 |
+
echo "⏳ Converting video to MP4 format..."
|
28 |
+
ffmpeg -y -i "$VIDEO_PATH" -c:v libx264 -preset fast -c:a aac -strict experimental "$TEMP_VIDEO_PATH" >/dev/null 2>&1
|
29 |
+
if [ $? -ne 0 ]; then
|
30 |
+
echo "❌ Video conversion failed"
|
31 |
+
exit 2
|
32 |
+
fi
|
33 |
+
else
|
34 |
+
cp "$VIDEO_PATH" "$TEMP_VIDEO_PATH"
|
35 |
+
fi
|
36 |
+
|
37 |
+
# Calculate video duration
|
38 |
+
DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$TEMP_VIDEO_PATH")
|
39 |
+
DURATION_SEC=${DURATION%.*}
|
40 |
+
echo "Duration is: $DURATION_SEC"
|
41 |
+
|
42 |
+
# Create cot.csv file
|
43 |
+
CAPTION_COT=$(echo "$DESCRIPTION" | tr '"' "'")
|
44 |
+
CSV_PATH="cot_coarse/cot.csv"
|
45 |
+
echo "id,caption,caption_cot" > "$CSV_PATH"
|
46 |
+
echo "${VIDEO_ID}_${UNIQUE_ID},$TITLE,\"$CAPTION_COT\"" >> "$CSV_PATH"
|
47 |
+
|
48 |
+
# Run feature extraction
|
49 |
+
echo "⏳ Extracting features..."
|
50 |
+
python extract_latents.py --duration_sec "$DURATION_SEC" 2>&1
|
51 |
+
if [ $? -ne 0 ]; then
|
52 |
+
echo "❌ Feature extraction failed"
|
53 |
+
rm -f "$TEMP_VIDEO_PATH"
|
54 |
+
exit 3
|
55 |
+
fi
|
56 |
+
|
57 |
+
# Run inference
|
58 |
+
echo "⏳ Running model inference..."
|
59 |
+
bash scripts/infer.sh --duration-sec "$DURATION_SEC" 2>&1
|
60 |
+
if [ $? -ne 0 ]; then
|
61 |
+
echo "❌ Inference failed"
|
62 |
+
rm -f "$TEMP_VIDEO_PATH"
|
63 |
+
exit 4
|
64 |
+
fi
|
65 |
+
|
66 |
+
# Get generated audio file
|
67 |
+
CURRENT_DATE=$(date +"%m%d")
|
68 |
+
AUDIO_PATH="results/${CURRENT_DATE}_batch_size1/demo.wav"
|
69 |
+
|
70 |
+
# Check if audio file exists
|
71 |
+
if [ ! -f "$AUDIO_PATH" ]; then
|
72 |
+
echo "❌ Generated audio file not found"
|
73 |
+
rm -f "$TEMP_VIDEO_PATH"
|
74 |
+
exit 5
|
75 |
+
fi
|
76 |
+
|
77 |
+
# Clean up temporary video file
|
78 |
+
rm -f "$TEMP_VIDEO_PATH"
|
79 |
+
|
80 |
+
|
81 |
+
echo "✅ Audio generated successfully!"
|
82 |
+
echo "Audio file path: $AUDIO_PATH"
|
scripts/infer.sh
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# 变量定义
|
4 |
+
ckpt_dir="ckpts/thinksound.ckpt"
|
5 |
+
test_batch_size=1
|
6 |
+
dataset_config="ThinkSound/configs/multimodal_dataset_demo.json"
|
7 |
+
model_config="ThinkSound/configs/model_configs/thinksound.json"
|
8 |
+
pretransform_ckpt_path="ckpts/vae.ckpt"
|
9 |
+
# 默认值
|
10 |
+
debug_mode="true"
|
11 |
+
node_rank=0
|
12 |
+
|
13 |
+
result_path="results"
|
14 |
+
|
15 |
+
while [[ $# -gt 0 ]]; do
|
16 |
+
case "$1" in
|
17 |
+
--duration-sec)
|
18 |
+
if [[ -n "$2" && "$2" != --* ]]; then
|
19 |
+
duration_sec="$2"
|
20 |
+
shift 2
|
21 |
+
else
|
22 |
+
echo "❌ Argument --duration-sec requires a value"
|
23 |
+
exit 1
|
24 |
+
fi
|
25 |
+
;;
|
26 |
+
--result-path)
|
27 |
+
if [[ -n "$2" && "$2" != --* ]]; then
|
28 |
+
result_path="$2"
|
29 |
+
shift 2
|
30 |
+
else
|
31 |
+
echo "❌ Argument --result-path requires a path"
|
32 |
+
exit 1
|
33 |
+
fi
|
34 |
+
;;
|
35 |
+
*)
|
36 |
+
echo "❌ Unknown argument: $1"
|
37 |
+
exit 1
|
38 |
+
;;
|
39 |
+
esac
|
40 |
+
done
|
41 |
+
|
42 |
+
export NODE_RANK=$node_rank
|
43 |
+
export RANK=$node_rank
|
44 |
+
|
45 |
+
num_gpus=1
|
46 |
+
num_nodes=1
|
47 |
+
|
48 |
+
export WORLD_SIZE=$((num_gpus * num_nodes))
|
49 |
+
# 打印配置信息
|
50 |
+
echo "Training Configuration:"
|
51 |
+
echo "Checkpoint Directory: $ckpt_dir"
|
52 |
+
echo "Dataset Config: $dataset_config"
|
53 |
+
echo "Model Config: $model_config"
|
54 |
+
echo "Pretransform Checkpoint Path: $pretransform_ckpt_path"
|
55 |
+
echo "Num GPUs: $num_gpus"
|
56 |
+
echo "Num Nodes: $num_nodes"
|
57 |
+
echo "Test Batch Size: $test_batch_size"
|
58 |
+
echo "Num Workers: 20"
|
59 |
+
echo "Node Rank: $node_rank"
|
60 |
+
echo "WORLD SIZE: $WORLD_SIZE"
|
61 |
+
|
62 |
+
|
63 |
+
python predict.py \
|
64 |
+
--dataset-config "$dataset_config" \
|
65 |
+
--model-config "$model_config" \
|
66 |
+
--ckpt-dir "$ckpt_dir" \
|
67 |
+
--pretransform-ckpt-path "$pretransform_ckpt_path" \
|
68 |
+
--checkpoint-every 2000 \
|
69 |
+
--num-gpus "$num_gpus" \
|
70 |
+
--num-nodes "$num_nodes" \
|
71 |
+
--batch-size 1 \
|
72 |
+
--test-batch-size $test_batch_size \
|
73 |
+
--num-workers 32 \
|
74 |
+
--duration-sec $duration_sec \
|
75 |
+
--results-dir $result_path \
|
76 |
+
|
setup.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='thinksound',
|
5 |
+
version='0.0.16',
|
6 |
+
url='https://github.com/liuhuadai/thinksound.git',
|
7 |
+
author='liuhuadai',
|
8 |
+
description='a unified Any2Audio generation framework guided by Chain-of-Thought (CoT) reasoning',
|
9 |
+
packages=find_packages(),
|
10 |
+
install_requires=[
|
11 |
+
'aeiou==0.0.20',
|
12 |
+
'alias-free-torch==0.0.6',
|
13 |
+
'auraloss==0.4.0',
|
14 |
+
'descript-audio-codec==1.0.0',
|
15 |
+
'einops==0.7.0',
|
16 |
+
'einops-exts==0.0.4',
|
17 |
+
'ema-pytorch==0.2.3',
|
18 |
+
'encodec==0.1.1',
|
19 |
+
# 'gradio>=3.42.0',
|
20 |
+
'huggingface_hub',
|
21 |
+
'importlib-resources==5.12.0',
|
22 |
+
'k-diffusion==0.1.1',
|
23 |
+
'laion-clap==1.1.4',
|
24 |
+
'local-attention==1.8.6',
|
25 |
+
'pandas==2.0.2',
|
26 |
+
'pedalboard==0.7.4',
|
27 |
+
'prefigure==0.0.9',
|
28 |
+
'pytorch_lightning==2.1.0',
|
29 |
+
'PyWavelets==1.4.1',
|
30 |
+
'safetensors',
|
31 |
+
'sentencepiece==0.1.99',
|
32 |
+
's3fs',
|
33 |
+
'torch>=2.0.1',
|
34 |
+
'torchaudio>=2.0.2',
|
35 |
+
'torchmetrics==0.11.4',
|
36 |
+
'tqdm',
|
37 |
+
'transformers',
|
38 |
+
'v-diffusion-pytorch==0.0.2',
|
39 |
+
'vector-quantize-pytorch==1.9.14',
|
40 |
+
'wandb==0.15.4',
|
41 |
+
'webdataset==0.2.48',
|
42 |
+
'x-transformers<1.27.0'
|
43 |
+
],
|
44 |
+
)
|