lorocksUMD commited on
Commit
099ac14
·
verified ·
1 Parent(s): 7ffd6dd

Upload 32 files

Browse files
DenseAV/denseav/aggregators.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from tqdm import tqdm
8
 
9
- from constants import *
10
 
11
 
12
  @torch.jit.script
 
6
  import torch.nn.functional as F
7
  from tqdm import tqdm
8
 
9
+ from DenseAV.denseav.constants import *
10
 
11
 
12
  @torch.jit.script
DenseAV/denseav/aligners.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import torch.nn.functional as F
5
  from torch.nn import ModuleList
6
 
7
- from featurizers.DINO import Block
8
 
9
 
10
  class ChannelNorm(torch.nn.Module):
 
4
  import torch.nn.functional as F
5
  from torch.nn import ModuleList
6
 
7
+ from DenseAV.denseav.featurizers.DINO import Block
8
 
9
 
10
  class ChannelNorm(torch.nn.Module):
DenseAV/denseav/eval_utils.py CHANGED
@@ -9,7 +9,7 @@ from torchmetrics.functional.classification import binary_average_precision
9
  from tqdm import tqdm
10
 
11
  from constants import *
12
- from shared import unnorm, remove_axes
13
 
14
 
15
  def prep_heatmap(sims, masks, h, w):
 
9
  from tqdm import tqdm
10
 
11
  from constants import *
12
+ from DenseAV.denseav.shared import unnorm, remove_axes
13
 
14
 
15
  def prep_heatmap(sims, masks, h, w):
DenseAV/denseav/evaluate.py CHANGED
@@ -4,8 +4,8 @@ from omegaconf import DictConfig, OmegaConf
4
  from pytorch_lightning import Trainer
5
  from pytorch_lightning import seed_everything
6
  from pytorch_lightning.loggers import TensorBoardLogger
7
- from data.AVDatasets import AVDataModule
8
- from shared import load_trained_model
9
 
10
 
11
  @hydra.main(config_path="configs", config_name="av_align.yaml")
 
4
  from pytorch_lightning import Trainer
5
  from pytorch_lightning import seed_everything
6
  from pytorch_lightning.loggers import TensorBoardLogger
7
+ from DenseAV.denseav.data.AVDatasets import AVDataModule
8
+ from DenseAV.denseav.shared import load_trained_model
9
 
10
 
11
  @hydra.main(config_path="configs", config_name="av_align.yaml")
DenseAV/denseav/plotting.py CHANGED
@@ -10,7 +10,7 @@ import torch.nn.functional as F
10
  import torchvision
11
  from moviepy.editor import VideoFileClip, AudioFileClip
12
  from base64 import b64encode
13
- from shared import pca
14
 
15
 
16
  def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
 
10
  import torchvision
11
  from moviepy.editor import VideoFileClip, AudioFileClip
12
  from base64 import b64encode
13
+ from DenseAV.denseav.shared import pca
14
 
15
 
16
  def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
DenseAV/denseav/shared.py CHANGED
@@ -90,37 +90,37 @@ def get_image_featurizer(name, token_type="key", **kwargs):
90
  name = name.lower()
91
 
92
  if name == "vit":
93
- from featurizers.DINO import DINOFeaturizer
94
  patch_size = 16
95
  model = DINOFeaturizer("vit_small_patch16_224", patch_size, token_type)
96
  dim = 384
97
  elif name == "dino16":
98
- from featurizers.DINO import DINOFeaturizer
99
  patch_size = 16
100
  model = DINOFeaturizer("dino_vits16", patch_size, token_type)
101
  dim = 384
102
  elif name == "dino8":
103
- from featurizers.DINO import DINOFeaturizer
104
  patch_size = 8
105
  model = DINOFeaturizer("dino_vits8", patch_size, token_type)
106
  dim = 384
107
  elif name == "clip":
108
- from featurizers.CLIP import CLIPFeaturizer
109
  patch_size = 16
110
  model = CLIPFeaturizer()
111
  dim = 512
112
  elif name == "cavmae":
113
- from featurizers.CAVMAE import CAVMAEImageFeaturizer
114
  model = CAVMAEImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
115
  dim = 768
116
  patch_size = 16
117
  elif name == "fnac":
118
- from featurizers.FNACAVL import FNACImageFeaturizer
119
  model = FNACImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
120
  dim = 512
121
  patch_size = 16
122
  elif name == "imagebind":
123
- from featurizers.ImageBind import ImageBindImageFeaturizer
124
  model = ImageBindImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
125
  dim = 1024
126
  patch_size = 16
@@ -131,12 +131,12 @@ def get_image_featurizer(name, token_type="key", **kwargs):
131
  patch_size = 1
132
  dim = 2048
133
  elif name == "davenet":
134
- from featurizers.DAVENet import DavenetImageFeaturizer
135
  model = DavenetImageFeaturizer()
136
  patch_size = 1
137
  dim = 1024
138
  elif name == "dinov2":
139
- from featurizers.DINOv2 import DINOv2Featurizer
140
  model = DINOv2Featurizer()
141
  patch_size = 14
142
  dim = 768
@@ -147,29 +147,29 @@ def get_image_featurizer(name, token_type="key", **kwargs):
147
 
148
  def get_audio_featurizer(name, **kwargs):
149
  if name == "davenet":
150
- from featurizers.DAVENet import DavenetAudioFeaturizer
151
  model = DavenetAudioFeaturizer()
152
  dim = 1024
153
  elif name == "dino8":
154
  model, _, dim = get_image_featurizer("dino8")
155
  elif name == "hubert":
156
- from featurizers.Hubert import Hubert
157
  model = Hubert()
158
  dim = 1024
159
  elif name == "cavmae":
160
- from featurizers.CAVMAE import CAVMAEAudioFeaturizer
161
  model = CAVMAEAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
162
  dim = 768
163
  elif name == "imagebind":
164
- from featurizers.ImageBind import ImageBindAudioFeaturizer
165
  model = ImageBindAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
166
  dim = 1024
167
  elif name == "audiomae":
168
- from featurizers.AudioMAE import AudioMAE
169
  model = AudioMAE(kwargs["output_root"], False)
170
  dim = 768
171
  elif name == "audiomae-finetuned":
172
- from featurizers.AudioMAE import AudioMAE
173
  model = AudioMAE(kwargs["output_root"], True)
174
  dim = 768
175
  else:
 
90
  name = name.lower()
91
 
92
  if name == "vit":
93
+ from DenseAV.denseav.featurizers.DINO import DINOFeaturizer
94
  patch_size = 16
95
  model = DINOFeaturizer("vit_small_patch16_224", patch_size, token_type)
96
  dim = 384
97
  elif name == "dino16":
98
+ from DenseAV.denseav.featurizers.DINO import DINOFeaturizer
99
  patch_size = 16
100
  model = DINOFeaturizer("dino_vits16", patch_size, token_type)
101
  dim = 384
102
  elif name == "dino8":
103
+ from DenseAV.denseav.featurizers.DINO import DINOFeaturizer
104
  patch_size = 8
105
  model = DINOFeaturizer("dino_vits8", patch_size, token_type)
106
  dim = 384
107
  elif name == "clip":
108
+ from DenseAV.denseav.featurizers.CLIP import CLIPFeaturizer
109
  patch_size = 16
110
  model = CLIPFeaturizer()
111
  dim = 512
112
  elif name == "cavmae":
113
+ from DenseAV.denseav.featurizers.CAVMAE import CAVMAEImageFeaturizer
114
  model = CAVMAEImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
115
  dim = 768
116
  patch_size = 16
117
  elif name == "fnac":
118
+ from DenseAV.denseav.featurizers.FNACAVL import FNACImageFeaturizer
119
  model = FNACImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
120
  dim = 512
121
  patch_size = 16
122
  elif name == "imagebind":
123
+ from DenseAV.denseav.featurizers.ImageBind import ImageBindImageFeaturizer
124
  model = ImageBindImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
125
  dim = 1024
126
  patch_size = 16
 
131
  patch_size = 1
132
  dim = 2048
133
  elif name == "davenet":
134
+ from fDenseAV.denseav.eaturizers.DAVENet import DavenetImageFeaturizer
135
  model = DavenetImageFeaturizer()
136
  patch_size = 1
137
  dim = 1024
138
  elif name == "dinov2":
139
+ from DenseAV.denseav.featurizers.DINOv2 import DINOv2Featurizer
140
  model = DINOv2Featurizer()
141
  patch_size = 14
142
  dim = 768
 
147
 
148
  def get_audio_featurizer(name, **kwargs):
149
  if name == "davenet":
150
+ from DenseAV.denseav.featurizers.DAVENet import DavenetAudioFeaturizer
151
  model = DavenetAudioFeaturizer()
152
  dim = 1024
153
  elif name == "dino8":
154
  model, _, dim = get_image_featurizer("dino8")
155
  elif name == "hubert":
156
+ from DenseAV.denseav.featurizers.Hubert import Hubert
157
  model = Hubert()
158
  dim = 1024
159
  elif name == "cavmae":
160
+ from DenseAV.denseav.featurizers.CAVMAE import CAVMAEAudioFeaturizer
161
  model = CAVMAEAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
162
  dim = 768
163
  elif name == "imagebind":
164
+ from DenseAV.denseav.featurizers.ImageBind import ImageBindAudioFeaturizer
165
  model = ImageBindAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
166
  dim = 1024
167
  elif name == "audiomae":
168
+ from DenseAV.denseav.featurizers.AudioMAE import AudioMAE
169
  model = AudioMAE(kwargs["output_root"], False)
170
  dim = 768
171
  elif name == "audiomae-finetuned":
172
+ from DenseAV.denseav.featurizers.AudioMAE import AudioMAE
173
  model = AudioMAE(kwargs["output_root"], True)
174
  dim = 768
175
  else:
DenseAV/denseav/train.py CHANGED
@@ -1,1222 +1,1222 @@
1
- import os
2
- from collections import deque
3
- from itertools import combinations
4
- from os.path import join
5
-
6
- import hydra
7
- import numpy as np
8
- import pytorch_lightning as pl
9
- import torch
10
- import torch.distributed as dist
11
- import torch.nn.functional as F
12
- from omegaconf import DictConfig, OmegaConf
13
- from peft import get_peft_model, LoraConfig
14
- from pytorch_lightning import Trainer
15
- from pytorch_lightning import seed_everything
16
- from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
17
- from pytorch_lightning.loggers import TensorBoardLogger
18
- from pytorch_lightning.utilities import grad_norm
19
- from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, SequentialLR, LambdaLR
20
- from torchmetrics.functional.classification import binary_average_precision
21
-
22
- from huggingface_hub import PyTorchModelHubMixin
23
-
24
- from DenseAV.denseav.aggregators import get_aggregator
25
- from aligners import get_aligner, ProgressiveGrowing
26
- from constants import *
27
- from data.AVDatasets import AVDataModule
28
- from shared import flatten_preds, GatherLayer, \
29
- get_image_featurizer, get_audio_featurizer, RollingAvg, create_model_from_cfg
30
-
31
- torch.multiprocessing.set_sharing_strategy('file_system')
32
-
33
-
34
- def _imposter_indices_helper(true_indices: torch.Tensor, samples: torch.Tensor):
35
- mask = (true_indices == samples).to(torch.int64)
36
- n = mask.shape[0]
37
-
38
- if not mask.any():
39
- return samples
40
- else:
41
- new_samples = torch.randint(0, n, size=(n,), device=true_indices.device)
42
- comb_samples = mask * new_samples + (1 - mask) * samples
43
- return _imposter_indices_helper(true_indices, comb_samples)
44
-
45
-
46
- def imposter_indices(n, device):
47
- return _imposter_indices_helper(
48
- torch.arange(0, n, device=device),
49
- torch.randint(0, n, size=(n,), device=device))
50
-
51
-
52
- def get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type):
53
- max_t = audio_outputs.shape[-1]
54
- oh = F.one_hot(n_frames - 1, num_classes=max_t)
55
- audio_mask = 1 - torch.cumsum(oh, dim=1)
56
- audio_mask = F.pad(audio_mask, [1, 0], value=1)[:, :max_t].to(audio_outputs.dtype)
57
-
58
- full_sim = torch.einsum("bct,bchw->bthw", audio_outputs, image_outputs)
59
- expanded_am = audio_mask.unsqueeze(-1).unsqueeze(-1)
60
-
61
- if sim_type.endswith("mi"):
62
- offset = 10 * (full_sim.max() - full_sim.min())
63
- full_sim = (full_sim - ((1 - expanded_am) * offset)).max(1, keepdim=True).values
64
-
65
- if sim_type.startswith("mi"):
66
- full_sim = full_sim.max(-1, keepdim=True).values.max(-2, keepdim=True).values
67
-
68
- if sim_type.endswith("sa"):
69
- full_sim = (full_sim * (expanded_am / expanded_am.sum(1, keepdim=True).clamp_min(1))).sum(1, keepdim=True)
70
-
71
- return full_sim.mean(dim=[1, 2, 3])
72
-
73
-
74
- def sampled_margin_rank_loss(image_outputs, audio_outputs, n_frames, sim_type, margin=1.):
75
- """
76
- Computes the triplet margin ranking loss for each anchor image/caption pair
77
- The impostor image/caption is randomly sampled from the minibatch
78
- """
79
- assert (image_outputs.dim() == 4)
80
- assert (audio_outputs.dim() == 3)
81
- n = image_outputs.size(0)
82
- imp_ind_i = imposter_indices(n, image_outputs.device)
83
- imp_ind_a = imposter_indices(n, image_outputs.device)
84
- true_sim = get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type)
85
- imp_sim_i = get_sim_per_row(image_outputs[imp_ind_i], audio_outputs, n_frames, sim_type)
86
- imp_sim_a = get_sim_per_row(image_outputs, audio_outputs[imp_ind_a], n_frames[imp_ind_a], sim_type)
87
- a2i_loss = (margin + imp_sim_i - true_sim).clamp_min(0)
88
- i2a_loss = (margin + imp_sim_a - true_sim).clamp_min(0)
89
- return (a2i_loss + i2a_loss).mean() / 2
90
-
91
-
92
- class SimilarityCalibrator(torch.nn.Module):
93
-
94
- def __init__(self, cal_init, max_w=100, min_w=.01, subtract_mean=True, use_bias=False):
95
- super().__init__()
96
- self.max_w = max_w
97
- self.min_w = min_w
98
- self.w = torch.nn.Parameter(torch.tensor([cal_init]).log())
99
-
100
- self.use_bias = use_bias
101
- if self.use_bias:
102
- self.b = torch.nn.Parameter(torch.tensor([0.0]))
103
-
104
- self.subtract_mean = subtract_mean
105
-
106
- def get_w(self):
107
- return torch.exp(self.w).clamp_max(self.max_w).clamp_min(self.min_w)
108
-
109
- def forward(self, x):
110
- sims = self.get_w() * x
111
-
112
- if self.use_bias:
113
- sims = sims + self.b
114
-
115
- if self.subtract_mean:
116
- return sims - sims.mean()
117
- else:
118
- return sims
119
-
120
-
121
- class SpatialDropout(torch.nn.Module):
122
-
123
- def __init__(self, p, *args, **kwargs):
124
- super().__init__(*args, **kwargs)
125
- self.p = p
126
-
127
- def forward(self, x):
128
- b, c, h, w = x.shape
129
- dropout = torch.rand((b, 1, h, w), dtype=x.dtype, device=x.device) > self.p
130
-
131
- if self.training:
132
- return x * dropout
133
- else:
134
- return x
135
-
136
-
137
- class LitAVAligner(pl.LightningModule, PyTorchModelHubMixin, repo_url="https://github.com/mhamilton723/DenseAV", license="mit", tags=["denseav"]):
138
- def __init__(self,
139
- code_dim,
140
- image_model_type,
141
- image_model_token_type,
142
- image_aligner_type,
143
- image_pool_width,
144
- audio_model_type,
145
- audio_aligner_type,
146
- audio_pool_width,
147
- audio_lora,
148
- audio_lora_rank,
149
- image_lora,
150
- image_lora_rank,
151
- gradient_clipping,
152
- learn_audio_cls,
153
- silence_l1,
154
- silence_l2,
155
- tv_weight,
156
- nonneg_sim,
157
- nonneg_pressure,
158
- pretrain_lr,
159
- lr,
160
- lr_warmup,
161
- lr_schedule,
162
- lr_cycle_length,
163
- optimizer,
164
- gather_tensors,
165
- sim_agg_type,
166
- sim_agg_heads,
167
- sim_use_cls,
168
- disentangle_weight,
169
- norm_vectors,
170
- cal_init,
171
- cal_balance_weight,
172
- loss_type,
173
- loss_margin,
174
- mask_silence,
175
- finetune_image_model,
176
- finetune_audio_model,
177
- use_cached_embs,
178
- output_root,
179
- neg_audio,
180
- neg_audio_weight,
181
- head_agg,
182
- adaptive_clipping,
183
- specialization_weight,
184
- spatial_dropout,
185
- channel_dropout,
186
- mixup_weight,
187
- memory_buffer_size,
188
- loss_leak,
189
- ):
190
- super().__init__()
191
-
192
- self.code_dim = code_dim
193
- self.image_model_type = image_model_type
194
- self.image_model_token_type = image_model_token_type
195
- self.image_aligner_type = image_aligner_type
196
- self.image_pool_width = image_pool_width
197
- self.audio_model_type = audio_model_type
198
- self.audio_aligner_type = audio_aligner_type
199
- self.audio_pool_width = audio_pool_width
200
-
201
- self.gradient_clipping = gradient_clipping
202
- self.learn_audio_cls = learn_audio_cls
203
- self.silence_l1 = silence_l1
204
- self.silence_l2 = silence_l2
205
-
206
- self.tv_weight = tv_weight
207
- self.nonneg_sim = nonneg_sim
208
- self.nonneg_pressure = nonneg_pressure
209
- self.pretrain_lr = pretrain_lr
210
- self.lr = lr
211
- self.lr_warmup = lr_warmup
212
- self.lr_schedule = lr_schedule
213
- self.lr_cycle_length = lr_cycle_length
214
- self.optimizer = optimizer
215
- self.gather_tensors = gather_tensors
216
- self.sim_agg_type = sim_agg_type
217
- self.sim_agg_heads = sim_agg_heads
218
- self.sim_use_cls = sim_use_cls
219
- self.disentangle_weight = disentangle_weight
220
-
221
- self.norm_vectors = norm_vectors
222
- self.cal_init = cal_init
223
- self.cal_balance_weight = cal_balance_weight
224
- self.loss_type = loss_type
225
- self.loss_margin = loss_margin
226
- self.mask_silence = mask_silence
227
- self.finetune_image_model = finetune_image_model
228
- self.finetune_audio_model = finetune_audio_model
229
- self.use_cached_embs = use_cached_embs
230
- self.output_root = output_root
231
- self.audio_lora = audio_lora
232
- self.audio_lora_rank = audio_lora_rank
233
- self.image_lora = image_lora
234
- self.image_lora_rank = image_lora_rank
235
- self.neg_audio = neg_audio
236
- self.neg_audio_weight = neg_audio_weight
237
- self.head_agg = head_agg
238
-
239
- self.adaptive_clipping = adaptive_clipping
240
- self.specialization_weight = specialization_weight
241
- self.spatial_dropout = spatial_dropout
242
- self.channel_dropout = channel_dropout
243
- self.mixup_weight = mixup_weight
244
-
245
- self.memory_buffer_size = memory_buffer_size
246
- self.memory_buffer = deque(maxlen=self.memory_buffer_size)
247
- self.loss_leak = loss_leak
248
-
249
- self.full_train = False # Added by me
250
-
251
- if self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
252
- self.audio_input = "spec"
253
- elif self.audio_model_type == "davenet":
254
- self.audio_input = "davenet_spec"
255
- elif self.audio_model_type == "fnac":
256
- self.audio_input = "fnac_spec"
257
- else:
258
- self.audio_input = "audio"
259
-
260
- extra_model_args = dict(output_root=output_root)
261
-
262
- self.image_model, _, self.image_feat_dim = get_image_featurizer(
263
- image_model_type, token_type=self.image_model_token_type, **extra_model_args)
264
-
265
- self.image_model.eval()
266
- if not self.finetune_image_model:
267
- for param in self.image_model.parameters():
268
- param.requires_grad = False
269
-
270
- if image_model_type in {"cavmae", "cavmae-mixed", "imagebind", "fnac"}:
271
- extra_model_args["model"] = self.image_model.model
272
-
273
- if use_cached_embs:
274
- _, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
275
- else:
276
- self.audio_model, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
277
-
278
- self.audio_model.eval()
279
- if not self.finetune_audio_model:
280
- for param in self.audio_model.parameters():
281
- param.requires_grad = False
282
-
283
- if self.image_lora:
284
- if self.image_model_type in {"sam", "dino8", "dinov2", "cavmae", "cavmae-mixed"}:
285
- target_modules = ["qkv"]
286
- elif self.image_model_type == "clip":
287
- target_modules = ["out_proj"]
288
- elif self.image_model_type == "imagebind":
289
- target_modules = ["out_proj", "fc1", "fc2"]
290
- else:
291
- target_modules = ["q", "k", "v"]
292
-
293
- peft_config = LoraConfig(
294
- target_modules=target_modules,
295
- inference_mode=False,
296
- r=image_lora_rank,
297
- lora_alpha=32,
298
- lora_dropout=0.1
299
- )
300
- self.image_model = get_peft_model(self.image_model, peft_config)
301
- self.image_model.print_trainable_parameters()
302
-
303
- if self.audio_lora:
304
- if self.audio_model_type == "hubert":
305
- target_modules = ["q_proj", "k_proj", "v_proj"]
306
- else:
307
- target_modules = ["q", "k", "v"]
308
-
309
- peft_config = LoraConfig(
310
- inference_mode=False,
311
- target_modules=target_modules,
312
- r=audio_lora_rank,
313
- lora_alpha=32,
314
- lora_dropout=0.1
315
- )
316
- self.audio_model = get_peft_model(self.audio_model, peft_config)
317
- self.audio_model.print_trainable_parameters()
318
-
319
- shared_aligner_args = dict(out_dim=self.code_dim)
320
-
321
- self.audio_aligner = get_aligner(
322
- self.audio_aligner_type, self.audio_feat_dim, **shared_aligner_args)
323
- self.image_aligner = get_aligner(
324
- self.image_aligner_type, self.image_feat_dim, **shared_aligner_args)
325
-
326
- if self.loss_type == "nce":
327
- self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=True, use_bias=False)
328
- else:
329
- self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=False, use_bias=True)
330
-
331
- if self.learn_audio_cls:
332
- self.audio_cls = torch.nn.Parameter(torch.randn(self.audio_feat_dim))
333
-
334
- if self.spatial_dropout > 0.0:
335
- self.spatial_dropout_layer = SpatialDropout(self.spatial_dropout)
336
-
337
- if self.channel_dropout > 0.0:
338
- self.channel_dropout_layer = torch.nn.Dropout2d(self.channel_dropout)
339
-
340
- self.sim_agg = get_aggregator(
341
- self.sim_agg_type,
342
- self.nonneg_sim,
343
- self.mask_silence,
344
- self.sim_agg_heads,
345
- self.head_agg,
346
- self.sim_use_cls,
347
- dim=self.image_feat_dim
348
- )
349
-
350
- self.hparams_logged = False
351
- self.rolling_avg = RollingAvg(50)
352
- self.grad_avg = RollingAvg(50, nonzero=True)
353
-
354
- self.save_hyperparameters()
355
-
356
- def set_full_train(self, full_train):
357
- self.full_train = full_train
358
-
359
- def prep_feats(self, feats, is_audio):
360
-
361
- if not is_audio and self.training and self.image_pool_width > 1:
362
- feats = torch.nn.AvgPool2d(self.image_pool_width)(feats)
363
-
364
- if is_audio and self.training and self.audio_pool_width > 1:
365
- feats = torch.nn.AvgPool2d((1, self.audio_pool_width))(feats)
366
-
367
- if self.norm_vectors:
368
- feats = F.normalize(feats, dim=1)
369
-
370
- return feats
371
-
372
- def on_before_optimizer_step(self, optimizer, optimizer_idx):
373
- norms = grad_norm(self, norm_type=2)
374
- avg_grads = self.grad_avg.get_all()
375
- params = {
376
- f"grad_2.0_norm/{name}": p
377
- for name, p in self.named_parameters()
378
- if p.grad is not None
379
- }
380
-
381
- if self.adaptive_clipping:
382
- for k in norms.keys():
383
- if k in params:
384
- avg_grad = max(avg_grads.get(k, norms[k]), 1e-5)
385
- if self.global_step > 10 and norms[k] > avg_grad * 5:
386
- print(f"Bad grad for {k}: {norms[k]} scaling to {avg_grad * 5}")
387
- torch.nn.utils.clip_grad_norm_(params[k], avg_grad * 5)
388
- norms[k] = avg_grad * 5
389
-
390
- if norms[k] > self.gradient_clipping:
391
- # print(f"Bad grad for {k}: {norms[k]} scaling to {self.gradient_clipping}")
392
- torch.nn.utils.clip_grad_norm_(params[k], self.gradient_clipping)
393
-
394
- # self.grad_avg.add_all(norms)
395
- # self.log_dict(norms)
396
-
397
- def interpolate_mask(self, mask, target_length, discrete):
398
- b, t = mask.shape
399
-
400
- mask = F.interpolate(mask.reshape(b, 1, 1, t), (1, target_length), mode="bilinear") \
401
- .reshape(b, target_length)
402
-
403
- if discrete:
404
- mask = mask > 0.01
405
- sums = mask.sum(1)
406
- all_zeros = torch.where(sums == 0)[0]
407
- if len(all_zeros) > 0:
408
- print("Fixing a bad mask")
409
- for entry in all_zeros:
410
- mask[entry, torch.randint(0, target_length - 1, size=())] = True
411
- else:
412
- return mask
413
- return mask
414
-
415
- def forward_audio(self, batch):
416
- if self.use_cached_embs:
417
- audio_feats = batch["audio_emb"]
418
- if "audio_cls" in batch:
419
- audio_cls = batch["audio_cls"]
420
- else:
421
- audio_cls = None
422
- else:
423
- audio = batch[self.audio_input]
424
-
425
- if self.full_train:
426
- audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
427
- else:
428
- with torch.no_grad():
429
- audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
430
-
431
- mask = batch[AUDIO_MASK] if AUDIO_MASK in batch else torch.ones_like(audio)
432
- pos_mask = batch[AUDIO_POS_MASK] if AUDIO_POS_MASK in batch else torch.ones_like(audio)
433
-
434
- if self.learn_audio_cls:
435
- assert audio_cls is None
436
- audio_cls = torch.broadcast_to(self.audio_cls.unsqueeze(0), (audio_feats.shape[0], audio_feats.shape[1]))
437
-
438
- aligned_audio_feats, aligned_audio_cls = self.audio_aligner(audio_feats, audio_cls)
439
-
440
- if self.channel_dropout > 0.0:
441
- aligned_audio_feats = self.channel_dropout_layer(aligned_audio_feats)
442
-
443
- aligned_audio_feats = self.prep_feats(aligned_audio_feats, is_audio=True)
444
- audio_mask = self.interpolate_mask(mask, aligned_audio_feats.shape[-1], True)
445
- audio_pos_mask = self.interpolate_mask(pos_mask, aligned_audio_feats.shape[-1], False)
446
-
447
- ret = {
448
- AUDIO_MASK: audio_mask,
449
- AUDIO_POS_MASK: audio_pos_mask,
450
- AUDIO_FEATS: aligned_audio_feats,
451
- }
452
-
453
- if aligned_audio_cls is not None:
454
- ret[AUDIO_CLS] = aligned_audio_cls
455
-
456
- return ret
457
-
458
- # @autocast(device_type="cuda", enabled=False)
459
- def forward_image(self, batch, max_batch_size=None):
460
-
461
- with torch.no_grad():
462
- image = batch[IMAGE_INPUT]
463
- b, nf, c, h, w = image.shape
464
- image = image.reshape(b * nf, c, h, w)
465
-
466
- if max_batch_size is None:
467
- max_batch_size = image.shape[0]
468
-
469
- chunks = [image[i:i + max_batch_size] for i in range(0, image.shape[0], max_batch_size)]
470
-
471
- all_image_feats = []
472
- all_image_cls = []
473
-
474
- for chunk in chunks:
475
- if self.full_train:
476
- image_feats, image_cls = self.image_model(chunk, include_cls=True)
477
- else:
478
- with torch.no_grad():
479
- image_feats, image_cls = self.image_model(chunk, include_cls=True)
480
-
481
- aligned_image_feats, aligned_image_cls = self.image_aligner(image_feats, image_cls)
482
-
483
- all_image_feats.append(aligned_image_feats)
484
- all_image_cls.append(aligned_image_cls)
485
-
486
- # Stitch the chunks back together
487
- aligned_image_feats = torch.cat(all_image_feats, dim=0)
488
- aligned_image_cls = torch.cat(all_image_cls, dim=0)
489
-
490
- if self.channel_dropout > 0.0:
491
- aligned_image_feats = self.channel_dropout_layer(aligned_image_feats)
492
-
493
- if self.spatial_dropout > 0.0:
494
- aligned_image_feats = self.spatial_dropout_layer(aligned_image_feats)
495
-
496
- aligned_image_feats = self.prep_feats(aligned_image_feats, is_audio=False)
497
- ret = {IMAGE_FEATS: aligned_image_feats}
498
-
499
- if IMAGE_MASK in batch:
500
- with torch.no_grad():
501
- mask = batch[IMAGE_MASK]
502
- mask = mask.reshape(b * nf, 1, h, w)
503
- b, c, h, w = aligned_image_feats.shape
504
- mask = F.adaptive_avg_pool2d(mask.to(aligned_image_feats), output_size=(h, w))
505
- ret[IMAGE_MASK] = mask
506
-
507
- if aligned_image_cls is not None:
508
- ret[IMAGE_CLS] = aligned_image_cls
509
-
510
- return ret
511
-
512
- def forward(self, batch):
513
- audio_feat_dict = self.forward_audio(batch)
514
- image_feat_dict = self.forward_image(batch)
515
- return {**image_feat_dict, **audio_feat_dict}
516
-
517
- def contrast_loss(self, sims):
518
- b = sims.shape[0]
519
- sims = sims - torch.eye(b, b, device=sims.device) * self.loss_margin
520
- sims_1 = sims
521
- sims_2 = sims.permute(1, 0)
522
-
523
- if self.loss_leak > 0.0:
524
- id = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
525
- label_mask = id * (1 - self.loss_leak)
526
- label_mask += (1 - id) * self.loss_leak / (sims_1.shape[0] - 1)
527
- label_mask /= label_mask.sum(dim=1, keepdim=True)
528
- else:
529
- label_mask = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
530
-
531
- labels = torch.arange(0, sims.shape[0], device=sims.device)
532
- self.rolling_avg.add(f"acc/1", (sims.argmax(dim=1) == labels).to(sims).mean())
533
- self.rolling_avg.add(f"acc/2", (sims.argmax(dim=0) == labels).to(sims).mean())
534
-
535
- if self.loss_type == "margin":
536
- margin_loss_tensor = (sims - torch.diag(sims)).clamp_min(0)
537
- margin_loss = margin_loss_tensor.mean()
538
- self.rolling_avg.add(f"loss/frac_nonzero", (margin_loss_tensor > 0).to(sims).mean())
539
- self.rolling_avg.add(f"loss/margin", margin_loss)
540
- return margin_loss
541
- elif self.loss_type == "ce":
542
- ce_loss = 1 / 2 * F.cross_entropy(sims_1, labels) + \
543
- 1 / 2 * F.cross_entropy(sims_2, labels)
544
- self.rolling_avg.add(f"loss/ce", ce_loss)
545
- return ce_loss
546
- elif self.loss_type == "bce":
547
- bce_loss = F.binary_cross_entropy_with_logits(sims_1.flatten(), label_mask.flatten())
548
- self.rolling_avg.add(f"loss/bce", bce_loss)
549
- return bce_loss
550
- elif self.loss_type == "nce":
551
- nce_loss = 1 / 2 * (-F.log_softmax(sims_1, dim=-1) * label_mask).sum(1).mean() + \
552
- 1 / 2 * (-F.log_softmax(sims_2, dim=-1) * label_mask).sum(1).mean()
553
- self.rolling_avg.add(f"loss/nce", nce_loss)
554
- return nce_loss
555
- else:
556
- raise ValueError(f"Unknown loss type {self.loss_type}")
557
-
558
- def loss(self, preds):
559
- image_feats = preds[IMAGE_FEATS]
560
- audio_feats = preds[AUDIO_FEATS]
561
- audio_mask = preds[AUDIO_MASK]
562
- image_mask = preds[IMAGE_MASK]
563
- audio_pos_mask = preds[AUDIO_POS_MASK]
564
- if DATA_SOURCE in preds:
565
- source = preds[DATA_SOURCE].to(torch.int64)
566
- else:
567
- source = None
568
-
569
- uncal_sims = self.sim_agg(preds, agg_heads=True)
570
- sims = self.sim_cal(uncal_sims)
571
-
572
- _mask = 1 - torch.eye(sims.shape[0], device=sims.device)
573
- self.log(f"sim/pos", torch.diag(sims).mean())
574
- self.log(f"sim/neg", (sims * _mask).sum() / (_mask.sum()))
575
- self.log(f"sim/uncal_pos", torch.diag(uncal_sims).mean())
576
- self.log(f"sim/uncal_neg", (uncal_sims * _mask).sum() / (_mask.sum()))
577
-
578
- b, c, h, w = image_feats.shape
579
- b, c, f, t = audio_feats.shape
580
- n_samples = 250
581
-
582
- nh = self.sim_agg_heads
583
- image_feats_by_head = image_feats.reshape(b, self.sim_agg_heads, c // nh, h, w)
584
- audio_feats_by_head = audio_feats.reshape(b, self.sim_agg_heads, c // nh, f, t)
585
-
586
- def maybe_clamp(t):
587
- return t.clamp_min(0) if self.nonneg_sim else t
588
-
589
- paired_sim_raw = self.sim_agg.get_pairwise_sims(preds, raw=True, agg_sim=False, agg_heads=False)
590
- paired_sim = maybe_clamp(paired_sim_raw)
591
-
592
- loss = 0.0
593
-
594
- if self.nonneg_pressure:
595
- afb, afk, afc, aff, aft = audio_feats_by_head.shape
596
- ifb, ifk, ifc, ifh, ifw = image_feats_by_head.shape
597
- assert (afb == ifb)
598
-
599
- device = audio_feats_by_head.device
600
- random_b = torch.randint(0, afb, size=(n_samples,), device=device)
601
- random_t = torch.randint(0, aft, size=(n_samples,), device=device)
602
- random_f = torch.randint(0, aff, size=(n_samples,), device=device)
603
- random_h = torch.randint(0, ifh, size=(n_samples,), device=device)
604
- random_w = torch.randint(0, ifw, size=(n_samples,), device=device)
605
-
606
- random_audio_feats = audio_feats_by_head[random_b, :, :, random_f, random_t]
607
- random_image_feats = image_feats_by_head[random_b, :, :, random_h, random_w]
608
- random_sim_raw = torch.einsum("bkc,dkc->bdk", random_audio_feats, random_image_feats)
609
-
610
- nonneg_loss = random_sim_raw.clamp_max(0).square().mean()
611
- self.rolling_avg.add(f"loss/nonneg", nonneg_loss)
612
- loss += nonneg_loss * self.nonneg_pressure
613
-
614
- if self.silence_l1 > 0 or self.silence_l2 > 0:
615
- masked_b, masked_t = torch.where(~audio_mask)
616
- if len(masked_b) > n_samples:
617
- subset = torch.randperm(len(masked_b))[:n_samples]
618
- masked_b = masked_b[subset]
619
- masked_t = masked_t[subset]
620
-
621
- if len(masked_b) == n_samples:
622
- silent_audio_feats = audio_feats_by_head[masked_b, :, :, :, masked_t].mean(-1) # d k c
623
- silence_tensor = maybe_clamp(
624
- torch.einsum("bkchw,dkc->bkdhw", image_feats_by_head, silent_audio_feats))
625
-
626
- silence_l1_loss = silence_tensor.abs().mean()
627
- self.rolling_avg.add(f"loss/silence_l1", silence_l1_loss)
628
- loss += silence_l1_loss * self.silence_l1
629
-
630
- silence_l2_loss = silence_tensor.square().mean()
631
- self.rolling_avg.add(f"loss/silence_l2", silence_l2_loss)
632
- loss += silence_l2_loss * self.silence_l2
633
- else:
634
- pass
635
-
636
- if self.neg_audio_weight > 0 and self.neg_audio:
637
- b, t = audio_pos_mask.shape
638
- negative_weight = ((1 - audio_pos_mask) * audio_mask.to(sims)).reshape(b, 1, 1, 1, 1, t)
639
- negative_weight = torch.broadcast_to(negative_weight, paired_sim.shape)
640
- if negative_weight.sum() > 0:
641
- neg_audio_loss = (paired_sim.square() * negative_weight).sum() \
642
- / negative_weight.sum().clamp_min(0.1)
643
- self.rolling_avg.add(f"loss/neg_audio", neg_audio_loss)
644
- self.rolling_avg.add(f"loss/neg_weight_avg", negative_weight.mean())
645
- loss += neg_audio_loss * self.neg_audio_weight
646
- else:
647
- print("WARNING: No negative samples found in batch")
648
-
649
- if self.tv_weight > 0:
650
- tv_loss = (paired_sim[:, :, :, :, :, 1:] - paired_sim[:, :, :, :, :, :-1]).square().mean()
651
- self.rolling_avg.add(f"loss/tv", tv_loss)
652
- loss += tv_loss * self.tv_weight
653
-
654
- self.log(f"cal/w", self.sim_cal.get_w())
655
- if self.cal_balance_weight > 0.0:
656
- cal_balance = (np.log(self.cal_init) - torch.log(self.sim_cal.get_w().clamp_min(.00000001))) \
657
- .clamp_min(0).square().mean()
658
- self.rolling_avg.add(f"loss/cal_balance", cal_balance)
659
- loss += cal_balance * self.cal_balance_weight
660
-
661
- if self.disentangle_weight > 0.0:
662
- assert source is not None
663
- assert self.sim_agg_heads % 2 == 0
664
-
665
- dilation = self.sim_agg_heads // 2
666
- sources_oh = F.one_hot(source, num_classes=2)
667
- b, h = sources_oh.shape
668
- sources_mask = 1 - torch.broadcast_to(sources_oh.unsqueeze(-1), (b, h, dilation)) \
669
- .reshape(b, h * dilation).to(paired_sim)
670
- disentangle_loss = torch.einsum("bkhwft,bk->bhwft", paired_sim, sources_mask).square().mean()
671
- self.rolling_avg.add(f"loss/disentangle", disentangle_loss)
672
- loss += disentangle_loss * self.disentangle_weight
673
-
674
- if self.specialization_weight > 0.0 and self.sim_agg_heads > 1:
675
- total_specialization_loss = 0.0
676
- combos = list(combinations(range(self.sim_agg_heads), 2))
677
- for i, j in combos:
678
- specialization_loss_pair = (paired_sim[:, i].abs() * paired_sim[:, j].abs()).mean()
679
- total_specialization_loss += specialization_loss_pair
680
- avg_specialization_loss = total_specialization_loss / len(combos)
681
- self.rolling_avg.add(f"loss/specialize", avg_specialization_loss)
682
- loss += avg_specialization_loss * self.specialization_weight
683
-
684
- if self.mixup_weight > 0.0:
685
- b, _, h, w = image_mask.shape
686
- neg_img_mask = torch.broadcast_to(
687
- 1 - image_mask.to(paired_sim).reshape(b, 1, h, w, 1, 1),
688
- paired_sim.shape)
689
- image_mixup_loss = (paired_sim * neg_img_mask).square().sum() / neg_img_mask.sum().clamp_min(0.1)
690
- self.rolling_avg.add(f"loss/image_mixup", image_mixup_loss)
691
- loss += image_mixup_loss * self.mixup_weight
692
-
693
- sims = sims
694
- loss += self.contrast_loss(sims)
695
- self.rolling_avg.add(f"loss/total", loss)
696
-
697
- return loss
698
-
699
- def setup_hparams(self):
700
- recalls = ['A_r1', 'A_r5', 'A_r10', 'I_r1', 'I_r5', 'I_r10']
701
-
702
- if self.trainer.datamodule.use_extra_val_sets:
703
- datasets = ["Places", "AudioSet"]
704
- else:
705
- datasets = ["Val"]
706
-
707
- heads = ["total"]
708
-
709
- metric_names = [
710
- "hp/speech_basic_ap", "hp/speech_advanced_ap", "hp/sound_basic_ap",
711
- "hp/speech_basic_iou", "hp/speech_advanced_iou", "hp/sound_basic_iou",
712
- ]
713
- for dataset in datasets:
714
- for head in heads:
715
- for recall in recalls:
716
- metric_names.append(f"hp/{dataset}/{head}/{recall}")
717
-
718
- if self.sim_agg_heads == 2:
719
- metric_names.extend(["hp/ap_dis", "hp/act_dis"])
720
-
721
- if hasattr(self.trainer, "datamodule"):
722
- all_hparams = {**self.hparams, **self.trainer.datamodule.hparams}
723
- else:
724
- all_hparams = self.hparams
725
-
726
- starting_values = {n: torch.nan for n in metric_names}
727
- self.logger.log_hyperparams(all_hparams, starting_values)
728
-
729
- def on_train_start(self):
730
- self.setup_hparams()
731
- self.hparams_logged = True
732
-
733
- def on_train_batch_start(self, batch, batch_idx):
734
- remake_optimizers = False
735
-
736
- if isinstance(self.image_aligner, ProgressiveGrowing):
737
- should_remake = self.image_aligner.maybe_change_phase(self.global_step)
738
- remake_optimizers = remake_optimizers or should_remake
739
- if isinstance(self.audio_aligner, ProgressiveGrowing):
740
- should_remake = self.audio_aligner.maybe_change_phase(self.global_step)
741
- remake_optimizers = remake_optimizers or should_remake
742
-
743
- if remake_optimizers:
744
- raise NotImplementedError()
745
-
746
- def _combine_preds(self, all_preds):
747
- temp = {}
748
- new_preds = {}
749
-
750
- # Collect tensors for each key into lists
751
- for d in all_preds:
752
- for key, value in d.items():
753
- if isinstance(value, torch.Tensor):
754
- if key not in temp:
755
- temp[key] = []
756
- temp[key].append(value)
757
-
758
- # Concatenate all tensors for each key using a single call to torch.cat
759
- for key, tensor_list in temp.items():
760
- new_preds[key] = torch.cat(tensor_list)
761
- return new_preds
762
-
763
- def training_step(self, batch, batch_idx):
764
- assert batch[IMAGE_INPUT].shape[1] == 1
765
-
766
- preds = self.forward(batch)
767
- if DATA_SOURCE in batch:
768
- preds[DATA_SOURCE] = batch[DATA_SOURCE]
769
-
770
- if self.trainer.world_size > 1 and self.gather_tensors:
771
- for k, v in preds.items():
772
- new_v = v.contiguous()
773
- preds[k] = torch.cat(GatherLayer.apply(new_v), dim=0)
774
-
775
- if self.memory_buffer_size > 0:
776
- new_preds = self._combine_preds(list(self.memory_buffer) + [preds])
777
- else:
778
- new_preds = preds
779
-
780
- loss = self.loss(new_preds)
781
-
782
- if self.memory_buffer_size > 0:
783
- self.memory_buffer.append(self._recursive_detach(preds, gather=False))
784
-
785
- if self.trainer.is_global_zero and self.global_step % 50 == 1:
786
- writer = self.logger.experiment
787
- self.rolling_avg.logall(lambda k, v: writer.add_scalar(k, v, global_step=self.global_step))
788
-
789
- if self.trainer.scaler is not None:
790
- self.log("loss_scale", self.trainer.scaler.get_scale())
791
-
792
- if self.global_step % 10000 == 0 and self.global_step > 0:
793
- print("RESETTING TFEVENT FILE")
794
- self.logger.experiment.close()
795
- self.logger.experiment._get_file_writer()
796
-
797
- return loss
798
-
799
- def on_validation_start(self) -> None:
800
- if not self.hparams_logged:
801
- self.setup_hparams()
802
- self.hparams_logged = True
803
-
804
- def _auto_gather(self, t):
805
- if t.dtype == torch.bool:
806
- t = t.to(torch.float)
807
-
808
- if self.trainer.num_devices == 1:
809
- return t.cpu()
810
-
811
- t = torch.clone(t).contiguous()
812
- if self.trainer.is_global_zero:
813
- gather_list = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
814
- dist.gather(t, gather_list)
815
- return torch.cat(gather_list, dim=0).cpu()
816
- else:
817
- dist.gather(t)
818
-
819
- def validation_step(self, batch, batch_idx, dataloader_idx=0):
820
-
821
- with torch.no_grad():
822
- preds = self.forward(batch)
823
-
824
- ret = {}
825
- for k in preds.keys():
826
- if k in preds:
827
- ret[k] = self._auto_gather(preds[k])
828
-
829
- batch_keys = [IMAGE_INPUT, "spec", "semseg", "num_pixels_per_class", 'total_length']
830
- for k in batch_keys:
831
- if k in batch:
832
- ret[k] = self._auto_gather(batch[k])
833
-
834
- if "metadata" in batch:
835
- if isinstance(batch["metadata"]["id"], torch.Tensor):
836
- ret["id"] = self._auto_gather(batch["metadata"]["id"])
837
- ret["index"] = self._auto_gather(batch["metadata"]["index"])
838
-
839
- return ret
840
-
841
- def _calc_recalls(self, sim):
842
- top_10_a = sim.topk(10, 0).indices == torch.arange(sim.shape[0]).unsqueeze(0)
843
- top_10_i = (sim.topk(10, 1).indices == torch.arange(sim.shape[0]).unsqueeze(1)).permute(1, 0)
844
- a_recall = lambda p: top_10_a[0:p].any(0).to(sim).mean()
845
- i_recall = lambda p: top_10_i[0:p].any(0).to(sim).mean()
846
- return {'A_r1': a_recall(1),
847
- 'A_r5': a_recall(5),
848
- 'A_r10': a_recall(10),
849
- 'I_r1': i_recall(1),
850
- 'I_r5': i_recall(5),
851
- 'I_r10': i_recall(10)}
852
-
853
- def calc_recalls(self, preds, dataset):
854
- sim = self.sim_agg.forward_batched(
855
- preds=preds,
856
- agg_heads=False,
857
- batch_size=4,
858
- ).cpu()
859
-
860
- all_metrics = dict()
861
- for k, v in self._calc_recalls(sim.sum(-1)).items():
862
- all_metrics[f"hp/{dataset}/total/" + k] = v
863
-
864
- return all_metrics
865
-
866
- def retrieval_validation(self, outputs, dataset_name):
867
- if len(outputs) == 0:
868
- return
869
-
870
- if self.trainer.is_global_zero:
871
- results = flatten_preds(outputs)
872
- if not self.trainer.sanity_checking:
873
- print(results[IMAGE_FEATS].shape[0])
874
- # assert (results[IMAGE_FEATS].shape[0] == 1000)
875
- results[IMAGE_FEATS] = results[IMAGE_FEATS].cpu()
876
- results[AUDIO_FEATS] = results[AUDIO_FEATS].cuda()
877
- if self.sim_use_cls:
878
- results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
879
- results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
880
-
881
- results[AUDIO_MASK] = results[AUDIO_MASK].cuda()
882
-
883
- recalls = self.calc_recalls(results, dataset_name)
884
-
885
- results[IMAGE_FEATS] = results[IMAGE_FEATS].cuda()
886
-
887
- writer = self.logger.experiment
888
- print("here")
889
- for name, v in recalls.items():
890
- writer.add_scalar(f"{name}", v, self.global_step + 1)
891
-
892
- def semseg_validation(self, speech_preds, sound_preds):
893
-
894
- if self.trainer.is_global_zero:
895
- from eval_utils import get_paired_heatmaps
896
- def prep_preds(preds, loader):
897
- results = flatten_preds(preds)
898
- metadata = loader.dataset.metadata
899
- ordered_metadata = metadata.iloc[results["index"].numpy(), :].copy()
900
- ordered_metadata["order"] = range(len(ordered_metadata))
901
- return results, ordered_metadata
902
-
903
- [_, _, speech_loader, sound_loader] = self.trainer.val_dataloaders
904
- speech_results, speech_metadata = prep_preds(speech_preds, speech_loader)
905
- sound_results, sound_metadata = prep_preds(sound_preds, sound_loader)
906
-
907
- self.sound_metrics, unique_sound_indices = get_paired_heatmaps(
908
- self, sound_results, sound_metadata["ade_class_id"], None)
909
-
910
- self.speech_metrics, unique_word_indices = get_paired_heatmaps(
911
- self, speech_results, speech_metadata["ade_class_id"], speech_metadata["timing"])
912
-
913
- writer = self.logger.experiment
914
-
915
- all_metrics = {
916
- **{"sound_" + k: v for k, v in self.sound_metrics.items()},
917
- **{"speech_" + k: v for k, v in self.speech_metrics.items()},
918
- }
919
-
920
- for k, v in all_metrics.items():
921
- writer.add_scalar(f"hp/{k}", torch.tensor(v).mean(), self.global_step + 1)
922
-
923
- def disentangle_validation(self, word_preds, sound_preds):
924
-
925
- if len(word_preds) == 0 or len(sound_preds) == 0:
926
- return
927
-
928
- if self.trainer.is_global_zero:
929
- word_preds = flatten_preds(word_preds)
930
- sound_preds = flatten_preds(sound_preds)
931
-
932
- word_scores = self.sim_agg.get_pairwise_sims(
933
- word_preds,
934
- raw=False,
935
- agg_sim=True,
936
- agg_heads=False,
937
- )
938
-
939
- sound_scores = self.sim_agg.get_pairwise_sims(
940
- sound_preds,
941
- raw=False,
942
- agg_sim=True,
943
- agg_heads=False,
944
- )
945
-
946
- all_scores = torch.cat([word_scores, sound_scores], dim=0)
947
- all_scores -= all_scores.min(dim=0, keepdim=True).values
948
- all_scores /= all_scores.max(dim=0, keepdim=True).values.clamp_min(.0001)
949
-
950
- is_words = torch.cat([
951
- torch.ones(word_scores.shape[0]),
952
- torch.zeros(sound_scores.shape[0])], dim=0).to(torch.bool)
953
-
954
- assert all_scores.shape[1] == 2
955
- ap_matrix = torch.zeros(2, 2)
956
- act_matrix = torch.zeros(2, 2)
957
-
958
- for head in range(2):
959
- # writer.add_histogram(f"h{head}_all_scores", all_scores[:, head])
960
- for dataset_num in range(2):
961
- if dataset_num == 0:
962
- labels = is_words
963
- else:
964
- labels = ~is_words
965
-
966
- ap_matrix[head, dataset_num] = binary_average_precision(
967
- all_scores[:, head].cpu(), labels.to(torch.int64).cpu())
968
-
969
- act_matrix[head, dataset_num] = 1 - (all_scores[:, head][labels]).mean()
970
-
971
- ap_dis = max(.5 * (ap_matrix[0, 0] + ap_matrix[1, 1]),
972
- .5 * (ap_matrix[0, 1] + ap_matrix[1, 0]))
973
-
974
- act_dis = max(.5 * (act_matrix[0, 0] + act_matrix[1, 1]),
975
- .5 * (act_matrix[0, 1] + act_matrix[1, 0]))
976
-
977
- print("AP", ap_matrix)
978
- print("AP dis", ap_dis)
979
- print("Act", act_matrix)
980
- print("Act dis", act_dis)
981
-
982
- writer = self.logger.experiment
983
- writer.add_scalar("hp/ap_dis", ap_dis, self.global_step + 1)
984
- writer.add_scalar("hp/act_dis", act_dis, self.global_step + 1)
985
-
986
- def validation_epoch_end(self, outputs) -> None:
987
- print("Val end")
988
- with torch.no_grad():
989
- if self.trainer.datamodule.use_extra_val_sets:
990
- if self.sim_agg_heads == 2:
991
- self.disentangle_validation(outputs[0], outputs[1])
992
- self.retrieval_validation(outputs[0], "Places")
993
- self.retrieval_validation(outputs[1], "AudioSet")
994
- self.semseg_validation(outputs[2], outputs[3])
995
-
996
- else:
997
- print("HERE!")
998
- self.retrieval_validation(outputs, "Val")
999
-
1000
- writer = self.logger.experiment
1001
- writer.flush()
1002
-
1003
- def _recursive_detach(self, obj, gather=True):
1004
- if isinstance(obj, torch.Tensor):
1005
- if gather:
1006
- return self._auto_gather(obj)
1007
- else:
1008
- obj.detach()
1009
- elif isinstance(obj, dict):
1010
- return {k: self._recursive_detach(v, gather) for k, v in obj.items()}
1011
- elif isinstance(obj, list):
1012
- return [self._recursive_detach(v, gather) for v in obj]
1013
- else:
1014
- return obj
1015
-
1016
- def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
1017
- with torch.no_grad():
1018
- predictions = {}
1019
- for k, v in batch.items():
1020
- predictions[k] = self._recursive_detach(v)
1021
- for k, v in self.forward(batch).items():
1022
- predictions[k] = self._auto_gather(v)
1023
-
1024
- return predictions
1025
-
1026
- def _configure_optimizers(self, full_train, lr):
1027
- params = [
1028
- *self.audio_aligner.parameters(),
1029
- *self.image_aligner.parameters(),
1030
- *self.sim_cal.parameters(),
1031
- *self.sim_agg.parameters()
1032
- ]
1033
-
1034
- if (self.finetune_image_model or self.image_lora) and full_train:
1035
- params.extend(self.image_model.parameters())
1036
-
1037
- if (self.finetune_audio_model or self.audio_lora) and full_train:
1038
- params.extend(self.audio_model.parameters())
1039
-
1040
- if self.learn_audio_cls:
1041
- params.append(self.audio_cls)
1042
-
1043
- last_epoch = self.global_step - 1
1044
- if self.optimizer == "adam":
1045
- opt = torch.optim.Adam(params, lr=lr, eps=1e-7)
1046
- elif self.optimizer == "nadam":
1047
- opt = torch.optim.NAdam(params, lr=lr, eps=1e-7)
1048
- else:
1049
- raise ValueError(f"Unknown optimizer {self.optimizer}")
1050
-
1051
- if self.lr_schedule == "sgdr":
1052
- scheduler = CosineAnnealingWarmRestarts(
1053
- opt, self.lr_cycle_length, 2, eta_min=lr * 2e-2, last_epoch=last_epoch)
1054
- else:
1055
- scheduler = LambdaLR(opt, lr_lambda=lambda step: 1.0, last_epoch=last_epoch)
1056
-
1057
- if self.lr_warmup > 0:
1058
- warmup = LambdaLR(
1059
- opt,
1060
- lr_lambda=lambda step: min(max(float(step), 0.0) / self.lr_warmup, 1.0),
1061
- last_epoch=last_epoch,
1062
- )
1063
- scheduler = SequentialLR(
1064
- opt,
1065
- schedulers=[warmup, scheduler],
1066
- milestones=[self.lr_warmup],
1067
- last_epoch=last_epoch)
1068
-
1069
- scheduler = {"scheduler": scheduler, "interval": "step"}
1070
-
1071
- return [opt], [scheduler]
1072
-
1073
- def configure_optimizers(self):
1074
- if self.full_train:
1075
- return self._configure_optimizers(self.full_train, self.lr)
1076
- else:
1077
- return self._configure_optimizers(self.full_train, self.pretrain_lr)
1078
-
1079
-
1080
- @hydra.main(config_path="configs", config_name="av_align.yaml", version_base=None)
1081
- def my_app(cfg: DictConfig) -> None:
1082
- print(OmegaConf.to_yaml(cfg))
1083
- seed_everything(cfg.seed, workers=True)
1084
-
1085
- exp_name = f"{cfg.resume_prefix}"
1086
-
1087
- if cfg.image_model_type == "dino8":
1088
- patch_size = 8 * cfg.image_pool_width
1089
- elif cfg.image_model_type == "cavmae":
1090
- patch_size = 16 * cfg.image_pool_width
1091
- elif cfg.image_model_type == "imagebind":
1092
- patch_size = 16 * cfg.image_pool_width
1093
- elif cfg.image_model_type == "clip":
1094
- patch_size = 16 * cfg.image_pool_width
1095
- elif cfg.image_model_type == "cavmae-mixed":
1096
- patch_size = 16 * cfg.image_pool_width
1097
- elif cfg.image_model_type == "dinov2":
1098
- patch_size = 14 * cfg.image_pool_width
1099
- else:
1100
- raise ValueError(f"Unknown patch size for model {cfg.image_model_type}")
1101
-
1102
- datamodule = AVDataModule(
1103
- dataset_name=cfg.dataset_name,
1104
- load_size=cfg.load_size,
1105
- image_aug=cfg.image_aug,
1106
- audio_aug=cfg.audio_aug,
1107
- extra_audio_masking=cfg.extra_audio_masking,
1108
- audio_model_type=cfg.audio_model_type,
1109
- pytorch_data_dir=cfg.pytorch_data_dir,
1110
- use_cached_embs=cfg.use_cached_embs,
1111
- batch_size=cfg.batch_size,
1112
- num_workers=cfg.num_workers,
1113
- audio_level=cfg.audio_level,
1114
- neg_audio=cfg.neg_audio,
1115
- use_original_val_set=not cfg.use_extra_val_sets,
1116
- use_extra_val_sets=cfg.use_extra_val_sets,
1117
- data_for_plotting=False,
1118
- quad_mixup=cfg.quad_mixup,
1119
- bg_mixup=cfg.bg_mixup,
1120
- patch_mixup=cfg.patch_mixup,
1121
- patch_size=patch_size
1122
- )
1123
- datamodule.maybe_unpack(remove_source=cfg.submitting_to_aml)
1124
-
1125
- aligner = create_model_from_cfg(LitAVAligner, cfg, {})
1126
-
1127
- if cfg.starting_weights is not None:
1128
- loaded = torch.load(join(cfg.output_root, cfg.starting_weights), map_location='cpu')
1129
- state = loaded["state_dict"]
1130
- aligner.load_state_dict(state, strict=cfg.load_strict)
1131
- del state
1132
- del loaded
1133
-
1134
- if cfg.num_gpus > 1:
1135
- # strategy = "ddp_sharded" # _find_unused_parameters_true"
1136
- strategy = "ddp" # _find_unused_parameters_true"
1137
- else:
1138
- strategy = "auto"
1139
-
1140
- if cfg.dataset_name in {"places-audio", "mixed", "audio-set", "mixed-full"}:
1141
- val_args = dict(check_val_every_n_epoch=2)
1142
- elif cfg.dataset_name in {"dolphin"}:
1143
- val_args = dict(check_val_every_n_epoch=5)
1144
- else:
1145
- val_args = dict(val_check_interval=10000)
1146
-
1147
- # val_args = dict(val_check_interval=1000)
1148
-
1149
- def maybe_get_ckpt(ckpt_dir):
1150
- if cfg.auto_resume and os.path.exists(ckpt_dir):
1151
- print(f"Attempting to resume from {ckpt_dir}")
1152
- candidates = os.listdir(ckpt_dir)
1153
- assert (len(candidates) == 1)
1154
- return join(ckpt_dir, candidates[0])
1155
- elif cfg.auto_resume:
1156
- print(f"Could not find checkpoint at {ckpt_dir}")
1157
- return None
1158
- else:
1159
- return None
1160
-
1161
- log_dir = join(cfg.output_root, "logs", cfg.grouping_name, exp_name)
1162
- ckpt_dir = join(cfg.output_root, "checkpoints", cfg.grouping_name, exp_name)
1163
-
1164
- import gc
1165
- torch.cuda.empty_cache()
1166
- gc.collect()
1167
-
1168
- def run_exp(aligner, full_train):
1169
- trainer_args = dict(
1170
- accelerator='gpu',
1171
- strategy=strategy,
1172
- devices=cfg.num_gpus,
1173
- num_sanity_val_steps=cfg.num_sanity_val_steps,
1174
- log_every_n_steps=50,
1175
- reload_dataloaders_every_n_epochs=10,
1176
- precision="16",
1177
- # profiler="simple",
1178
- # precision="bf16",
1179
- max_steps=cfg.max_steps,
1180
- **val_args)
1181
-
1182
- aligner.set_full_train(full_train)
1183
- if full_train:
1184
- suffix = "train"
1185
- else:
1186
- suffix = "pretrain"
1187
- trainer_args["max_steps"] = cfg.pretrain_steps
1188
-
1189
- print(f"Starting {suffix} phase")
1190
-
1191
- logger = TensorBoardLogger(join(log_dir, suffix), default_hp_metric=False)
1192
- callbacks = [
1193
- ModelCheckpoint(join(ckpt_dir, suffix), every_n_epochs=1),
1194
- LearningRateMonitor(logging_interval='step'),
1195
- ]
1196
- Trainer(logger=logger,
1197
- callbacks=callbacks,
1198
- **trainer_args).fit(
1199
- aligner,
1200
- datamodule=datamodule,
1201
- ckpt_path=maybe_get_ckpt(join(ckpt_dir, suffix)))
1202
-
1203
- train_chkpt = maybe_get_ckpt(join(ckpt_dir, "train"))
1204
-
1205
- gc.collect()
1206
- if torch.cuda.is_available():
1207
- torch.cuda.empty_cache()
1208
-
1209
- if cfg.pretrain_steps > 0 and train_chkpt is None:
1210
- print("---"*10)
1211
- print("Setup with full_train = False")
1212
- run_exp(aligner, full_train=False)
1213
- print("---"*10)
1214
- else:
1215
- print("---"*10)
1216
- print("Setup with full_train = False")
1217
- run_exp(aligner, full_train=True)
1218
- print("---"*10)
1219
-
1220
-
1221
- if __name__ == "__main__":
1222
- my_app()
 
1
+ import os
2
+ from collections import deque
3
+ from itertools import combinations
4
+ from os.path import join
5
+
6
+ import hydra
7
+ import numpy as np
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.nn.functional as F
12
+ from omegaconf import DictConfig, OmegaConf
13
+ from peft import get_peft_model, LoraConfig
14
+ from pytorch_lightning import Trainer
15
+ from pytorch_lightning import seed_everything
16
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
17
+ from pytorch_lightning.loggers import TensorBoardLogger
18
+ from pytorch_lightning.utilities import grad_norm
19
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, SequentialLR, LambdaLR
20
+ from torchmetrics.functional.classification import binary_average_precision
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin
23
+
24
+ from DenseAV.denseav.aggregators import get_aggregator
25
+ from DenseAV.denseav.aligners import get_aligner, ProgressiveGrowing
26
+ from DenseAV.denseav.constants import *
27
+ from DenseAV.denseav.data.AVDatasets import AVDataModule
28
+ from DenseAV.denseav.shared import flatten_preds, GatherLayer, \
29
+ get_image_featurizer, get_audio_featurizer, RollingAvg, create_model_from_cfg
30
+
31
+ torch.multiprocessing.set_sharing_strategy('file_system')
32
+
33
+
34
+ def _imposter_indices_helper(true_indices: torch.Tensor, samples: torch.Tensor):
35
+ mask = (true_indices == samples).to(torch.int64)
36
+ n = mask.shape[0]
37
+
38
+ if not mask.any():
39
+ return samples
40
+ else:
41
+ new_samples = torch.randint(0, n, size=(n,), device=true_indices.device)
42
+ comb_samples = mask * new_samples + (1 - mask) * samples
43
+ return _imposter_indices_helper(true_indices, comb_samples)
44
+
45
+
46
+ def imposter_indices(n, device):
47
+ return _imposter_indices_helper(
48
+ torch.arange(0, n, device=device),
49
+ torch.randint(0, n, size=(n,), device=device))
50
+
51
+
52
+ def get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type):
53
+ max_t = audio_outputs.shape[-1]
54
+ oh = F.one_hot(n_frames - 1, num_classes=max_t)
55
+ audio_mask = 1 - torch.cumsum(oh, dim=1)
56
+ audio_mask = F.pad(audio_mask, [1, 0], value=1)[:, :max_t].to(audio_outputs.dtype)
57
+
58
+ full_sim = torch.einsum("bct,bchw->bthw", audio_outputs, image_outputs)
59
+ expanded_am = audio_mask.unsqueeze(-1).unsqueeze(-1)
60
+
61
+ if sim_type.endswith("mi"):
62
+ offset = 10 * (full_sim.max() - full_sim.min())
63
+ full_sim = (full_sim - ((1 - expanded_am) * offset)).max(1, keepdim=True).values
64
+
65
+ if sim_type.startswith("mi"):
66
+ full_sim = full_sim.max(-1, keepdim=True).values.max(-2, keepdim=True).values
67
+
68
+ if sim_type.endswith("sa"):
69
+ full_sim = (full_sim * (expanded_am / expanded_am.sum(1, keepdim=True).clamp_min(1))).sum(1, keepdim=True)
70
+
71
+ return full_sim.mean(dim=[1, 2, 3])
72
+
73
+
74
+ def sampled_margin_rank_loss(image_outputs, audio_outputs, n_frames, sim_type, margin=1.):
75
+ """
76
+ Computes the triplet margin ranking loss for each anchor image/caption pair
77
+ The impostor image/caption is randomly sampled from the minibatch
78
+ """
79
+ assert (image_outputs.dim() == 4)
80
+ assert (audio_outputs.dim() == 3)
81
+ n = image_outputs.size(0)
82
+ imp_ind_i = imposter_indices(n, image_outputs.device)
83
+ imp_ind_a = imposter_indices(n, image_outputs.device)
84
+ true_sim = get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type)
85
+ imp_sim_i = get_sim_per_row(image_outputs[imp_ind_i], audio_outputs, n_frames, sim_type)
86
+ imp_sim_a = get_sim_per_row(image_outputs, audio_outputs[imp_ind_a], n_frames[imp_ind_a], sim_type)
87
+ a2i_loss = (margin + imp_sim_i - true_sim).clamp_min(0)
88
+ i2a_loss = (margin + imp_sim_a - true_sim).clamp_min(0)
89
+ return (a2i_loss + i2a_loss).mean() / 2
90
+
91
+
92
+ class SimilarityCalibrator(torch.nn.Module):
93
+
94
+ def __init__(self, cal_init, max_w=100, min_w=.01, subtract_mean=True, use_bias=False):
95
+ super().__init__()
96
+ self.max_w = max_w
97
+ self.min_w = min_w
98
+ self.w = torch.nn.Parameter(torch.tensor([cal_init]).log())
99
+
100
+ self.use_bias = use_bias
101
+ if self.use_bias:
102
+ self.b = torch.nn.Parameter(torch.tensor([0.0]))
103
+
104
+ self.subtract_mean = subtract_mean
105
+
106
+ def get_w(self):
107
+ return torch.exp(self.w).clamp_max(self.max_w).clamp_min(self.min_w)
108
+
109
+ def forward(self, x):
110
+ sims = self.get_w() * x
111
+
112
+ if self.use_bias:
113
+ sims = sims + self.b
114
+
115
+ if self.subtract_mean:
116
+ return sims - sims.mean()
117
+ else:
118
+ return sims
119
+
120
+
121
+ class SpatialDropout(torch.nn.Module):
122
+
123
+ def __init__(self, p, *args, **kwargs):
124
+ super().__init__(*args, **kwargs)
125
+ self.p = p
126
+
127
+ def forward(self, x):
128
+ b, c, h, w = x.shape
129
+ dropout = torch.rand((b, 1, h, w), dtype=x.dtype, device=x.device) > self.p
130
+
131
+ if self.training:
132
+ return x * dropout
133
+ else:
134
+ return x
135
+
136
+
137
+ class LitAVAligner(pl.LightningModule, PyTorchModelHubMixin, repo_url="https://github.com/mhamilton723/DenseAV", license="mit", tags=["denseav"]):
138
+ def __init__(self,
139
+ code_dim,
140
+ image_model_type,
141
+ image_model_token_type,
142
+ image_aligner_type,
143
+ image_pool_width,
144
+ audio_model_type,
145
+ audio_aligner_type,
146
+ audio_pool_width,
147
+ audio_lora,
148
+ audio_lora_rank,
149
+ image_lora,
150
+ image_lora_rank,
151
+ gradient_clipping,
152
+ learn_audio_cls,
153
+ silence_l1,
154
+ silence_l2,
155
+ tv_weight,
156
+ nonneg_sim,
157
+ nonneg_pressure,
158
+ pretrain_lr,
159
+ lr,
160
+ lr_warmup,
161
+ lr_schedule,
162
+ lr_cycle_length,
163
+ optimizer,
164
+ gather_tensors,
165
+ sim_agg_type,
166
+ sim_agg_heads,
167
+ sim_use_cls,
168
+ disentangle_weight,
169
+ norm_vectors,
170
+ cal_init,
171
+ cal_balance_weight,
172
+ loss_type,
173
+ loss_margin,
174
+ mask_silence,
175
+ finetune_image_model,
176
+ finetune_audio_model,
177
+ use_cached_embs,
178
+ output_root,
179
+ neg_audio,
180
+ neg_audio_weight,
181
+ head_agg,
182
+ adaptive_clipping,
183
+ specialization_weight,
184
+ spatial_dropout,
185
+ channel_dropout,
186
+ mixup_weight,
187
+ memory_buffer_size,
188
+ loss_leak,
189
+ ):
190
+ super().__init__()
191
+
192
+ self.code_dim = code_dim
193
+ self.image_model_type = image_model_type
194
+ self.image_model_token_type = image_model_token_type
195
+ self.image_aligner_type = image_aligner_type
196
+ self.image_pool_width = image_pool_width
197
+ self.audio_model_type = audio_model_type
198
+ self.audio_aligner_type = audio_aligner_type
199
+ self.audio_pool_width = audio_pool_width
200
+
201
+ self.gradient_clipping = gradient_clipping
202
+ self.learn_audio_cls = learn_audio_cls
203
+ self.silence_l1 = silence_l1
204
+ self.silence_l2 = silence_l2
205
+
206
+ self.tv_weight = tv_weight
207
+ self.nonneg_sim = nonneg_sim
208
+ self.nonneg_pressure = nonneg_pressure
209
+ self.pretrain_lr = pretrain_lr
210
+ self.lr = lr
211
+ self.lr_warmup = lr_warmup
212
+ self.lr_schedule = lr_schedule
213
+ self.lr_cycle_length = lr_cycle_length
214
+ self.optimizer = optimizer
215
+ self.gather_tensors = gather_tensors
216
+ self.sim_agg_type = sim_agg_type
217
+ self.sim_agg_heads = sim_agg_heads
218
+ self.sim_use_cls = sim_use_cls
219
+ self.disentangle_weight = disentangle_weight
220
+
221
+ self.norm_vectors = norm_vectors
222
+ self.cal_init = cal_init
223
+ self.cal_balance_weight = cal_balance_weight
224
+ self.loss_type = loss_type
225
+ self.loss_margin = loss_margin
226
+ self.mask_silence = mask_silence
227
+ self.finetune_image_model = finetune_image_model
228
+ self.finetune_audio_model = finetune_audio_model
229
+ self.use_cached_embs = use_cached_embs
230
+ self.output_root = output_root
231
+ self.audio_lora = audio_lora
232
+ self.audio_lora_rank = audio_lora_rank
233
+ self.image_lora = image_lora
234
+ self.image_lora_rank = image_lora_rank
235
+ self.neg_audio = neg_audio
236
+ self.neg_audio_weight = neg_audio_weight
237
+ self.head_agg = head_agg
238
+
239
+ self.adaptive_clipping = adaptive_clipping
240
+ self.specialization_weight = specialization_weight
241
+ self.spatial_dropout = spatial_dropout
242
+ self.channel_dropout = channel_dropout
243
+ self.mixup_weight = mixup_weight
244
+
245
+ self.memory_buffer_size = memory_buffer_size
246
+ self.memory_buffer = deque(maxlen=self.memory_buffer_size)
247
+ self.loss_leak = loss_leak
248
+
249
+ self.full_train = False # Added by me
250
+
251
+ if self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
252
+ self.audio_input = "spec"
253
+ elif self.audio_model_type == "davenet":
254
+ self.audio_input = "davenet_spec"
255
+ elif self.audio_model_type == "fnac":
256
+ self.audio_input = "fnac_spec"
257
+ else:
258
+ self.audio_input = "audio"
259
+
260
+ extra_model_args = dict(output_root=output_root)
261
+
262
+ self.image_model, _, self.image_feat_dim = get_image_featurizer(
263
+ image_model_type, token_type=self.image_model_token_type, **extra_model_args)
264
+
265
+ self.image_model.eval()
266
+ if not self.finetune_image_model:
267
+ for param in self.image_model.parameters():
268
+ param.requires_grad = False
269
+
270
+ if image_model_type in {"cavmae", "cavmae-mixed", "imagebind", "fnac"}:
271
+ extra_model_args["model"] = self.image_model.model
272
+
273
+ if use_cached_embs:
274
+ _, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
275
+ else:
276
+ self.audio_model, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
277
+
278
+ self.audio_model.eval()
279
+ if not self.finetune_audio_model:
280
+ for param in self.audio_model.parameters():
281
+ param.requires_grad = False
282
+
283
+ if self.image_lora:
284
+ if self.image_model_type in {"sam", "dino8", "dinov2", "cavmae", "cavmae-mixed"}:
285
+ target_modules = ["qkv"]
286
+ elif self.image_model_type == "clip":
287
+ target_modules = ["out_proj"]
288
+ elif self.image_model_type == "imagebind":
289
+ target_modules = ["out_proj", "fc1", "fc2"]
290
+ else:
291
+ target_modules = ["q", "k", "v"]
292
+
293
+ peft_config = LoraConfig(
294
+ target_modules=target_modules,
295
+ inference_mode=False,
296
+ r=image_lora_rank,
297
+ lora_alpha=32,
298
+ lora_dropout=0.1
299
+ )
300
+ self.image_model = get_peft_model(self.image_model, peft_config)
301
+ self.image_model.print_trainable_parameters()
302
+
303
+ if self.audio_lora:
304
+ if self.audio_model_type == "hubert":
305
+ target_modules = ["q_proj", "k_proj", "v_proj"]
306
+ else:
307
+ target_modules = ["q", "k", "v"]
308
+
309
+ peft_config = LoraConfig(
310
+ inference_mode=False,
311
+ target_modules=target_modules,
312
+ r=audio_lora_rank,
313
+ lora_alpha=32,
314
+ lora_dropout=0.1
315
+ )
316
+ self.audio_model = get_peft_model(self.audio_model, peft_config)
317
+ self.audio_model.print_trainable_parameters()
318
+
319
+ shared_aligner_args = dict(out_dim=self.code_dim)
320
+
321
+ self.audio_aligner = get_aligner(
322
+ self.audio_aligner_type, self.audio_feat_dim, **shared_aligner_args)
323
+ self.image_aligner = get_aligner(
324
+ self.image_aligner_type, self.image_feat_dim, **shared_aligner_args)
325
+
326
+ if self.loss_type == "nce":
327
+ self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=True, use_bias=False)
328
+ else:
329
+ self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=False, use_bias=True)
330
+
331
+ if self.learn_audio_cls:
332
+ self.audio_cls = torch.nn.Parameter(torch.randn(self.audio_feat_dim))
333
+
334
+ if self.spatial_dropout > 0.0:
335
+ self.spatial_dropout_layer = SpatialDropout(self.spatial_dropout)
336
+
337
+ if self.channel_dropout > 0.0:
338
+ self.channel_dropout_layer = torch.nn.Dropout2d(self.channel_dropout)
339
+
340
+ self.sim_agg = get_aggregator(
341
+ self.sim_agg_type,
342
+ self.nonneg_sim,
343
+ self.mask_silence,
344
+ self.sim_agg_heads,
345
+ self.head_agg,
346
+ self.sim_use_cls,
347
+ dim=self.image_feat_dim
348
+ )
349
+
350
+ self.hparams_logged = False
351
+ self.rolling_avg = RollingAvg(50)
352
+ self.grad_avg = RollingAvg(50, nonzero=True)
353
+
354
+ self.save_hyperparameters()
355
+
356
+ def set_full_train(self, full_train):
357
+ self.full_train = full_train
358
+
359
+ def prep_feats(self, feats, is_audio):
360
+
361
+ if not is_audio and self.training and self.image_pool_width > 1:
362
+ feats = torch.nn.AvgPool2d(self.image_pool_width)(feats)
363
+
364
+ if is_audio and self.training and self.audio_pool_width > 1:
365
+ feats = torch.nn.AvgPool2d((1, self.audio_pool_width))(feats)
366
+
367
+ if self.norm_vectors:
368
+ feats = F.normalize(feats, dim=1)
369
+
370
+ return feats
371
+
372
+ def on_before_optimizer_step(self, optimizer, optimizer_idx):
373
+ norms = grad_norm(self, norm_type=2)
374
+ avg_grads = self.grad_avg.get_all()
375
+ params = {
376
+ f"grad_2.0_norm/{name}": p
377
+ for name, p in self.named_parameters()
378
+ if p.grad is not None
379
+ }
380
+
381
+ if self.adaptive_clipping:
382
+ for k in norms.keys():
383
+ if k in params:
384
+ avg_grad = max(avg_grads.get(k, norms[k]), 1e-5)
385
+ if self.global_step > 10 and norms[k] > avg_grad * 5:
386
+ print(f"Bad grad for {k}: {norms[k]} scaling to {avg_grad * 5}")
387
+ torch.nn.utils.clip_grad_norm_(params[k], avg_grad * 5)
388
+ norms[k] = avg_grad * 5
389
+
390
+ if norms[k] > self.gradient_clipping:
391
+ # print(f"Bad grad for {k}: {norms[k]} scaling to {self.gradient_clipping}")
392
+ torch.nn.utils.clip_grad_norm_(params[k], self.gradient_clipping)
393
+
394
+ # self.grad_avg.add_all(norms)
395
+ # self.log_dict(norms)
396
+
397
+ def interpolate_mask(self, mask, target_length, discrete):
398
+ b, t = mask.shape
399
+
400
+ mask = F.interpolate(mask.reshape(b, 1, 1, t), (1, target_length), mode="bilinear") \
401
+ .reshape(b, target_length)
402
+
403
+ if discrete:
404
+ mask = mask > 0.01
405
+ sums = mask.sum(1)
406
+ all_zeros = torch.where(sums == 0)[0]
407
+ if len(all_zeros) > 0:
408
+ print("Fixing a bad mask")
409
+ for entry in all_zeros:
410
+ mask[entry, torch.randint(0, target_length - 1, size=())] = True
411
+ else:
412
+ return mask
413
+ return mask
414
+
415
+ def forward_audio(self, batch):
416
+ if self.use_cached_embs:
417
+ audio_feats = batch["audio_emb"]
418
+ if "audio_cls" in batch:
419
+ audio_cls = batch["audio_cls"]
420
+ else:
421
+ audio_cls = None
422
+ else:
423
+ audio = batch[self.audio_input]
424
+
425
+ if self.full_train:
426
+ audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
427
+ else:
428
+ with torch.no_grad():
429
+ audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
430
+
431
+ mask = batch[AUDIO_MASK] if AUDIO_MASK in batch else torch.ones_like(audio)
432
+ pos_mask = batch[AUDIO_POS_MASK] if AUDIO_POS_MASK in batch else torch.ones_like(audio)
433
+
434
+ if self.learn_audio_cls:
435
+ assert audio_cls is None
436
+ audio_cls = torch.broadcast_to(self.audio_cls.unsqueeze(0), (audio_feats.shape[0], audio_feats.shape[1]))
437
+
438
+ aligned_audio_feats, aligned_audio_cls = self.audio_aligner(audio_feats, audio_cls)
439
+
440
+ if self.channel_dropout > 0.0:
441
+ aligned_audio_feats = self.channel_dropout_layer(aligned_audio_feats)
442
+
443
+ aligned_audio_feats = self.prep_feats(aligned_audio_feats, is_audio=True)
444
+ audio_mask = self.interpolate_mask(mask, aligned_audio_feats.shape[-1], True)
445
+ audio_pos_mask = self.interpolate_mask(pos_mask, aligned_audio_feats.shape[-1], False)
446
+
447
+ ret = {
448
+ AUDIO_MASK: audio_mask,
449
+ AUDIO_POS_MASK: audio_pos_mask,
450
+ AUDIO_FEATS: aligned_audio_feats,
451
+ }
452
+
453
+ if aligned_audio_cls is not None:
454
+ ret[AUDIO_CLS] = aligned_audio_cls
455
+
456
+ return ret
457
+
458
+ # @autocast(device_type="cuda", enabled=False)
459
+ def forward_image(self, batch, max_batch_size=None):
460
+
461
+ with torch.no_grad():
462
+ image = batch[IMAGE_INPUT]
463
+ b, nf, c, h, w = image.shape
464
+ image = image.reshape(b * nf, c, h, w)
465
+
466
+ if max_batch_size is None:
467
+ max_batch_size = image.shape[0]
468
+
469
+ chunks = [image[i:i + max_batch_size] for i in range(0, image.shape[0], max_batch_size)]
470
+
471
+ all_image_feats = []
472
+ all_image_cls = []
473
+
474
+ for chunk in chunks:
475
+ if self.full_train:
476
+ image_feats, image_cls = self.image_model(chunk, include_cls=True)
477
+ else:
478
+ with torch.no_grad():
479
+ image_feats, image_cls = self.image_model(chunk, include_cls=True)
480
+
481
+ aligned_image_feats, aligned_image_cls = self.image_aligner(image_feats, image_cls)
482
+
483
+ all_image_feats.append(aligned_image_feats)
484
+ all_image_cls.append(aligned_image_cls)
485
+
486
+ # Stitch the chunks back together
487
+ aligned_image_feats = torch.cat(all_image_feats, dim=0)
488
+ aligned_image_cls = torch.cat(all_image_cls, dim=0)
489
+
490
+ if self.channel_dropout > 0.0:
491
+ aligned_image_feats = self.channel_dropout_layer(aligned_image_feats)
492
+
493
+ if self.spatial_dropout > 0.0:
494
+ aligned_image_feats = self.spatial_dropout_layer(aligned_image_feats)
495
+
496
+ aligned_image_feats = self.prep_feats(aligned_image_feats, is_audio=False)
497
+ ret = {IMAGE_FEATS: aligned_image_feats}
498
+
499
+ if IMAGE_MASK in batch:
500
+ with torch.no_grad():
501
+ mask = batch[IMAGE_MASK]
502
+ mask = mask.reshape(b * nf, 1, h, w)
503
+ b, c, h, w = aligned_image_feats.shape
504
+ mask = F.adaptive_avg_pool2d(mask.to(aligned_image_feats), output_size=(h, w))
505
+ ret[IMAGE_MASK] = mask
506
+
507
+ if aligned_image_cls is not None:
508
+ ret[IMAGE_CLS] = aligned_image_cls
509
+
510
+ return ret
511
+
512
+ def forward(self, batch):
513
+ audio_feat_dict = self.forward_audio(batch)
514
+ image_feat_dict = self.forward_image(batch)
515
+ return {**image_feat_dict, **audio_feat_dict}
516
+
517
+ def contrast_loss(self, sims):
518
+ b = sims.shape[0]
519
+ sims = sims - torch.eye(b, b, device=sims.device) * self.loss_margin
520
+ sims_1 = sims
521
+ sims_2 = sims.permute(1, 0)
522
+
523
+ if self.loss_leak > 0.0:
524
+ id = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
525
+ label_mask = id * (1 - self.loss_leak)
526
+ label_mask += (1 - id) * self.loss_leak / (sims_1.shape[0] - 1)
527
+ label_mask /= label_mask.sum(dim=1, keepdim=True)
528
+ else:
529
+ label_mask = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
530
+
531
+ labels = torch.arange(0, sims.shape[0], device=sims.device)
532
+ self.rolling_avg.add(f"acc/1", (sims.argmax(dim=1) == labels).to(sims).mean())
533
+ self.rolling_avg.add(f"acc/2", (sims.argmax(dim=0) == labels).to(sims).mean())
534
+
535
+ if self.loss_type == "margin":
536
+ margin_loss_tensor = (sims - torch.diag(sims)).clamp_min(0)
537
+ margin_loss = margin_loss_tensor.mean()
538
+ self.rolling_avg.add(f"loss/frac_nonzero", (margin_loss_tensor > 0).to(sims).mean())
539
+ self.rolling_avg.add(f"loss/margin", margin_loss)
540
+ return margin_loss
541
+ elif self.loss_type == "ce":
542
+ ce_loss = 1 / 2 * F.cross_entropy(sims_1, labels) + \
543
+ 1 / 2 * F.cross_entropy(sims_2, labels)
544
+ self.rolling_avg.add(f"loss/ce", ce_loss)
545
+ return ce_loss
546
+ elif self.loss_type == "bce":
547
+ bce_loss = F.binary_cross_entropy_with_logits(sims_1.flatten(), label_mask.flatten())
548
+ self.rolling_avg.add(f"loss/bce", bce_loss)
549
+ return bce_loss
550
+ elif self.loss_type == "nce":
551
+ nce_loss = 1 / 2 * (-F.log_softmax(sims_1, dim=-1) * label_mask).sum(1).mean() + \
552
+ 1 / 2 * (-F.log_softmax(sims_2, dim=-1) * label_mask).sum(1).mean()
553
+ self.rolling_avg.add(f"loss/nce", nce_loss)
554
+ return nce_loss
555
+ else:
556
+ raise ValueError(f"Unknown loss type {self.loss_type}")
557
+
558
+ def loss(self, preds):
559
+ image_feats = preds[IMAGE_FEATS]
560
+ audio_feats = preds[AUDIO_FEATS]
561
+ audio_mask = preds[AUDIO_MASK]
562
+ image_mask = preds[IMAGE_MASK]
563
+ audio_pos_mask = preds[AUDIO_POS_MASK]
564
+ if DATA_SOURCE in preds:
565
+ source = preds[DATA_SOURCE].to(torch.int64)
566
+ else:
567
+ source = None
568
+
569
+ uncal_sims = self.sim_agg(preds, agg_heads=True)
570
+ sims = self.sim_cal(uncal_sims)
571
+
572
+ _mask = 1 - torch.eye(sims.shape[0], device=sims.device)
573
+ self.log(f"sim/pos", torch.diag(sims).mean())
574
+ self.log(f"sim/neg", (sims * _mask).sum() / (_mask.sum()))
575
+ self.log(f"sim/uncal_pos", torch.diag(uncal_sims).mean())
576
+ self.log(f"sim/uncal_neg", (uncal_sims * _mask).sum() / (_mask.sum()))
577
+
578
+ b, c, h, w = image_feats.shape
579
+ b, c, f, t = audio_feats.shape
580
+ n_samples = 250
581
+
582
+ nh = self.sim_agg_heads
583
+ image_feats_by_head = image_feats.reshape(b, self.sim_agg_heads, c // nh, h, w)
584
+ audio_feats_by_head = audio_feats.reshape(b, self.sim_agg_heads, c // nh, f, t)
585
+
586
+ def maybe_clamp(t):
587
+ return t.clamp_min(0) if self.nonneg_sim else t
588
+
589
+ paired_sim_raw = self.sim_agg.get_pairwise_sims(preds, raw=True, agg_sim=False, agg_heads=False)
590
+ paired_sim = maybe_clamp(paired_sim_raw)
591
+
592
+ loss = 0.0
593
+
594
+ if self.nonneg_pressure:
595
+ afb, afk, afc, aff, aft = audio_feats_by_head.shape
596
+ ifb, ifk, ifc, ifh, ifw = image_feats_by_head.shape
597
+ assert (afb == ifb)
598
+
599
+ device = audio_feats_by_head.device
600
+ random_b = torch.randint(0, afb, size=(n_samples,), device=device)
601
+ random_t = torch.randint(0, aft, size=(n_samples,), device=device)
602
+ random_f = torch.randint(0, aff, size=(n_samples,), device=device)
603
+ random_h = torch.randint(0, ifh, size=(n_samples,), device=device)
604
+ random_w = torch.randint(0, ifw, size=(n_samples,), device=device)
605
+
606
+ random_audio_feats = audio_feats_by_head[random_b, :, :, random_f, random_t]
607
+ random_image_feats = image_feats_by_head[random_b, :, :, random_h, random_w]
608
+ random_sim_raw = torch.einsum("bkc,dkc->bdk", random_audio_feats, random_image_feats)
609
+
610
+ nonneg_loss = random_sim_raw.clamp_max(0).square().mean()
611
+ self.rolling_avg.add(f"loss/nonneg", nonneg_loss)
612
+ loss += nonneg_loss * self.nonneg_pressure
613
+
614
+ if self.silence_l1 > 0 or self.silence_l2 > 0:
615
+ masked_b, masked_t = torch.where(~audio_mask)
616
+ if len(masked_b) > n_samples:
617
+ subset = torch.randperm(len(masked_b))[:n_samples]
618
+ masked_b = masked_b[subset]
619
+ masked_t = masked_t[subset]
620
+
621
+ if len(masked_b) == n_samples:
622
+ silent_audio_feats = audio_feats_by_head[masked_b, :, :, :, masked_t].mean(-1) # d k c
623
+ silence_tensor = maybe_clamp(
624
+ torch.einsum("bkchw,dkc->bkdhw", image_feats_by_head, silent_audio_feats))
625
+
626
+ silence_l1_loss = silence_tensor.abs().mean()
627
+ self.rolling_avg.add(f"loss/silence_l1", silence_l1_loss)
628
+ loss += silence_l1_loss * self.silence_l1
629
+
630
+ silence_l2_loss = silence_tensor.square().mean()
631
+ self.rolling_avg.add(f"loss/silence_l2", silence_l2_loss)
632
+ loss += silence_l2_loss * self.silence_l2
633
+ else:
634
+ pass
635
+
636
+ if self.neg_audio_weight > 0 and self.neg_audio:
637
+ b, t = audio_pos_mask.shape
638
+ negative_weight = ((1 - audio_pos_mask) * audio_mask.to(sims)).reshape(b, 1, 1, 1, 1, t)
639
+ negative_weight = torch.broadcast_to(negative_weight, paired_sim.shape)
640
+ if negative_weight.sum() > 0:
641
+ neg_audio_loss = (paired_sim.square() * negative_weight).sum() \
642
+ / negative_weight.sum().clamp_min(0.1)
643
+ self.rolling_avg.add(f"loss/neg_audio", neg_audio_loss)
644
+ self.rolling_avg.add(f"loss/neg_weight_avg", negative_weight.mean())
645
+ loss += neg_audio_loss * self.neg_audio_weight
646
+ else:
647
+ print("WARNING: No negative samples found in batch")
648
+
649
+ if self.tv_weight > 0:
650
+ tv_loss = (paired_sim[:, :, :, :, :, 1:] - paired_sim[:, :, :, :, :, :-1]).square().mean()
651
+ self.rolling_avg.add(f"loss/tv", tv_loss)
652
+ loss += tv_loss * self.tv_weight
653
+
654
+ self.log(f"cal/w", self.sim_cal.get_w())
655
+ if self.cal_balance_weight > 0.0:
656
+ cal_balance = (np.log(self.cal_init) - torch.log(self.sim_cal.get_w().clamp_min(.00000001))) \
657
+ .clamp_min(0).square().mean()
658
+ self.rolling_avg.add(f"loss/cal_balance", cal_balance)
659
+ loss += cal_balance * self.cal_balance_weight
660
+
661
+ if self.disentangle_weight > 0.0:
662
+ assert source is not None
663
+ assert self.sim_agg_heads % 2 == 0
664
+
665
+ dilation = self.sim_agg_heads // 2
666
+ sources_oh = F.one_hot(source, num_classes=2)
667
+ b, h = sources_oh.shape
668
+ sources_mask = 1 - torch.broadcast_to(sources_oh.unsqueeze(-1), (b, h, dilation)) \
669
+ .reshape(b, h * dilation).to(paired_sim)
670
+ disentangle_loss = torch.einsum("bkhwft,bk->bhwft", paired_sim, sources_mask).square().mean()
671
+ self.rolling_avg.add(f"loss/disentangle", disentangle_loss)
672
+ loss += disentangle_loss * self.disentangle_weight
673
+
674
+ if self.specialization_weight > 0.0 and self.sim_agg_heads > 1:
675
+ total_specialization_loss = 0.0
676
+ combos = list(combinations(range(self.sim_agg_heads), 2))
677
+ for i, j in combos:
678
+ specialization_loss_pair = (paired_sim[:, i].abs() * paired_sim[:, j].abs()).mean()
679
+ total_specialization_loss += specialization_loss_pair
680
+ avg_specialization_loss = total_specialization_loss / len(combos)
681
+ self.rolling_avg.add(f"loss/specialize", avg_specialization_loss)
682
+ loss += avg_specialization_loss * self.specialization_weight
683
+
684
+ if self.mixup_weight > 0.0:
685
+ b, _, h, w = image_mask.shape
686
+ neg_img_mask = torch.broadcast_to(
687
+ 1 - image_mask.to(paired_sim).reshape(b, 1, h, w, 1, 1),
688
+ paired_sim.shape)
689
+ image_mixup_loss = (paired_sim * neg_img_mask).square().sum() / neg_img_mask.sum().clamp_min(0.1)
690
+ self.rolling_avg.add(f"loss/image_mixup", image_mixup_loss)
691
+ loss += image_mixup_loss * self.mixup_weight
692
+
693
+ sims = sims
694
+ loss += self.contrast_loss(sims)
695
+ self.rolling_avg.add(f"loss/total", loss)
696
+
697
+ return loss
698
+
699
+ def setup_hparams(self):
700
+ recalls = ['A_r1', 'A_r5', 'A_r10', 'I_r1', 'I_r5', 'I_r10']
701
+
702
+ if self.trainer.datamodule.use_extra_val_sets:
703
+ datasets = ["Places", "AudioSet"]
704
+ else:
705
+ datasets = ["Val"]
706
+
707
+ heads = ["total"]
708
+
709
+ metric_names = [
710
+ "hp/speech_basic_ap", "hp/speech_advanced_ap", "hp/sound_basic_ap",
711
+ "hp/speech_basic_iou", "hp/speech_advanced_iou", "hp/sound_basic_iou",
712
+ ]
713
+ for dataset in datasets:
714
+ for head in heads:
715
+ for recall in recalls:
716
+ metric_names.append(f"hp/{dataset}/{head}/{recall}")
717
+
718
+ if self.sim_agg_heads == 2:
719
+ metric_names.extend(["hp/ap_dis", "hp/act_dis"])
720
+
721
+ if hasattr(self.trainer, "datamodule"):
722
+ all_hparams = {**self.hparams, **self.trainer.datamodule.hparams}
723
+ else:
724
+ all_hparams = self.hparams
725
+
726
+ starting_values = {n: torch.nan for n in metric_names}
727
+ self.logger.log_hyperparams(all_hparams, starting_values)
728
+
729
+ def on_train_start(self):
730
+ self.setup_hparams()
731
+ self.hparams_logged = True
732
+
733
+ def on_train_batch_start(self, batch, batch_idx):
734
+ remake_optimizers = False
735
+
736
+ if isinstance(self.image_aligner, ProgressiveGrowing):
737
+ should_remake = self.image_aligner.maybe_change_phase(self.global_step)
738
+ remake_optimizers = remake_optimizers or should_remake
739
+ if isinstance(self.audio_aligner, ProgressiveGrowing):
740
+ should_remake = self.audio_aligner.maybe_change_phase(self.global_step)
741
+ remake_optimizers = remake_optimizers or should_remake
742
+
743
+ if remake_optimizers:
744
+ raise NotImplementedError()
745
+
746
+ def _combine_preds(self, all_preds):
747
+ temp = {}
748
+ new_preds = {}
749
+
750
+ # Collect tensors for each key into lists
751
+ for d in all_preds:
752
+ for key, value in d.items():
753
+ if isinstance(value, torch.Tensor):
754
+ if key not in temp:
755
+ temp[key] = []
756
+ temp[key].append(value)
757
+
758
+ # Concatenate all tensors for each key using a single call to torch.cat
759
+ for key, tensor_list in temp.items():
760
+ new_preds[key] = torch.cat(tensor_list)
761
+ return new_preds
762
+
763
+ def training_step(self, batch, batch_idx):
764
+ assert batch[IMAGE_INPUT].shape[1] == 1
765
+
766
+ preds = self.forward(batch)
767
+ if DATA_SOURCE in batch:
768
+ preds[DATA_SOURCE] = batch[DATA_SOURCE]
769
+
770
+ if self.trainer.world_size > 1 and self.gather_tensors:
771
+ for k, v in preds.items():
772
+ new_v = v.contiguous()
773
+ preds[k] = torch.cat(GatherLayer.apply(new_v), dim=0)
774
+
775
+ if self.memory_buffer_size > 0:
776
+ new_preds = self._combine_preds(list(self.memory_buffer) + [preds])
777
+ else:
778
+ new_preds = preds
779
+
780
+ loss = self.loss(new_preds)
781
+
782
+ if self.memory_buffer_size > 0:
783
+ self.memory_buffer.append(self._recursive_detach(preds, gather=False))
784
+
785
+ if self.trainer.is_global_zero and self.global_step % 50 == 1:
786
+ writer = self.logger.experiment
787
+ self.rolling_avg.logall(lambda k, v: writer.add_scalar(k, v, global_step=self.global_step))
788
+
789
+ if self.trainer.scaler is not None:
790
+ self.log("loss_scale", self.trainer.scaler.get_scale())
791
+
792
+ if self.global_step % 10000 == 0 and self.global_step > 0:
793
+ print("RESETTING TFEVENT FILE")
794
+ self.logger.experiment.close()
795
+ self.logger.experiment._get_file_writer()
796
+
797
+ return loss
798
+
799
+ def on_validation_start(self) -> None:
800
+ if not self.hparams_logged:
801
+ self.setup_hparams()
802
+ self.hparams_logged = True
803
+
804
+ def _auto_gather(self, t):
805
+ if t.dtype == torch.bool:
806
+ t = t.to(torch.float)
807
+
808
+ if self.trainer.num_devices == 1:
809
+ return t.cpu()
810
+
811
+ t = torch.clone(t).contiguous()
812
+ if self.trainer.is_global_zero:
813
+ gather_list = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
814
+ dist.gather(t, gather_list)
815
+ return torch.cat(gather_list, dim=0).cpu()
816
+ else:
817
+ dist.gather(t)
818
+
819
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
820
+
821
+ with torch.no_grad():
822
+ preds = self.forward(batch)
823
+
824
+ ret = {}
825
+ for k in preds.keys():
826
+ if k in preds:
827
+ ret[k] = self._auto_gather(preds[k])
828
+
829
+ batch_keys = [IMAGE_INPUT, "spec", "semseg", "num_pixels_per_class", 'total_length']
830
+ for k in batch_keys:
831
+ if k in batch:
832
+ ret[k] = self._auto_gather(batch[k])
833
+
834
+ if "metadata" in batch:
835
+ if isinstance(batch["metadata"]["id"], torch.Tensor):
836
+ ret["id"] = self._auto_gather(batch["metadata"]["id"])
837
+ ret["index"] = self._auto_gather(batch["metadata"]["index"])
838
+
839
+ return ret
840
+
841
+ def _calc_recalls(self, sim):
842
+ top_10_a = sim.topk(10, 0).indices == torch.arange(sim.shape[0]).unsqueeze(0)
843
+ top_10_i = (sim.topk(10, 1).indices == torch.arange(sim.shape[0]).unsqueeze(1)).permute(1, 0)
844
+ a_recall = lambda p: top_10_a[0:p].any(0).to(sim).mean()
845
+ i_recall = lambda p: top_10_i[0:p].any(0).to(sim).mean()
846
+ return {'A_r1': a_recall(1),
847
+ 'A_r5': a_recall(5),
848
+ 'A_r10': a_recall(10),
849
+ 'I_r1': i_recall(1),
850
+ 'I_r5': i_recall(5),
851
+ 'I_r10': i_recall(10)}
852
+
853
+ def calc_recalls(self, preds, dataset):
854
+ sim = self.sim_agg.forward_batched(
855
+ preds=preds,
856
+ agg_heads=False,
857
+ batch_size=4,
858
+ ).cpu()
859
+
860
+ all_metrics = dict()
861
+ for k, v in self._calc_recalls(sim.sum(-1)).items():
862
+ all_metrics[f"hp/{dataset}/total/" + k] = v
863
+
864
+ return all_metrics
865
+
866
+ def retrieval_validation(self, outputs, dataset_name):
867
+ if len(outputs) == 0:
868
+ return
869
+
870
+ if self.trainer.is_global_zero:
871
+ results = flatten_preds(outputs)
872
+ if not self.trainer.sanity_checking:
873
+ print(results[IMAGE_FEATS].shape[0])
874
+ # assert (results[IMAGE_FEATS].shape[0] == 1000)
875
+ results[IMAGE_FEATS] = results[IMAGE_FEATS].cpu()
876
+ results[AUDIO_FEATS] = results[AUDIO_FEATS].cuda()
877
+ if self.sim_use_cls:
878
+ results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
879
+ results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
880
+
881
+ results[AUDIO_MASK] = results[AUDIO_MASK].cuda()
882
+
883
+ recalls = self.calc_recalls(results, dataset_name)
884
+
885
+ results[IMAGE_FEATS] = results[IMAGE_FEATS].cuda()
886
+
887
+ writer = self.logger.experiment
888
+ print("here")
889
+ for name, v in recalls.items():
890
+ writer.add_scalar(f"{name}", v, self.global_step + 1)
891
+
892
+ def semseg_validation(self, speech_preds, sound_preds):
893
+
894
+ if self.trainer.is_global_zero:
895
+ from eval_utils import get_paired_heatmaps
896
+ def prep_preds(preds, loader):
897
+ results = flatten_preds(preds)
898
+ metadata = loader.dataset.metadata
899
+ ordered_metadata = metadata.iloc[results["index"].numpy(), :].copy()
900
+ ordered_metadata["order"] = range(len(ordered_metadata))
901
+ return results, ordered_metadata
902
+
903
+ [_, _, speech_loader, sound_loader] = self.trainer.val_dataloaders
904
+ speech_results, speech_metadata = prep_preds(speech_preds, speech_loader)
905
+ sound_results, sound_metadata = prep_preds(sound_preds, sound_loader)
906
+
907
+ self.sound_metrics, unique_sound_indices = get_paired_heatmaps(
908
+ self, sound_results, sound_metadata["ade_class_id"], None)
909
+
910
+ self.speech_metrics, unique_word_indices = get_paired_heatmaps(
911
+ self, speech_results, speech_metadata["ade_class_id"], speech_metadata["timing"])
912
+
913
+ writer = self.logger.experiment
914
+
915
+ all_metrics = {
916
+ **{"sound_" + k: v for k, v in self.sound_metrics.items()},
917
+ **{"speech_" + k: v for k, v in self.speech_metrics.items()},
918
+ }
919
+
920
+ for k, v in all_metrics.items():
921
+ writer.add_scalar(f"hp/{k}", torch.tensor(v).mean(), self.global_step + 1)
922
+
923
+ def disentangle_validation(self, word_preds, sound_preds):
924
+
925
+ if len(word_preds) == 0 or len(sound_preds) == 0:
926
+ return
927
+
928
+ if self.trainer.is_global_zero:
929
+ word_preds = flatten_preds(word_preds)
930
+ sound_preds = flatten_preds(sound_preds)
931
+
932
+ word_scores = self.sim_agg.get_pairwise_sims(
933
+ word_preds,
934
+ raw=False,
935
+ agg_sim=True,
936
+ agg_heads=False,
937
+ )
938
+
939
+ sound_scores = self.sim_agg.get_pairwise_sims(
940
+ sound_preds,
941
+ raw=False,
942
+ agg_sim=True,
943
+ agg_heads=False,
944
+ )
945
+
946
+ all_scores = torch.cat([word_scores, sound_scores], dim=0)
947
+ all_scores -= all_scores.min(dim=0, keepdim=True).values
948
+ all_scores /= all_scores.max(dim=0, keepdim=True).values.clamp_min(.0001)
949
+
950
+ is_words = torch.cat([
951
+ torch.ones(word_scores.shape[0]),
952
+ torch.zeros(sound_scores.shape[0])], dim=0).to(torch.bool)
953
+
954
+ assert all_scores.shape[1] == 2
955
+ ap_matrix = torch.zeros(2, 2)
956
+ act_matrix = torch.zeros(2, 2)
957
+
958
+ for head in range(2):
959
+ # writer.add_histogram(f"h{head}_all_scores", all_scores[:, head])
960
+ for dataset_num in range(2):
961
+ if dataset_num == 0:
962
+ labels = is_words
963
+ else:
964
+ labels = ~is_words
965
+
966
+ ap_matrix[head, dataset_num] = binary_average_precision(
967
+ all_scores[:, head].cpu(), labels.to(torch.int64).cpu())
968
+
969
+ act_matrix[head, dataset_num] = 1 - (all_scores[:, head][labels]).mean()
970
+
971
+ ap_dis = max(.5 * (ap_matrix[0, 0] + ap_matrix[1, 1]),
972
+ .5 * (ap_matrix[0, 1] + ap_matrix[1, 0]))
973
+
974
+ act_dis = max(.5 * (act_matrix[0, 0] + act_matrix[1, 1]),
975
+ .5 * (act_matrix[0, 1] + act_matrix[1, 0]))
976
+
977
+ print("AP", ap_matrix)
978
+ print("AP dis", ap_dis)
979
+ print("Act", act_matrix)
980
+ print("Act dis", act_dis)
981
+
982
+ writer = self.logger.experiment
983
+ writer.add_scalar("hp/ap_dis", ap_dis, self.global_step + 1)
984
+ writer.add_scalar("hp/act_dis", act_dis, self.global_step + 1)
985
+
986
+ def validation_epoch_end(self, outputs) -> None:
987
+ print("Val end")
988
+ with torch.no_grad():
989
+ if self.trainer.datamodule.use_extra_val_sets:
990
+ if self.sim_agg_heads == 2:
991
+ self.disentangle_validation(outputs[0], outputs[1])
992
+ self.retrieval_validation(outputs[0], "Places")
993
+ self.retrieval_validation(outputs[1], "AudioSet")
994
+ self.semseg_validation(outputs[2], outputs[3])
995
+
996
+ else:
997
+ print("HERE!")
998
+ self.retrieval_validation(outputs, "Val")
999
+
1000
+ writer = self.logger.experiment
1001
+ writer.flush()
1002
+
1003
+ def _recursive_detach(self, obj, gather=True):
1004
+ if isinstance(obj, torch.Tensor):
1005
+ if gather:
1006
+ return self._auto_gather(obj)
1007
+ else:
1008
+ obj.detach()
1009
+ elif isinstance(obj, dict):
1010
+ return {k: self._recursive_detach(v, gather) for k, v in obj.items()}
1011
+ elif isinstance(obj, list):
1012
+ return [self._recursive_detach(v, gather) for v in obj]
1013
+ else:
1014
+ return obj
1015
+
1016
+ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
1017
+ with torch.no_grad():
1018
+ predictions = {}
1019
+ for k, v in batch.items():
1020
+ predictions[k] = self._recursive_detach(v)
1021
+ for k, v in self.forward(batch).items():
1022
+ predictions[k] = self._auto_gather(v)
1023
+
1024
+ return predictions
1025
+
1026
+ def _configure_optimizers(self, full_train, lr):
1027
+ params = [
1028
+ *self.audio_aligner.parameters(),
1029
+ *self.image_aligner.parameters(),
1030
+ *self.sim_cal.parameters(),
1031
+ *self.sim_agg.parameters()
1032
+ ]
1033
+
1034
+ if (self.finetune_image_model or self.image_lora) and full_train:
1035
+ params.extend(self.image_model.parameters())
1036
+
1037
+ if (self.finetune_audio_model or self.audio_lora) and full_train:
1038
+ params.extend(self.audio_model.parameters())
1039
+
1040
+ if self.learn_audio_cls:
1041
+ params.append(self.audio_cls)
1042
+
1043
+ last_epoch = self.global_step - 1
1044
+ if self.optimizer == "adam":
1045
+ opt = torch.optim.Adam(params, lr=lr, eps=1e-7)
1046
+ elif self.optimizer == "nadam":
1047
+ opt = torch.optim.NAdam(params, lr=lr, eps=1e-7)
1048
+ else:
1049
+ raise ValueError(f"Unknown optimizer {self.optimizer}")
1050
+
1051
+ if self.lr_schedule == "sgdr":
1052
+ scheduler = CosineAnnealingWarmRestarts(
1053
+ opt, self.lr_cycle_length, 2, eta_min=lr * 2e-2, last_epoch=last_epoch)
1054
+ else:
1055
+ scheduler = LambdaLR(opt, lr_lambda=lambda step: 1.0, last_epoch=last_epoch)
1056
+
1057
+ if self.lr_warmup > 0:
1058
+ warmup = LambdaLR(
1059
+ opt,
1060
+ lr_lambda=lambda step: min(max(float(step), 0.0) / self.lr_warmup, 1.0),
1061
+ last_epoch=last_epoch,
1062
+ )
1063
+ scheduler = SequentialLR(
1064
+ opt,
1065
+ schedulers=[warmup, scheduler],
1066
+ milestones=[self.lr_warmup],
1067
+ last_epoch=last_epoch)
1068
+
1069
+ scheduler = {"scheduler": scheduler, "interval": "step"}
1070
+
1071
+ return [opt], [scheduler]
1072
+
1073
+ def configure_optimizers(self):
1074
+ if self.full_train:
1075
+ return self._configure_optimizers(self.full_train, self.lr)
1076
+ else:
1077
+ return self._configure_optimizers(self.full_train, self.pretrain_lr)
1078
+
1079
+
1080
+ @hydra.main(config_path="configs", config_name="av_align.yaml", version_base=None)
1081
+ def my_app(cfg: DictConfig) -> None:
1082
+ print(OmegaConf.to_yaml(cfg))
1083
+ seed_everything(cfg.seed, workers=True)
1084
+
1085
+ exp_name = f"{cfg.resume_prefix}"
1086
+
1087
+ if cfg.image_model_type == "dino8":
1088
+ patch_size = 8 * cfg.image_pool_width
1089
+ elif cfg.image_model_type == "cavmae":
1090
+ patch_size = 16 * cfg.image_pool_width
1091
+ elif cfg.image_model_type == "imagebind":
1092
+ patch_size = 16 * cfg.image_pool_width
1093
+ elif cfg.image_model_type == "clip":
1094
+ patch_size = 16 * cfg.image_pool_width
1095
+ elif cfg.image_model_type == "cavmae-mixed":
1096
+ patch_size = 16 * cfg.image_pool_width
1097
+ elif cfg.image_model_type == "dinov2":
1098
+ patch_size = 14 * cfg.image_pool_width
1099
+ else:
1100
+ raise ValueError(f"Unknown patch size for model {cfg.image_model_type}")
1101
+
1102
+ datamodule = AVDataModule(
1103
+ dataset_name=cfg.dataset_name,
1104
+ load_size=cfg.load_size,
1105
+ image_aug=cfg.image_aug,
1106
+ audio_aug=cfg.audio_aug,
1107
+ extra_audio_masking=cfg.extra_audio_masking,
1108
+ audio_model_type=cfg.audio_model_type,
1109
+ pytorch_data_dir=cfg.pytorch_data_dir,
1110
+ use_cached_embs=cfg.use_cached_embs,
1111
+ batch_size=cfg.batch_size,
1112
+ num_workers=cfg.num_workers,
1113
+ audio_level=cfg.audio_level,
1114
+ neg_audio=cfg.neg_audio,
1115
+ use_original_val_set=not cfg.use_extra_val_sets,
1116
+ use_extra_val_sets=cfg.use_extra_val_sets,
1117
+ data_for_plotting=False,
1118
+ quad_mixup=cfg.quad_mixup,
1119
+ bg_mixup=cfg.bg_mixup,
1120
+ patch_mixup=cfg.patch_mixup,
1121
+ patch_size=patch_size
1122
+ )
1123
+ datamodule.maybe_unpack(remove_source=cfg.submitting_to_aml)
1124
+
1125
+ aligner = create_model_from_cfg(LitAVAligner, cfg, {})
1126
+
1127
+ if cfg.starting_weights is not None:
1128
+ loaded = torch.load(join(cfg.output_root, cfg.starting_weights), map_location='cpu')
1129
+ state = loaded["state_dict"]
1130
+ aligner.load_state_dict(state, strict=cfg.load_strict)
1131
+ del state
1132
+ del loaded
1133
+
1134
+ if cfg.num_gpus > 1:
1135
+ # strategy = "ddp_sharded" # _find_unused_parameters_true"
1136
+ strategy = "ddp" # _find_unused_parameters_true"
1137
+ else:
1138
+ strategy = "auto"
1139
+
1140
+ if cfg.dataset_name in {"places-audio", "mixed", "audio-set", "mixed-full"}:
1141
+ val_args = dict(check_val_every_n_epoch=2)
1142
+ elif cfg.dataset_name in {"dolphin"}:
1143
+ val_args = dict(check_val_every_n_epoch=5)
1144
+ else:
1145
+ val_args = dict(val_check_interval=10000)
1146
+
1147
+ # val_args = dict(val_check_interval=1000)
1148
+
1149
+ def maybe_get_ckpt(ckpt_dir):
1150
+ if cfg.auto_resume and os.path.exists(ckpt_dir):
1151
+ print(f"Attempting to resume from {ckpt_dir}")
1152
+ candidates = os.listdir(ckpt_dir)
1153
+ assert (len(candidates) == 1)
1154
+ return join(ckpt_dir, candidates[0])
1155
+ elif cfg.auto_resume:
1156
+ print(f"Could not find checkpoint at {ckpt_dir}")
1157
+ return None
1158
+ else:
1159
+ return None
1160
+
1161
+ log_dir = join(cfg.output_root, "logs", cfg.grouping_name, exp_name)
1162
+ ckpt_dir = join(cfg.output_root, "checkpoints", cfg.grouping_name, exp_name)
1163
+
1164
+ import gc
1165
+ torch.cuda.empty_cache()
1166
+ gc.collect()
1167
+
1168
+ def run_exp(aligner, full_train):
1169
+ trainer_args = dict(
1170
+ accelerator='gpu',
1171
+ strategy=strategy,
1172
+ devices=cfg.num_gpus,
1173
+ num_sanity_val_steps=cfg.num_sanity_val_steps,
1174
+ log_every_n_steps=50,
1175
+ reload_dataloaders_every_n_epochs=10,
1176
+ precision="16",
1177
+ # profiler="simple",
1178
+ # precision="bf16",
1179
+ max_steps=cfg.max_steps,
1180
+ **val_args)
1181
+
1182
+ aligner.set_full_train(full_train)
1183
+ if full_train:
1184
+ suffix = "train"
1185
+ else:
1186
+ suffix = "pretrain"
1187
+ trainer_args["max_steps"] = cfg.pretrain_steps
1188
+
1189
+ print(f"Starting {suffix} phase")
1190
+
1191
+ logger = TensorBoardLogger(join(log_dir, suffix), default_hp_metric=False)
1192
+ callbacks = [
1193
+ ModelCheckpoint(join(ckpt_dir, suffix), every_n_epochs=1),
1194
+ LearningRateMonitor(logging_interval='step'),
1195
+ ]
1196
+ Trainer(logger=logger,
1197
+ callbacks=callbacks,
1198
+ **trainer_args).fit(
1199
+ aligner,
1200
+ datamodule=datamodule,
1201
+ ckpt_path=maybe_get_ckpt(join(ckpt_dir, suffix)))
1202
+
1203
+ train_chkpt = maybe_get_ckpt(join(ckpt_dir, "train"))
1204
+
1205
+ gc.collect()
1206
+ if torch.cuda.is_available():
1207
+ torch.cuda.empty_cache()
1208
+
1209
+ if cfg.pretrain_steps > 0 and train_chkpt is None:
1210
+ print("---"*10)
1211
+ print("Setup with full_train = False")
1212
+ run_exp(aligner, full_train=False)
1213
+ print("---"*10)
1214
+ else:
1215
+ print("---"*10)
1216
+ print("Setup with full_train = False")
1217
+ run_exp(aligner, full_train=True)
1218
+ print("---"*10)
1219
+
1220
+
1221
+ if __name__ == "__main__":
1222
+ my_app()